Commit 79fa3eba authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.7' of...

Merge branch 'develop_v2.7' of http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine into release_v2.7
parents 117f9059 b15412aa
...@@ -80,6 +80,12 @@ from ..cpp_extensions import ( ...@@ -80,6 +80,12 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
try:
from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -1264,14 +1270,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1264,14 +1270,17 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma, ctx.zero_centered_gamma,
) )
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
dgrad, dgamma = tex.rmsnorm_bwd( if enable_lightop and (rsigma.dtype is torch.bfloat16 or rsigma.dtype is torch.float16):
dgrad, dgrad, dgamma = rmsnorm_backward(dgrad, inputmat, rsigma, ln_weight)
inputmat, else:
rsigma, dgrad, dgamma = tex.rmsnorm_bwd(
ln_weight, dgrad,
ctx.bwd_ln_sm_margin, inputmat,
ctx.zero_centered_gamma, rsigma,
) ln_weight,
ctx.bwd_ln_sm_margin,
ctx.zero_centered_gamma,
)
dbeta = None dbeta = None
clear_tensor_data(mu, rsigma) clear_tensor_data(mu, rsigma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment