Unverified Commit 7a053671 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] shared model returns cpu state_dict (#1328)

parent b2475d8c
......@@ -439,7 +439,8 @@ class ShardedModelV2(nn.Module):
for p in sharded_params:
p.data = p.colo_attr.data_payload
module_to_load = module_to_load or self
gathered_state_dict = deepcopy(state_dict_func(module_to_load, destination, prefix, keep_vars))
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
......
......@@ -39,7 +39,7 @@ def run_zero_state_dict(shard_strategy_class):
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])
assert torch.equal(val, zero_state_dict[key].to(val.device))
def run_dist(rank, world_size, port):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment