Commit 148b5bea authored by wenjh's avatar wenjh
Browse files

Fix pytorch module import error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 793e0103
......@@ -12,10 +12,7 @@ import transformer_engine.common
try:
from . import pytorch
except ImportError as e:
try:
from . import pytorch
except ImportError as e:
pass
pass
try:
from . import jax
......
......@@ -13,7 +13,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
......@@ -549,6 +548,7 @@ def round_up_to_nearest_multiple(value, multiple):
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
......@@ -643,3 +643,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment