Unverified Commit fa91ed72 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

mxfp8 (for all gemm layouts) is not supported on 120+ arch yet (#1939)



* mxfp8 is not supported on 120+ arch yet
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* change the default recipe for arch 120
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent bda29934
......@@ -46,6 +46,8 @@ def check_fp8_support() -> Tuple[bool, str]:
def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."
......@@ -64,7 +66,11 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
if check_mxfp8_support()[0]:
# This is a temporary restriction until MXFP8 is supported for all
# gemm layouts.
if get_device_compute_capability() >= (12, 0):
return Float8BlockScaling()
return MXFP8BlockScaling()
return DelayedScaling()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment