"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "32567b044c327a4d3cee179094f32646d8311c95"
Unverified Commit c4c02424 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] sharded model manages ophooks individually (#492)

parent c9023d40
...@@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module): ...@@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module):
self._iter_cnter = 0 self._iter_cnter = 0
# Register hooks # Register hooks
register_ophooks_recursively(self.module, self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]) register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
...@@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module): ...@@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module):
def backward(self, loss): def backward(self, loss):
loss.backward() loss.backward()
self._post_backward_operations() self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad) torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations() self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
@torch.no_grad() @torch.no_grad()
def _post_backward_operations(self) -> None: def _post_backward_operations(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment