Unverified Commit 11f1e426 authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] Correct several erroneous code comments (#4794)

parent 54b3ad89
......@@ -50,7 +50,7 @@ class ModulePolicyDescription:
new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight)
```
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
......
......@@ -92,7 +92,7 @@ class BucketStore(BaseStore):
def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
Returns:
Tensor: the flattened gradients slices in the bucket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment