Commit 59b49b47 authored by maxiao3's avatar maxiao3 Committed by wenjh
Browse files

Add NVTE_USE_LIGHTOP env var to control lightop import


Signed-off-by: default avatarmaxiao3 <maxiao3@sugon.com>

See merge request dcutoolkit/deeplearing/TransformerEngine!71
parent 0fce42f7
......@@ -9,10 +9,12 @@ import os
import torch
import transformer_engine_torch as tex
import warnings
try:
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
import lightop
enable_lightop = True
except ImportError:
except ImportError:
enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
......
......@@ -16,10 +16,13 @@ from ..export import is_in_onnx_export_mode
from ..utils import get_default_init_method
import warnings
try:
import os
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
......
......@@ -80,10 +80,14 @@ from ..cpp_extensions import (
general_gemm,
)
import warnings
try:
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward,rmsnorm_backward
enable_lightop = True
except ImportError:
except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
......
......@@ -88,10 +88,12 @@ from ..cpp_extensions import (
from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.debug_state import TEDebugState
try:
enable_lightop = os.getenv("NVTE_USE_LIGHTOP", "false").strip().lower() in ["true", "1"]
if enable_lightop:
try:
from lightop import rmsnorm_forward, rmsnorm_backward
enable_lightop = True
except ImportError:
except ImportError:
enable_lightop = False
__all__ = ["LayerNormMLP"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment