Unverified Commit 11f1e426 authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] Correct several erroneous code comments (#4794)

parent 54b3ad89
...@@ -50,7 +50,7 @@ class ModulePolicyDescription: ...@@ -50,7 +50,7 @@ class ModulePolicyDescription:
new_weight = shard_rowwise(weight, process_group) new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight) module.weight = torch.nn.Parameter(new_weight)
``` ```
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription
object which specifies the module to be replaced and the target module used to replacement. object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
""" """
......
...@@ -92,7 +92,7 @@ class BucketStore(BaseStore): ...@@ -92,7 +92,7 @@ class BucketStore(BaseStore):
def get_flatten_grad(self) -> Tensor: def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
Returns: Returns:
Tensor: the flattened gradients slices in the bucket Tensor: the flattened gradients slices in the bucket
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment