Commit 36587c01 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Update layer_norm_linear

parent 6bbc5323
...@@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function): ...@@ -987,7 +987,7 @@ class LayerNormLinearFn(torch.autograd.Function):
if residual is not None if residual is not None
else (torch.float32 if residual_in_fp32 else None) else (torch.float32 if residual_in_fp32 else None)
) )
y, mean, rstd, residual_out = _layer_norm_fwd( y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
x, x,
norm_weight, norm_weight,
norm_bias, norm_bias,
...@@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function): ...@@ -1031,7 +1031,7 @@ class LayerNormLinearFn(torch.autograd.Function):
assert dresidual.shape == x.shape assert dresidual.shape == x.shape
else: else:
dresidual = None dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
dy, dy,
x, x,
norm_weight, norm_weight,
...@@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function): ...@@ -1039,9 +1039,9 @@ class LayerNormLinearFn(torch.autograd.Function):
ctx.eps, ctx.eps,
mean, mean,
rstd, rstd,
dresidual, dresidual=dresidual,
ctx.has_residual, has_residual=ctx.has_residual,
ctx.is_rms_norm, is_rms_norm=ctx.is_rms_norm,
x_dtype=ctx.x_dtype, x_dtype=ctx.x_dtype,
recompute_output=True, recompute_output=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment