"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "547d8dd865b60c9e080e1cfdabadbc72a2a73706"
Commit 4f79b7a9 authored by panning's avatar panning
Browse files

add lightop rmsnorm

parent a207db1d
...@@ -15,6 +15,13 @@ from .. import cpp_extensions as tex ...@@ -15,6 +15,13 @@ from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor from ..tensor.float8_tensor import Float8Tensor
import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
...@@ -81,16 +88,18 @@ def apply_normalization( ...@@ -81,16 +88,18 @@ def apply_normalization(
normalization_func = _get_normalization_func(normalization, True) normalization_func = _get_normalization_func(normalization, True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if enable_lightop and (ln_bias is None):
return normalization_func( return rmsnorm_forward(inputmat, ln_weight,ln_out,eps,True)
*inputs, else:
eps, return normalization_func(
ln_out, *inputs,
output_quantizer, eps,
TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, ln_out,
fwd_ln_sm_margin, output_quantizer,
zero_centered_gamma, TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype,
) fwd_ln_sm_margin,
zero_centered_gamma,
)
class _NoopCatFunc(torch.autograd.Function): class _NoopCatFunc(torch.autograd.Function):
......
...@@ -61,6 +61,13 @@ from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param ...@@ -61,6 +61,13 @@ from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param
from ..cpp_extensions import ( from ..cpp_extensions import (
general_gemm, general_gemm,
) )
import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
...@@ -757,14 +764,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -757,14 +764,18 @@ class _LayerNormLinear(torch.autograd.Function):
) )
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
elif ctx.normalization == "RMSNorm": elif ctx.normalization == "RMSNorm":
dgrad, dgamma = tex.rmsnorm_bwd( if enable_lightop:
dgrad, dgrad, dgamma =rmsnorm_backward(dgrad,inputmat,rsigma,ln_weight)
inputmat, else:
rsigma,
ln_weight, dgrad, dgamma = tex.rmsnorm_bwd(
ctx.bwd_ln_sm_margin, dgrad,
ctx.zero_centered_gamma, inputmat,
) rsigma,
ln_weight,
ctx.bwd_ln_sm_margin,
ctx.zero_centered_gamma,
)
dgrad = dgrad.reshape(inputmat.size()) dgrad = dgrad.reshape(inputmat.size())
dbeta = None dbeta = None
nvtx_range_pop(f"{nvtx_label}.norm") nvtx_range_pop(f"{nvtx_label}.norm")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment