Commit e2cc2fc4 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.10' into release_v2.10

parents 96a104d5 59b49b47
...@@ -10,11 +10,13 @@ import functools ...@@ -10,11 +10,13 @@ import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import warnings import warnings
try: enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
import lightop if enable_lightop:
enable_lightop = True try:
except ImportError: import lightop
enable_lightop = False enable_lightop = True
except ImportError:
enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
......
...@@ -16,12 +16,15 @@ from ..export import is_in_onnx_export_mode ...@@ -16,12 +16,15 @@ from ..export import is_in_onnx_export_mode
from ..utils import get_default_init_method from ..utils import get_default_init_method
import warnings import warnings
try: import os
from lightop import rmsnorm_forward,rmsnorm_backward enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
enable_lightop = True if enable_lightop:
except ImportError: try:
enable_lightop = False from lightop import rmsnorm_forward,rmsnorm_backward
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning) enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = { fwd_normalization_funcs = {
......
...@@ -80,12 +80,16 @@ from ..cpp_extensions import ( ...@@ -80,12 +80,16 @@ from ..cpp_extensions import (
general_gemm, general_gemm,
) )
import warnings import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
enable_lightop = True
except ImportError: if enable_lightop:
enable_lightop = False try:
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning) from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
__all__ = ["LayerNormLinear"] __all__ = ["LayerNormLinear"]
......
...@@ -88,11 +88,13 @@ from ..cpp_extensions import ( ...@@ -88,11 +88,13 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
try: enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
from lightop import rmsnorm_forward, rmsnorm_backward if enable_lightop:
enable_lightop = True try:
except ImportError: from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = False enable_lightop = True
except ImportError:
enable_lightop = False
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment