"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ff182ad6694ada3c01b3514eeae03392b2761b92"
Commit 0edd811d authored by zms1999's avatar zms1999
Browse files

fix shape mismatch in set_grad

parent d90ff389
...@@ -53,6 +53,6 @@ def set_grads(e, grads): ...@@ -53,6 +53,6 @@ def set_grads(e, grads):
seg = grads[offset:offset + p.numel()] seg = grads[offset:offset + p.numel()]
offset += p.numel() offset += p.numel()
if p.grad is None: if p.grad is None:
p.grad = seg.clone() p.grad = seg.clone().reshape(p.shape)
else: else:
p.grad += seg.reshape(p.shape) p.grad += seg.reshape(p.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment