model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
self.forward_only=nottorch.is_grad_enabled()
ifoptimizerisNone:
assertself.forward_only,"Optimizer should be passed when doing backward."
@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments:
...
@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments:
-`enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
-`enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
-`enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
-`enable_all_optimization`: Whether to turn on all optimization tools including `fused normalization`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
-`extra_kwargs`: A dict to store extra kwargs for ShardFomer.
-`extra_kwargs`: A dict to store extra kwargs for ShardFormer.
### Write your own policy
### Write your own policy
...
@@ -116,17 +116,18 @@ We will follow this roadmap to develop Shardformer:
...
@@ -116,17 +116,18 @@ We will follow this roadmap to develop Shardformer: