Unverified Commit a1c8f379 authored by Jeff Daily's avatar Jeff Daily Committed by GitHub
Browse files

dynamic distpatch of fp8 kernels (#14245)


Signed-off-by: default avatarJeff Daily <jeff.daily@amd.com>
parent 08a1a112
...@@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op ...@@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else
torch.float8_e4m3fn)
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
...@@ -165,9 +161,7 @@ def input_to_float8( ...@@ -165,9 +161,7 @@ def input_to_float8(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values " """This function quantizes input values to float8 values "
"with tensor-wise quantization.""" "with tensor-wise quantization."""
if dtype is None: dtype = current_platform.fp8_dtype() if dtype is None else dtype
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax() min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
...@@ -311,9 +305,7 @@ def per_token_group_quant_fp8( ...@@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
if dtype is None: dtype = current_platform.fp8_dtype() if dtype is None else dtype
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), ( assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}") f"by `group_size` {group_size}")
......
...@@ -293,6 +293,10 @@ class CudaPlatformBase(Platform): ...@@ -293,6 +293,10 @@ class CudaPlatformBase(Platform):
def get_device_communicator_cls(cls) -> str: def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
......
...@@ -330,6 +330,36 @@ class Platform: ...@@ -330,6 +330,36 @@ class Platform:
""" """
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
@classmethod
def supports_fp8(cls) -> bool:
"""
Returns whether the current platform supports FP8 types.
"""
return False
@classmethod
def is_fp8_fnuz(cls) -> bool:
"""
Returns whether the preferred FP8 type is FNUZ on the current platform.
There are two representations of FP8, OCP FP8 and FNUZ FP8.
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
hardware has converged on the OCP FP8 standard.
"""
return False
@classmethod
def fp8_dtype(cls) -> torch.dtype:
"""
Returns the preferred FP8 type on the current platform.
See the documentation for is_fp8_fnuz for details.
"""
return torch.float8_e4m3fn
@classmethod @classmethod
def use_all_gather(cls) -> bool: def use_all_gather(cls) -> bool:
""" """
......
...@@ -231,3 +231,20 @@ class RocmPlatform(Platform): ...@@ -231,3 +231,20 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_device_communicator_cls(cls) -> str: def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def fp8_dtype(cls) -> torch.dtype:
if cls.is_fp8_fnuz():
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn
...@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8) CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize) scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
...@@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK, W_Q_UK,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous() self.W_Q_UK = W_Q_UK.T.contiguous()
...@@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O, W_UV_O,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous() self.W_UV_O = W_UV_O.T.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment