Unverified Commit d3352524 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Fix bug when reshaping norm output from LayerNormLinear (#1741)



Fix bug when reshaping norm output from LayerNormLinear
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 221fedc2
......@@ -467,7 +467,7 @@ class _LayerNormLinear(torch.autograd.Function):
shape = list(inp_shape)
shape[0] *= tp_size
return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp)
return out, ln_out_return.view(inp_shape)
return out
@staticmethod
......
......@@ -612,7 +612,7 @@ class _LayerNormMLP(torch.autograd.Function):
shape = list(inp_shape)
shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp)
return fc2_out, ln_out_return.view(inp_shape)
return fc2_out
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment