"...ggml-cann/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "0cb78a2fc2a5463ce8e84b56388f28e02e1c1d06"
Commit 1d95abb9 authored by wenjh's avatar wenjh
Browse files

Enable fp8 on nmz


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 5c63251d
......@@ -36,10 +36,12 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_P
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
# channelwise int8 test
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8"
python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 NVTE_INT8_SIM_FP8_TENSORWISE=1 python3 -m pytest -v -s --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_current_scaling_exact_int8_tensorwise.xml $TE_PATH/tests/pytorch/test_float8_current_scaling_exact.py || test_fail "test_float8_current_scaling_exact.py_int8_tensorwise"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact_int8.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py_int8"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 $TE_PATH/tests/pytorch/test_int8_blockwise_gemm_exact.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
......
......@@ -36,8 +36,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8_int8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py_int8"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
......@@ -47,7 +47,8 @@ NVTE_INT8_SIM_FP8=1 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
NVTE_INT8_SIM_FP8=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
NVTE_INT8_SIM_FP8=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py_int8"
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
......
......@@ -16,7 +16,7 @@ import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, int8_simulation_fp8, int8_simulation_fp8_tensorwise
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from test_numerics import (
......@@ -587,7 +587,7 @@ def test_fake_quant_fp8(
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if IS_HIP_EXTENSION:
if IS_HIP_EXTENSION and int8_simulation_fp8:
if fp8_kwargs["fprop_fp8"] or fp8_kwargs["dgrad_fp8"] or fp8_kwargs["wgrad_fp8"]:
return # Output type 32 (FP32) does not support int8 simulation.
if WORLD_RANK == 0:
......
......@@ -51,9 +51,8 @@ def _run_test(quantization):
all_boolean = [True, False]
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
def test_int8_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
......@@ -74,3 +73,15 @@ def test_distributed(quantization):
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
......@@ -46,7 +46,7 @@ def cublas_gemm_fp8_blockwise_case(
atol: float = 0.0,
rtol: float = 0.0
):
if IS_HIP_EXTENSION and int8_simulation_fp8:
if IS_HIP_EXTENSION:
if use_bias or use_gelu:
pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.")
if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)):
......@@ -167,7 +167,7 @@ def cublas_gemm_fp8_blockwise_case(
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if IS_HIP_EXTENSION and int8_simulation_fp8:
if IS_HIP_EXTENSION and int8_simulation_fp8:
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
......@@ -248,7 +248,7 @@ def cublas_gemm_test_constraint_enforced(
expected_err_cls=RuntimeError
):
if IS_HIP_EXTENSION:
pytest.skip("ROCm does not support cuBLAS GEMM. No need to test constraint enforcement.")
pytest.skip("ROCm does not support cuBLAS blockwise FP8 gemm. No need to test constraint enforcement.")
if not fp8_blockwise_gemm_supported():
pytest.skip("CUDA version does not support blockwise FP8 gemm.")
# Setup device and random seed
......
......@@ -495,6 +495,9 @@ def test_quantization_block_tiling_extrema_versus_reference(
rtol=0.0,
)
def fp8_blockwise_scaling_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported
# FP8 per tesnor current scaling
@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe)
......@@ -530,6 +533,59 @@ class TestFP8BlockScalingRecipeLinear(TestFP8RecipeLinearBase):
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
dgrad_error=1,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int_sim_fp8_current_scaling_with_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION:
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
......@@ -601,6 +657,65 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
fp8_zero_tolerance_tensor_dumps_recipe2 = None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map = self._check_golden_tensor_dumps(
TENSOR_DUMP_DIR,
recipe2,
(batch_size, hidden_size, out_size),
dtype,
use_bias,
"LayerNorm",
)
if tensor_map is not None:
fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map
self.compare_recipe(
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5,
ln_out_error=0.5,
dgrad_error=1.6,
wgrad_error=1,
bgrad_error=0.5,
recipe1_golden_tensors=None,
recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2,
)
@pytest.mark.parametrize(
"batch_size, hidden_size, out_size",
[
(16, 256, 128),
],
)
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize(
"recipe1, recipe2",
[
(GetRecipes.none, GetRecipes.fp8_blockwise),
],
)
def test_int8_sim_fp8_current_scaling_with_layernorm_linear_module(
self,
recipe1,
recipe2,
batch_size,
hidden_size,
out_size,
dtype,
use_bias=True,
):
if not fp8_blockwise_scaling_supported():
pytest.skip("CUDA version does not support blockwise FP8.")
if IS_HIP_EXTENSION:
import importlib
ori_int8_sim_fp8 = os.environ.get("NVTE_INT8_SIM_FP8", None)
......@@ -629,7 +744,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
use_bias,
seed=torch.initial_seed(),
dtype=dtype,
y_error=0.5 if not IS_HIP_EXTENSION else 0.9,
y_error=0.5 if not IS_HIP_EXTENSION else 0.9,
ln_out_error=0.5,
dgrad_error=1.6 if not IS_HIP_EXTENSION else 1.0,
wgrad_error=1,
......@@ -642,4 +757,4 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
os.environ["NVTE_INT8_SIM_FP8"] = ori_int8_sim_fp8
else:
del os.environ["NVTE_INT8_SIM_FP8"]
importlib.reload(te.pytorch.fp8)
importlib.reload(te.pytorch.fp8)
\ No newline at end of file
......@@ -13,6 +13,7 @@ from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Recipe,
......@@ -24,53 +25,51 @@ from transformer_engine.common.recipe import (
)
from .constants import dist_group_type
from .utils import get_device_compute_capability
from .utils import get_device_compute_capability, is_gfx928, is_gfx936, is_gfx938
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
int8_simulation_fp8_tensorwise = bool(int(os.getenv("NVTE_INT8_SIM_FP8_TENSORWISE", "0")))
blockwise_fp8_block_len = int(os.getenv("NVTE_BLOCKWISE_FP8_BLOCK_LEN", "128"))
__all__ = ["fp8_autocast", "fp8_model_init"]
if IS_HIP_EXTENSION:
from transformer_engine.pytorch.utils import is_K100_AI, is_BW
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if IS_HIP_EXTENSION:
if (is_K100_AI() or is_BW()) and int8_simulation_fp8:
return True, "DCU turn on fp8 simulation with int8"
else:
return False, "DCU not support fp8 for now"
else:
if get_device_compute_capability() >= (9, 0): # hopper and above
if is_gfx938():
return True, ""
if (is_gfx928() or is_gfx936()) and int8_simulation_fp8 and int8_simulation_fp8_tensorwise:
return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return False, "DCU not support fp8 for now"
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if IS_HIP_EXTENSION:
return False, "DCU not support mxfp8 for now"
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
def check_fp8_block_scaling_support() -> Tuple[bool, str]:
"""Return if fp8 block scaling support is available"""
if IS_HIP_EXTENSION:
if is_K100_AI() or is_BW():
if is_gfx938():
return True, ""
else:
return False, "DCU not support block_scaling fp8 for now"
if (is_gfx928() or is_gfx936()) and int8_simulation_fp8:
return True, ""
return False, "DCU not support block_scaling fp8 for now"
if (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
......@@ -79,7 +78,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
return True, ""
return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
def check_recipe_support(recipe: Recipe) -> None:
"""Check if the given recipe is supported."""
recipe_supported = True
......@@ -102,7 +100,6 @@ def get_default_fp8_recipe() -> Recipe:
return Float8CurrentScaling()
return DelayedScaling()
def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype:
"""Get fp8 data type according to recipe and tensor"""
if fp8_recipe.fp8_format == Format.E4M3 or (
......
......@@ -10,19 +10,20 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION
ActivationOffloadEnabled = False
def get_activation_offloading():
"""Get global status of get_activation_offloading"""
global ActivationOffloadEnabled
return ActivationOffloadEnabled
def set_activation_offloading(activation_offloading):
"""Set global status of get_activation_offloading"""
global ActivationOffloadEnabled
ActivationOffloadEnabled = activation_offloading
......@@ -467,34 +468,73 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)
if IS_HIP_EXTENSION:
def is_mi200():
"""check whether this machine is mi200/210/250"""
import re
return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_K100_AI():
"""check whether this machine is K100_AI"""
import re
return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_BW():
"""check whether this machine is BW"""
import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
@functools.lru_cache(maxsize=None)
def _get_gcn_arch_impl(device: torch.device) -> int:
props = torch.cuda.get_device_properties(device)
import re
if re.search('gfx906', props.gcnArchName) is not None:
return 906
if re.search('gfx926', props.gcnArchName) is not None:
return 926
if re.search('gfx928', props.gcnArchName) is not None:
return 928
if re.search('gfx936', props.gcnArchName) is not None:
return 936
if re.search('gfx938', props.gcnArchName) is not None:
return 938
raise RuntimeError(f"Unsupported GCN Arch {props.gcnArchName}")
def _get_gcn_arch() -> int:
return _get_gcn_arch_impl(torch.cuda.current_device())
def is_gfx906() -> bool:
"""check whether this machine is gfx906"""
return _get_gcn_arch() == 906
def is_gfx926() -> bool:
"""check whether this machine is gfx926"""
return _get_gcn_arch() == 926
def is_gfx928() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 928
def is_gfx936() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 936
def is_gfx938() -> bool:
"""check whether this machine is gfx928"""
return _get_gcn_arch() == 938
else:
def is_gfx906() -> bool:
"""gfx906 is only available on ROCm"""
return False
def is_gfx926() -> bool:
"""gfx926 is only available on ROCm"""
return False
def is_gfx928() -> bool:
"""gfx928 is only available on ROCm"""
return False
def is_gfx936() -> bool:
"""gfx936 is only available on ROCm"""
return False
def is_gfx938() -> bool:
"""gfx938 is only available on ROCm"""
return False
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16
if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW():
return True
else:
return False
else:
return torch.cuda.get_device_capability()[0] >= 8
# only these arch support bf16
return is_gfx928() or is_gfx936() or is_gfx938()
return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
......@@ -505,8 +545,7 @@ def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
if IS_HIP_EXTENSION:
if is_blockwise:
return False
else:
return True
return True
device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment