Commit 04afba37 authored by wenjh's avatar wenjh
Browse files

Enable test_attention


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent ea5cc27a
...@@ -50,6 +50,7 @@ from transformer_engine.pytorch.quantized_tensor import ( ...@@ -50,6 +50,7 @@ from transformer_engine.pytorch.quantized_tensor import (
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
from torch.utils.cpp_extension import IS_HIP_EXTENSION
_current_file = pathlib.Path(__file__).resolve() _current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent)) sys.path.append(str(_current_file.parent.parent))
...@@ -65,7 +66,8 @@ from utils import ( ...@@ -65,7 +66,8 @@ from utils import (
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8 fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability() device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)): if not IS_HIP_EXTENSION:
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False fp8_attn_available = False
reason_for_no_fp8_attn = ( reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability =" "FP8 attention is not supported for compute capability ="
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment