Unverified Commit 77d64552 authored by Marks101's avatar Marks101 Committed by GitHub
Browse files

[PyTorch] fix fuse_wgrad_accumulation in LayerNormMLP backward (#1618)



* [PyTorch] fix general_gemm argument out_dtype in LayerNormMLP backward
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMarkus Schnoes <markus.schnoes@gmx.de>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent cf9a7c2f
...@@ -734,7 +734,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -734,7 +734,11 @@ class _LayerNormMLP(torch.autograd.Function):
act_out, act_out,
grad_output, grad_output,
get_workspace(), get_workspace(),
out_dtype=ctx.activation_dtype, out_dtype=(
fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
quantization_params=None, # wgrad in high precision quantization_params=None, # wgrad in high precision
layout="NT", layout="NT",
grad=True, grad=True,
...@@ -894,7 +898,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -894,7 +898,11 @@ class _LayerNormMLP(torch.autograd.Function):
ln_out_total, ln_out_total,
dact, dact,
get_workspace(), get_workspace(),
out_dtype=ctx.activation_dtype, out_dtype=(
fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
layout="NT", layout="NT",
grad=fuse_gemm_and_bias_fc1_wgrad, grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment