Unverified Commit c52edcf0 authored by junxu's avatar junxu Committed by GitHub
Browse files

Rename class method of ZeroDDP (#2692)

parent 6e4ac081
...@@ -294,7 +294,7 @@ class ZeroDDP(ColoDDP): ...@@ -294,7 +294,7 @@ class ZeroDDP(ColoDDP):
continue continue
p.grad = None p.grad = None
def _pre_bacward(self): def _pre_backward(self):
# set a visit label for all parameters # set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced # the label is used to check whether the parameter is correctly reduced
for param in self.param2name: for param in self.param2name:
...@@ -318,7 +318,7 @@ class ZeroDDP(ColoDDP): ...@@ -318,7 +318,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter() self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
self._pre_bacward() self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward() loss.backward()
self._post_backward() self._post_backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment