"vscode:/vscode.git/clone" did not exist on "7d648418b8b1aadb90489ef18cff1763ffc82ed5"
Unverified Commit 447de6da authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Fix layernorm output shape in LayernormLinear (#1906)



Fix layernorm output shape in LayernormLinear, make it align with input shape
Signed-off-by: default avatarGuyue Huang <guyueh@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9d173c93
......@@ -493,7 +493,7 @@ class _LayerNormLinear(torch.autograd.Function):
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape = list(inp_shape)
shape[0] *= tp_size if with_input_all_gather else 1
return out, ln_out_return.view(shape)
return out, ln_out_return.view(inp_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment