Commit d8041744 authored by yuguo's avatar yuguo
Browse files

[DCU] fix all gather usage

parent a397dcb7
......@@ -1004,10 +1004,10 @@ def _post_process_fp8_blockwise_gather(
return out
needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
)
need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported()
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
)
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
......
......@@ -488,12 +488,15 @@ def is_bf16_compatible() -> None:
@functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool:
def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
"""
if IS_HIP_EXTENSION:
return True
if is_blockwise:
return False
else:
return True
device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment