Unverified Commit a3f55cea authored by Edenzzzz's avatar Edenzzzz Committed by GitHub
Browse files

Fixed optim update error with non-contiguous grads/params (#1187)



* Fixed optim update error with non-contiguous grads
* fix formatting

Thanks @Edenzzzz for this contribution!

---------
Co-authored-by: default avatarTitus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com>
parent 5212a0f2
......@@ -474,6 +474,10 @@ class Optimizer2State(Optimizer8bit):
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
# avoid update error from non-contiguous memory layout
p.data = p.data.contiguous()
p.grad = p.grad.contiguous()
state = self.state[p]
grad = p.grad
......@@ -685,6 +689,10 @@ class Optimizer1State(Optimizer8bit):
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
# avoid update error from non-contiguous memory layout
p.data = p.data.contiguous()
p.grad = p.grad.contiguous()
state = self.state[p]
grad = p.grad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment