Unverified Commit c293d3a8 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch] Fix typo in GrouppedLinear (#1867)



typo fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent 0efc7daf
...@@ -241,8 +241,8 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -241,8 +241,8 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms: for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment