Commit 59b49b47 authored by maxiao3's avatar maxiao3 Committed by wenjh
Browse files

Add NVTE_USE_LIGHTOP env var to control lightop import


Signed-off-by: default avatarmaxiao3 <maxiao3@sugon.com>

See merge request dcutoolkit/deeplearing/TransformerEngine!71
parent 0fce42f7
...@@ -9,10 +9,12 @@ import os ...@@ -9,10 +9,12 @@ import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import warnings import warnings
try: enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
import lightop import lightop
enable_lightop = True enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
......
...@@ -16,10 +16,13 @@ from ..export import is_in_onnx_export_mode ...@@ -16,10 +16,13 @@ from ..export import is_in_onnx_export_mode
from ..utils import get_default_init_method from ..utils import get_default_init_method
import warnings import warnings
try: import os
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward,rmsnorm_backward from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning) warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
......
...@@ -80,10 +80,14 @@ from ..cpp_extensions import ( ...@@ -80,10 +80,14 @@ from ..cpp_extensions import (
general_gemm, general_gemm,
) )
import warnings import warnings
try:
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward,rmsnorm_backward from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning) warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
......
...@@ -88,10 +88,12 @@ from ..cpp_extensions import ( ...@@ -88,10 +88,12 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
try: enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward, rmsnorm_backward from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = True enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False enable_lightop = False
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment