Commit 148b5bea authored by wenjh's avatar wenjh
Browse files

Fix pytorch module import error


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 793e0103
...@@ -12,9 +12,6 @@ import transformer_engine.common ...@@ -12,9 +12,6 @@ import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except ImportError as e: except ImportError as e:
try:
from . import pytorch
except ImportError as e:
pass pass
try: try:
......
...@@ -13,7 +13,6 @@ import torch ...@@ -13,7 +13,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...@@ -549,6 +548,7 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -549,6 +548,7 @@ def round_up_to_nearest_multiple(value, multiple):
def needs_quantized_gemm(obj, rowwise=True): def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm.""" """Used to check if obj will need quantized gemm or normal gemm."""
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
if isinstance(obj, DebugQuantizedTensor): if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor, torch.Tensor,
...@@ -643,3 +643,5 @@ if torch_version() >= (2, 4, 0): ...@@ -643,3 +643,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else: else:
gpu_autocast_ctx = torch.cuda.amp.autocast gpu_autocast_ctx = torch.cuda.amp.autocast
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment