"tools/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "0685f4ffc9b3171b02315155b44e4a700964dd1a"
Unverified Commit 7db3ccc7 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] remove duplicated param register to stateful tensor manager (#728)

parent 600e769a
...@@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module): ...@@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module):
GLOBAL_MODEL_DATA_TRACER.register_model(self) GLOBAL_MODEL_DATA_TRACER.register_model(self)
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
# for param in module.parameters(): for param in module.parameters():
for submodule in module.modules(): if hasattr(param, 'colo_attr'):
for param in submodule.parameters(recurse=False): self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment