Commit b15412aa authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 803be71d
...@@ -80,6 +80,12 @@ from ..cpp_extensions import ( ...@@ -80,6 +80,12 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
try:
from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -1264,14 +1270,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1264,14 +1270,17 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma, ctx.zero_centered_gamma,
) )
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
dgrad, dgamma = tex.rmsnorm_bwd( if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgrad, dgamma = rmsnorm_backward(dgrad, inputmat, rsigma, ln_weight)
inputmat, else:
rsigma, dgrad, dgamma = tex.rmsnorm_bwd(
ln_weight, dgrad,
ctx.bwd_ln_sm_margin, inputmat,
ctx.zero_centered_gamma, rsigma,
) ln_weight,
ctx.bwd_ln_sm_margin,
ctx.zero_centered_gamma,
)
dbeta = None dbeta = None
clear_tensor_data(mu, rsigma) clear_tensor_data(mu, rsigma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment