Commit b32741e2 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents a5892578 148b5bea
......@@ -13,9 +13,6 @@ import transformer_engine.common
try:
from . import pytorch
except ImportError:
try:
from . import pytorch
except ImportError:
pass
except FileNotFoundError as e:
if "Could not find shared object file" not in str(e):
......
......@@ -13,7 +13,6 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
......@@ -558,6 +557,7 @@ def round_up_to_nearest_multiple(value, multiple):
def needs_quantized_gemm(obj, rowwise=True):
"""Used to check if obj will need quantized gemm or normal gemm."""
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
if isinstance(obj, DebugQuantizedTensor):
return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck
torch.Tensor,
......@@ -652,3 +652,5 @@ if torch_version() >= (2, 4, 0):
gpu_autocast_ctx = functools.partial(torch.amp.autocast, device_type="cuda")
else:
gpu_autocast_ctx = torch.cuda.amp.autocast
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment