Unverified Commit 7db3ccc7 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] remove duplicated param register to stateful tensor manager (#728)

parent 600e769a
...@@ -106,9 +106,7 @@ class ShardedModelV2(nn.Module): ...@@ -106,9 +106,7 @@ class ShardedModelV2(nn.Module):
GLOBAL_MODEL_DATA_TRACER.register_model(self) GLOBAL_MODEL_DATA_TRACER.register_model(self)
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
# for param in module.parameters(): for param in module.parameters():
for submodule in module.modules():
for param in submodule.parameters(recurse=False):
if hasattr(param, 'colo_attr'): if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment