Unverified Commit 4ab32566 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 (#10095)


Signed-off-by: default avatarmgoin <michael@neuralmagic.com>
parent 719c1ca4
...@@ -7,8 +7,7 @@ from vllm.platforms import current_platform ...@@ -7,8 +7,7 @@ from vllm.platforms import current_platform
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
if current_platform.is_rocm() else None
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
...@@ -166,8 +165,7 @@ def apply_fp8_linear( ...@@ -166,8 +165,7 @@ def apply_fp8_linear(
# Making sure the dummy tensor is on the same device as the weight # Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY global TORCH_DEVICE_IDENTITY
if (TORCH_DEVICE_IDENTITY is not None if TORCH_DEVICE_IDENTITY.device != weight.device:
and TORCH_DEVICE_IDENTITY.device != weight.device):
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM # GEMM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment