Commit b16169cf authored by yuguo's avatar yuguo
Browse files

Merge branch 'pann-rmsnorm' into 'main'

add lightop rmsnorm

See merge request dcutoolkit/deeplearing/TransformerEngine!1
parents a207db1d 4f79b7a9
......@@ -15,6 +15,13 @@ from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
def _get_normalization_func(normalization: str, forward: bool):
......@@ -81,7 +88,9 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None):
return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
else:
return normalization_func(
*inputs,
eps,
......
......@@ -61,6 +61,13 @@ from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import (
general_gemm,
)
import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
__all__ = ["LayerNormLinear"]
......@@ -757,6 +764,10 @@ class _LayerNormLinear(torch.autograd.Function):
)
dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm":
if enable_lightop:
dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
else:
dgrad, dgamma = tex.rmsnorm_bwd(
dgrad,
inputmat,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment