Commit 36587c01 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Update layer_norm_linear

parent 6bbc5323
......@@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function):
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
......@@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function):
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
dy,
x,
norm_weight,
......@@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function):
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
dresidual=dresidual,
has_residual=ctx.has_residual,
is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment