Unverified Commit c4c02424 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[zero] sharded model manages ophooks individually (#492)

parent c9023d40
......@@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module):
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module,
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
......@@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module):
def backward(self, loss):
loss.backward()
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
@torch.no_grad()
def _post_backward_operations(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment