Unverified Commit 86c62cc9 authored by Eugen Hotaj's avatar Eugen Hotaj Committed by GitHub
Browse files

Use correct node names for param counting in auto_shard. (#830)



Fixes #827.
Co-authored-by: default avatarEugen Hotaj <ehotaj@fb.com>
parent eadfdc49
...@@ -52,8 +52,7 @@ def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3 ...@@ -52,8 +52,7 @@ def _split_nodes(traced_graph_module: torch.fx.GraphModule, shard_count: int = 3
# Find the total number of params in the model and # Find the total number of params in the model and
# the number of params per shard we are aiming for. # the number of params per shard we are aiming for.
for name, module in traced_graph_module.named_modules(): for name, module in traced_graph_module.named_modules():
if "." in name: name = name.replace(".", "_")
continue
param_count[name] = sum([x.numel() for x in module.parameters()]) param_count[name] = sum([x.numel() for x in module.parameters()])
logging.info(f"Total number of params are {param_count['']}") logging.info(f"Total number of params are {param_count['']}")
per_shard_param = param_count[""] // shard_count per_shard_param = param_count[""] // shard_count
......
...@@ -140,9 +140,7 @@ def test_dynaimc_conditionals_auto_wrapped(): ...@@ -140,9 +140,7 @@ def test_dynaimc_conditionals_auto_wrapped():
model = BranchedNetwork(features) model = BranchedNetwork(features)
sharded_model = shard_model(model, 3) sharded_model = shard_model(model, 3)
# TODO(ehotaj): There might be a bug in our split code because we shard the assert len(sharded_model) == 3
# model into 10 shards even though we specify 3 shards above.
assert len(sharded_model) == 10
input_ = torch.randn(3, features) input_ = torch.randn(3, features)
model_output = model(input_) model_output = model(input_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment