"vscode:/vscode.git/clone" did not exist on "dcfb18a2d340d8e1f0ff001b06d2931ffa8648da"
Commit 0edd811d authored by zms1999's avatar zms1999
Browse files

fix shape mismatch in set_grad

parent d90ff389
......@@ -53,6 +53,6 @@ def set_grads(e, grads):
seg = grads[offset:offset + p.numel()]
offset += p.numel()
if p.grad is None:
p.grad = seg.clone()
p.grad = seg.clone().reshape(p.shape)
else:
p.grad += seg.reshape(p.shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment