Commit d8041744 authored by yuguo's avatar yuguo
Browse files

[DCU] fix all gather usage

parent a397dcb7
...@@ -1004,10 +1004,10 @@ def _post_process_fp8_blockwise_gather( ...@@ -1004,10 +1004,10 @@ def _post_process_fp8_blockwise_gather(
return out return out
needs_columnwise_data_transpose = ( needs_columnwise_data_transpose = (
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
) )
need_rowwise_scale_transpose = ( need_rowwise_scale_transpose = (
quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported(is_blockwise=True)
) )
# CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024
......
...@@ -488,12 +488,15 @@ def is_bf16_compatible() -> None: ...@@ -488,12 +488,15 @@ def is_bf16_compatible() -> None:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def is_non_tn_fp8_gemm_supported() -> bool: def is_non_tn_fp8_gemm_supported(is_blockwise: Optional[bool] = False) -> bool:
"""Checks whether the device supports """Checks whether the device supports
non-TN layouts for FP8 GEMMs. non-TN layouts for FP8 GEMMs.
""" """
if IS_HIP_EXTENSION: if IS_HIP_EXTENSION:
return True if is_blockwise:
return False
else:
return True
device_capability = torch.cuda.get_device_capability() device_capability = torch.cuda.get_device_capability()
return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0) return (10, 0) <= device_capability < (12, 0) or device_capability >= (13, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment