"container/vscode:/vscode.git/clone" did not exist on "e97493eb0065285c2775bfb5fcee7cd821f08842"
Unverified Commit 5b24987f authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] fix parameters sharding bug (#2716)

parent 2045d45a
...@@ -428,6 +428,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -428,6 +428,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
param = torch.nn.Parameter( param = torch.nn.Parameter(
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec, shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
target_sharding_spec).detach().clone()) target_sharding_spec).detach().clone())
return param
for node in nodes: for node in nodes:
if node.op == 'call_module': if node.op == 'call_module':
...@@ -438,7 +439,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -438,7 +439,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
setattr(target_module, 'processed', True) setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters(): for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name) target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
_shard_param(param, target_sharding_spec) param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param) setattr(target_module, name, param)
_add_hook_for_grad_communication(node, param) _add_hook_for_grad_communication(node, param)
...@@ -469,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes ...@@ -469,7 +470,7 @@ def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMes
target = getattr(target_module, atoms[-1]) target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec target_sharding_spec = node.sharding_spec
_shard_param(target, target_sharding_spec) target = _shard_param(target, target_sharding_spec)
assert hasattr(target_module, atoms[-1]) assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target) setattr(target_module, atoms[-1], target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment