"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "39c58da187c9cd54e54560601be4b7ceef320da4"
Unverified Commit 77d64552 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] fix fuse_wgrad_accumulation in LayerNormMLP backward (#1618)



* [PyTorch] fix general_gemm argument out_dtype in LayerNormMLP backward
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cf9a7c2f
......@@ -734,7 +734,11 @@ class _LayerNormMLP(torch.autograd.Function):
act_out,
grad_output,
get_workspace(),
out_dtype=ctx.activation_dtype,
out_dtype=(
fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
quantization_params=None, # wgrad in high precision
layout="NT",
grad=True,
......@@ -894,7 +898,11 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total,
dact,
get_workspace(),
out_dtype=ctx.activation_dtype,
out_dtype=(
fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
layout="NT",
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment