Commit b32741e2 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents a5892578 148b5bea
...@@ -13,9 +13,6 @@ import transformer_engine.common ...@@ -13,9 +13,6 @@ import transformer_engine.common
try: try:
from . import pytorch from . import pytorch
except ImportError: except ImportError:
try:
from . import pytorch
except ImportError:
pass pass
except FileNotFoundError as e: except FileNotFoundError as e:
if "Could not find shared object file" not in str(e): if "Could not find shared object file" not in str(e):
......
...@@ -13,7 +13,6 @@ import torch ...@@ -13,7 +13,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...@@ -558,6 +557,7 @@ def round_up_to_nearest_multiple(value, multiple): ...@@ -558,6 +557,7 @@ def round_up_to_nearest_multiple(value, multiple):
def needs_quantized_gemm(obj, rowwise=True): def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm.""" """Used to check if obj will need quantized gemm or normal gemm."""
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
if isinstance(obj, DebugQuantizedTensor): if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor, torch.Tensor,
...@@ -652,3 +652,5 @@ if torch_version() >= (2, 4, 0): ...@@ -652,3 +652,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda") gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else: else:
gpu_autocast_ctx = torch.cuda.amp.autocast gpu_autocast_ctx = torch.cuda.amp.autocast
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment