Commit b15412aa authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 803be71d
...@@ -80,6 +80,12 @@ from ..cpp_extensions import ( ...@@ -80,6 +80,12 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
try:
from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -1264,6 +1270,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1264,6 +1270,9 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma, ctx.zero_centered_gamma,
) )
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgamma = rmsnorm_backward(dgrad, inputmat, rsigma, ln_weight)
else:
dgrad, dgamma = tex.rmsnorm_bwd( dgrad, dgamma = tex.rmsnorm_bwd(
dgrad, dgrad,
inputmat, inputmat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment