Unverified Commit fa91ed72 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

mxfp8 (for all gemm layouts) is not supported on 120+ arch yet (#1939)



* mxfp8 is not supported on 120+ arch yet
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* change the default recipe for arch 120
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bda29934
...@@ -46,6 +46,8 @@ def check_fp8_support() -> Tuple[bool, str]: ...@@ -46,6 +46,8 @@ def check_fp8_support() -> Tuple[bool, str]:
def check_mxfp8_support() -> Tuple[bool, str]: def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available""" """Return if fp8 support is available"""
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, "" return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution." return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
...@@ -64,7 +66,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ...@@ -64,7 +66,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe: def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args.""" """FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling() return MXFP8BlockScaling()
return DelayedScaling() return DelayedScaling()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment