Unverified Commit 2e9034c9 authored by Maral's avatar Maral Committed by GitHub
Browse files

[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8...


[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: default avatarmaral <maralbahari.98@gmail.com>
Signed-off-by: default avatarMaral <maralbahari.98@gmail.com>
parent 8332078c
...@@ -6,8 +6,15 @@ import torch ...@@ -6,8 +6,15 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from .cutlass import CutlassInt8ScaledMMLinearKernel from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
...@@ -107,3 +114,54 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): ...@@ -107,3 +114,54 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# b to be [N, K] # b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
n, k = config.weight_shape
self.use_triton = (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
)
@classmethod
def is_supported(cls, compute_capability=None):
return (
rocm_aiter_ops.is_linear_enabled(),
"Only supported on ROCm platform \
with aiter package installed.",
)
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
if self.use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
return gemm_a8w8_blockscale_op(
A, B, As, Bs, list(self.weight_group_shape), output_dtype=out_dtype
)
...@@ -5,12 +5,19 @@ ...@@ -5,12 +5,19 @@
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
convert_to_channelwise, convert_to_channelwise,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
...@@ -171,3 +178,143 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): ...@@ -171,3 +178,143 @@ class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
) )
return output.view(*output_shape) return output.view(*output_shape)
class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.weight_group_shape = config.weight_quant_key.scale.group_shape
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
use_ue8m0=False,
column_major_scales=True,
)
self.is_hopper = current_platform.is_device_capability(90)
@classmethod
def is_supported(cls, compute_capability=None):
if not CUTLASS_BLOCK_FP8_SUPPORTED:
return (
False,
"The device compute capability of"
f"{compute_capability} is not supported.",
)
return True, None
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(
A,
B,
As,
Bs,
list(self.weight_group_shape),
out_dtype,
)
else:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.T,
)
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
scale_b=Bs.T,
)
def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
has_pad = padded > dim
if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
deepgemm_post_process_fp8_weight_block,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_auto_disable_deep_gemm,
should_use_deepgemm_for_fp8_linear,
)
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class DeepGemmFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
act_scale_descriptor = config.activation_quant_key.scale
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.quant_fp8 = QuantFP8(
static=False,
group_shape=act_scale_descriptor.group_shape,
use_ue8m0=self.use_deep_gemm_e8m0,
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
column_major_scales=True,
)
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda():
return False, "DeepGEMM is only supported on cuda platform"
if not is_deep_gemm_supported():
return False, "Currently, only Hopper and Blackwell GPUs are supported."
return True, None
@classmethod
def can_implement(cls, config):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
if config.out_dtype != torch.bfloat16:
return (False, "Supports only output dtype of bfloat16")
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
model_config = get_current_vllm_config().model_config
if model_config is None:
return False, "Model configuration is required."
model_type = getattr(model_config.hf_text_config, "model_type", None)
if should_auto_disable_deep_gemm(model_type):
return False, f"Should not use deepgemm for model {model_type}"
if not should_use_deepgemm_for_fp8_linear(
config.out_dtype, config.weight_shape
):
return False, "The provided metadata is not supported."
return True, None
def process_weights_after_loading(self, layer):
super().process_weights_after_loading(layer)
params = self._get_layer_params(layer)
assert layer.weight_block_size is not None
if self.is_deep_gemm_supported:
weight_scale_invs = params.weight_scale_inv
scale_attr = (
params.WEIGHT_SCALE_INV
if weight_scale_invs is not None
else params.WEIGHT_SCALE
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=params.weight,
ws=weight_scale_invs
if weight_scale_invs is not None
else params.weight_scale,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=self.use_deep_gemm_e8m0,
)
replace_parameter(layer, params.WEIGHT, dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
output = torch.empty(
(A.shape[0], B.shape[0]),
dtype=out_dtype,
device=A.device,
)
torch.ops.vllm.fp8_gemm_nt_op(A, As, B, Bs, output, self.use_deep_gemm_e8m0)
return output
def _fp8_gemm_nt_op(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
def _fp8_gemm_nt_op_fake(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
return None
direct_register_custom_op(
"fp8_gemm_nt_op",
_fp8_gemm_nt_op,
mutates_args=["output"],
fake_impl=_fp8_gemm_nt_op_fake,
)
...@@ -2,11 +2,32 @@ ...@@ -2,11 +2,32 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import ClassVar
import torch import torch
import vllm.envs as envs
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer from vllm.utils.flashinfer import (
flashinfer_fp8_blockscale_gemm,
flashinfer_scaled_fp8_mm,
has_flashinfer,
is_flashinfer_fp8_blockscale_gemm_supported,
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledDynamicMMLinearKernel,
Fp8BlockScaledMMLinearKernel,
)
from .deep_gemm import DeepGemmFp8BlockScaledMMKernel, fp8_gemm_nt
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel, FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig, FP8ScaledMMLinearLayerConfig,
...@@ -55,3 +76,256 @@ class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): ...@@ -55,3 +76,256 @@ class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
return flashinfer_scaled_fp8_mm( return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
) )
class FlashInferFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
# FlashInfer accepts BF16 input and handles FP8 conversion internally.
apply_input_quant: ClassVar[bool] = False
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
can_implement_base, reason = super().can_implement(config)
if not can_implement_base:
return can_implement_base, reason
act_quant_desc = config.activation_quant_key.scale
if act_quant_desc.group_shape != GroupShape(1, 128):
return (
False,
"Supports only dynamic per token group activation "
"quantization with group_shape=(1,128).",
)
if not should_use_flashinfer_for_blockscale_fp8_gemm(
is_flashinfer_fp8_blockscale_gemm_supported(),
config.out_dtype,
config.input_dtype,
config.weight_quant_key.dtype,
config.weight_shape,
):
return (
False,
"The provided metadata is not supported.",
)
return True, None
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda():
return False, "only cuda devices are supported."
if not is_flashinfer_fp8_blockscale_gemm_supported():
return False, "FlashInfer block-scale FP8 GEMM is not available."
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
# A is BF16 — FlashInfer handles FP8 conversion internally.
# As is a placeholder (apply_input_quant=False) and is not used here.
return torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
A, # BF16 input
B, # FP8 weight
Bs, # Weight scales
)
class FlashInferFp8DeepGEMMDynamicBlockScaledKernel(
Fp8BlockScaledDynamicMMLinearKernel
):
"""
Conditional FlashInfer / DeepGEMM FP8 block-scaled GEMM.
Dispatches between two kernels based on input batch size:
- Small batches (M < 32): FlashInfer's swapAB trick for better utilisation.
- Large batches (M >= 32): DeepGEMM for peak throughput.
apply_input_quant is False because FlashInfer accepts BF16 input and
handles FP8 conversion internally. The DeepGEMM branch therefore
quantises BF16→FP8 inside apply_mm via a closure before dispatching to
the DeepGEMM kernel — keeping both branches compatible with the single
BF16 tensor operand list passed by torch.cond.
"""
base_type: ClassVar[type[FlashInferFp8BlockScaledMMKernel]] = (
FlashInferFp8BlockScaledMMKernel
)
fallback_type: ClassVar[type[DeepGemmFp8BlockScaledMMKernel]] = (
DeepGemmFp8BlockScaledMMKernel
)
apply_input_quant: ClassVar[bool] = False
def __init__(self, config: FP8ScaledMMLinearLayerConfig):
super().__init__(config)
self.base: FlashInferFp8BlockScaledMMKernel
self.fallback: DeepGemmFp8BlockScaledMMKernel
def process_weights_after_loading(self, layer: torch.nn.Module):
# DeepGEMM need post-processing; both kernels share the same
# parameter tensor layout so processing once is sufficient.
self.fallback.process_weights_after_loading(layer)
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
group_size = self.weight_group_shape.col
use_deep_gemm_e8m0 = self.fallback.use_deep_gemm_e8m0
return torch.ops.vllm.dynamic_flashinfer_deepgemm_blockscale_gemm(
A, B, Bs, group_size, use_deep_gemm_e8m0
)
def _flashinfer_fp8_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def _flashinfer_fp8_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"flashinfer_fp8_blockscale_gemm",
_flashinfer_fp8_blockscale_gemm_impl,
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
)
def _dynamic_flashinfer_deepgemm_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
This function switches between two optimized kernels based on the input batch size:
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
The conditional logic must use torch.cond() instead of a simple if-else statement
to maintain compatibility with torch.compile graph compilation.
This batch-size-dependent selection is essential for maintaining model accuracy.
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
drop.
Args:
input: Input tensor of shape (batch_size, input_dim) in FP8 format
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
weight_scale: Scale factors for weight quantization (per-group)
group_size: Quantization group size for the weight tensor
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
Returns:
Output tensor of shape (batch_size, output_dim) in bfloat16 format
"""
def run_flashinfer_deepgemm_swapAB(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def run_deepgemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
q_input, input_scale = per_token_group_quant_fp8(
input,
group_size=group_size,
column_major_scales=True,
use_ue8m0=use_deep_gemm_e8m0,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
return output
if envs.VLLM_BATCH_INVARIANT:
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
# Python conditionals. torch.cond() explicitly registers both code paths in the
# computation graph, allowing torch.compile to capture both branches.
# without torch.cond, the M < 32 condition won't be able to be captured by torch
# compile
return torch.cond(
condition,
run_flashinfer_deepgemm_swapAB,
run_deepgemm,
(input, weight, weight_scale),
)
def _dynamic_flashinfer_deepgemm_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"dynamic_flashinfer_deepgemm_blockscale_gemm",
_dynamic_flashinfer_deepgemm_blockscale_gemm_impl,
fake_impl=_dynamic_flashinfer_deepgemm_blockscale_gemm_fake,
)
...@@ -13,7 +13,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -13,7 +13,11 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
)
from .cutlass import CutlassInt8ScaledMMLinearKernel from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import ( from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig,
...@@ -150,3 +154,67 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): ...@@ -150,3 +154,67 @@ class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
out -= (x_s * w_s_row * azp_adj).to(x.dtype) out -= (x_s * w_s_row * azp_adj).to(x.dtype)
return out return out
class TritonFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
@classmethod
def is_supported(cls, compute_capability=None):
if not current_platform.is_cuda_alike():
return False, "only cuda like devices are supported."
return True, None
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
A,
B,
As,
Bs,
list(self.weight_group_shape),
self.config.out_dtype,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm,
)
return w8a8_triton_block_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
)
...@@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate ...@@ -8,6 +8,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from torch.nn import Parameter from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
...@@ -16,18 +17,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -16,18 +17,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_channel_strategy, process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kFp8StaticTokenSym, kFp8StaticTokenSym,
...@@ -67,6 +66,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -67,6 +66,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.weight_quant = weight_quant self.weight_quant = weight_quant
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
...@@ -75,21 +75,18 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -75,21 +75,18 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size), self.weight_quant_key = create_fp8_quant_key(
act_quant_group_shape=self.act_q_group_shape, static=True, group_shape=GroupShape(*self.weight_block_size)
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: self.activation_quant_key = create_fp8_quant_key(
activation_quant_key = activation_quant_key_mapping[is_static_input_scheme] static=False, group_shape=self.act_q_group_shape
weight_quant_key = weight_quant_key_mapping[self.strategy]
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
) )
else:
self.activation_quant_key = activation_quant_key_mapping[
is_static_input_scheme
]
self.weight_quant_key = weight_quant_key_mapping[self.strategy]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -146,6 +143,15 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -146,6 +143,15 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader) input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR: if self.strategy == QuantizationStrategy.TENSOR:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy( weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
...@@ -163,10 +169,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -163,10 +169,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.BLOCK: elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy( self.fp8_linear.process_weights_after_loading(layer)
layer.weight, layer.weight_scale
) layer.input_scale = None
input_scale = None # fp8_linear.process_weights_after_loading applies the post process
# and reassigns the weight and weight_scale buffers to layer attributes.
return
else: else:
raise ValueError( raise ValueError(
...@@ -185,8 +193,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -185,8 +193,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else: else:
layer.input_scale = None layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)
if hasattr(self, "fp8_linear"): if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer) self.fp8_linear.process_weights_after_loading(layer)
...@@ -197,13 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -197,13 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
...@@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -93,12 +94,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel( self.input_dtype = get_current_vllm_config().model_config.dtype
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights( def create_weights(
self, self,
...@@ -149,6 +145,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -149,6 +145,15 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
) )
layer.input_scale_ub = input_scale_ub layer.input_scale_ub = input_scale_ub
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# required by torch.compile # required by torch.compile
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
......
...@@ -10,7 +10,7 @@ from torch.utils._python_dispatch import TorchDispatchMode ...@@ -10,7 +10,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
...@@ -45,13 +45,10 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -45,13 +45,10 @@ from vllm.model_executor.layers.quantization.base_config import (
) )
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
process_fp8_weight_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe,
validate_fp8_block_shape, validate_fp8_block_shape,
...@@ -61,6 +58,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( ...@@ -61,6 +58,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key,
is_layer_skipped, is_layer_skipped,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8DynamicTensorSym, kFp8DynamicTensorSym,
...@@ -273,12 +271,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -273,12 +271,13 @@ class Fp8LinearMethod(LinearMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.marlin_input_dtype = None self.marlin_input_dtype = None
self.use_marlin = False
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
if self.quant_config.use_deep_gemm is not None: if self.quant_config.use_deep_gemm is not None:
self.use_deep_gemm = self.quant_config.use_deep_gemm self.use_deep_gemm = self.quant_config.use_deep_gemm
else: else:
...@@ -288,37 +287,26 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -288,37 +287,26 @@ class Fp8LinearMethod(LinearMethodBase):
self.block_quant = self.weight_block_size is not None self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static" self.act_q_static = self.quant_config.activation_scheme == "static"
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
if self.block_quant: if self.block_quant:
weight_quant_key = kFp8Static128BlockSym
else:
weight_quant_key = kFp8StaticTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
if self.block_quant and not self.use_marlin:
assert not self.act_q_static assert not self.act_q_static
assert self.weight_block_size is not None assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size), self.activation_quant_key = create_fp8_quant_key(
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]), static=self.act_q_static,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, group_shape=GroupShape(1, self.weight_block_size[0]),
use_aiter_and_is_supported=self.use_aiter_and_is_supported, )
use_deep_gemm=self.use_deep_gemm, self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size)
) )
else:
self.weight_quant_key = kFp8StaticTensorSym
# Use per-token quantization for better perf if dynamic and cutlass
if self.act_q_static:
self.activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
self.activation_quant_key = kFp8DynamicTokenSym
else:
self.activation_quant_key = kFp8DynamicTensorSym
def create_weights( def create_weights(
self, self,
...@@ -384,6 +372,17 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -384,6 +372,17 @@ class Fp8LinearMethod(LinearMethodBase):
set_weight_attrs(scale, {"scale_type": "input_scale"}) set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if self.use_marlin: if self.use_marlin:
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid # Only Marlin kernels support `marlin_input_dtype`; guard to avoid
...@@ -398,13 +397,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -398,13 +397,7 @@ class Fp8LinearMethod(LinearMethodBase):
if self.block_quant: if self.block_quant:
assert not self.act_q_static assert not self.act_q_static
weight, weight_scale_inv = process_fp8_weight_block_strategy( self.fp8_linear.process_weights_after_loading(layer)
layer.weight, layer.weight_scale_inv
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
else: else:
...@@ -435,9 +428,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -435,9 +428,6 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
layer.input_scale = None layer.input_scale = None
if self.block_quant and self.use_deep_gemm:
maybe_post_process_fp8_weight_block(layer)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -449,12 +439,10 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -449,12 +439,10 @@ class Fp8LinearMethod(LinearMethodBase):
if envs.VLLM_BATCH_INVARIANT: if envs.VLLM_BATCH_INVARIANT:
if self.block_quant: if self.block_quant:
assert self.weight_block_size is not None assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply( return self.fp8_linear.apply_weights(
input=x, layer,
weight=layer.weight, x,
weight_scale=layer.weight_scale_inv, bias,
input_scale=layer.input_scale,
bias=bias,
) )
else: else:
# per-tensor/channel: dequant to BF16 and run GEMM # per-tensor/channel: dequant to BF16 and run GEMM
...@@ -483,17 +471,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -483,17 +471,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.use_marlin: if self.use_marlin:
return self.fp8_linear.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale_inv,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply_weights(layer, x, bias) return self.fp8_linear.apply_weights(layer, x, bias)
...@@ -538,6 +515,24 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -538,6 +515,24 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
initialize_online_processing(layer) initialize_online_processing(layer)
# TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314
# This check is required because Mxfp8OnlineLinearMethod inherits from
# Fp8OnlineLinearMethod but only calls super().create_weights(), so we must
# skip the fp8_linear kernel creation.
if hasattr(self, "mxfp8_linear"):
return
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
self.use_marlin = isinstance(self.fp8_linear, MarlinFP8ScaledMMLinearKernel)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention import Attention, MLAAttention
...@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -56,7 +57,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31, swap_w13_to_w31,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe, process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe,
) )
...@@ -78,6 +78,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( ...@@ -78,6 +78,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
...@@ -86,7 +87,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -86,7 +87,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kNvfp4Static, kNvfp4Static,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
requantize_with_max_scale, requantize_with_max_scale,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -450,12 +450,8 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel( self.out_dtype = torch.get_default_dtype()
activation_quant_key=kFp8StaticTensorSym, self.input_dtype = get_current_vllm_config().model_config.dtype
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights( def create_weights(
self, self,
...@@ -505,6 +501,15 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -505,6 +501,15 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
scale[:] = torch.finfo(torch.float32).min scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight weight = layer.weight
max_w_scale = layer.weight_scale.max() max_w_scale = layer.weight_scale.max()
...@@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): ...@@ -536,12 +541,8 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = init_fp8_linear_kernel( self.out_dtype = torch.get_default_dtype()
activation_quant_key=kFp8DynamicTokenSym, self.input_dtype = get_current_vllm_config().model_config.dtype
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights( def create_weights(
self, self,
...@@ -587,6 +588,15 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): ...@@ -587,6 +588,15 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
...@@ -616,12 +626,16 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase): ...@@ -616,12 +626,16 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
block_n, block_k = self._WEIGHT_BLOCK_SIZE block_n, block_k = self._WEIGHT_BLOCK_SIZE
self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE) self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE)
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(block_n, block_k), self.activation_quant_key = create_fp8_quant_key(
act_quant_group_shape=GroupShape(1, block_k), static=False, group_shape=GroupShape(1, block_k)
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
) )
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_n, block_k)
)
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights( def create_weights(
self, self,
...@@ -688,8 +702,17 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase): ...@@ -688,8 +702,17 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
weight_scale[:] = torch.finfo(torch.float32).min weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
self.w8a8_block_fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Keep weight in [out, in] layout for W8A8BlockFp8LinearOp. # Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel.
layer.weight = Parameter(layer.weight.data, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False)
scale = layer.weight_scale scale = layer.weight_scale
...@@ -713,13 +736,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase): ...@@ -713,13 +736,7 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.w8a8_block_fp8_linear.apply( return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
class ModelOptFp8MoEMethod(FusedMoEMethodBase): class ModelOptFp8MoEMethod(FusedMoEMethodBase):
......
...@@ -17,7 +17,7 @@ if TYPE_CHECKING: ...@@ -17,7 +17,7 @@ if TYPE_CHECKING:
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config
from vllm.model_executor.kernels.linear import init_fp8_linear_kernel from vllm.model_executor.kernels.linear import init_fp8_linear_kernel
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoEMethodBase, FusedMoEMethodBase,
...@@ -28,13 +28,9 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( ...@@ -28,13 +28,9 @@ from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearMethodBase, LinearMethodBase,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key,
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8DynamicTensorSym, kFp8DynamicTensorSym,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
...@@ -42,7 +38,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -42,7 +38,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
cutlass_fp8_supported, cutlass_fp8_supported,
) )
from vllm.model_executor.model_loader.reload.layerwise import ( from vllm.model_executor.model_loader.reload.layerwise import (
...@@ -51,7 +46,7 @@ from vllm.model_executor.model_loader.reload.layerwise import ( ...@@ -51,7 +46,7 @@ from vllm.model_executor.model_loader.reload.layerwise import (
from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported, per_block_cast_to_fp8 from vllm.utils.deep_gemm import per_block_cast_to_fp8
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Online FP8 Linear Methods # Online FP8 Linear Methods
...@@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase): ...@@ -64,6 +59,10 @@ class _Fp8OnlineLinearBase(LinearMethodBase):
uses_meta_device: bool = True uses_meta_device: bool = True
def __init__(self):
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -103,18 +102,41 @@ class Fp8PerTensorOnlineLinearMethod(_Fp8OnlineLinearBase):
Loads fp16/bf16 weights and quantizes them per-tensor during loading.""" Loads fp16/bf16 weights and quantizes them per-tensor during loading."""
def __init__(self): def __init__(self):
self.out_dtype = torch.get_default_dtype() super().__init__()
self.weight_quant_key = kFp8StaticTensorSym
# Use per-token quantization for better perf if dynamic and cutlass # Use per-token quantization for better perf if dynamic and cutlass
if cutlass_fp8_supported(): if cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym self.activation_quant_key = kFp8DynamicTokenSym
else: else:
activation_quant_key = kFp8DynamicTensorSym self.activation_quant_key = kFp8DynamicTensorSym
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key, activation_quant_key=self.activation_quant_key,
weight_quant_key=kFp8StaticTensorSym, weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(), weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
...@@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -166,19 +188,14 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
Loads fp16/bf16 weights and quantizes them per-block during loading.""" Loads fp16/bf16 weights and quantizes them per-block during loading."""
def __init__(self): def __init__(self):
self.out_dtype = torch.get_default_dtype() super().__init__()
self.weight_block_size = [128, 128] self.weight_block_size = [128, 128]
self.activation_quant_key = create_fp8_quant_key(
self.use_deep_gemm = is_deep_gemm_supported() static=False,
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() group_shape=GroupShape(1, self.weight_block_size[0]),
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() )
self.weight_quant_key = create_fp8_quant_key(
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( static=True, group_shape=GroupShape(*self.weight_block_size)
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
use_deep_gemm=self.use_deep_gemm,
) )
def create_weights( def create_weights(
...@@ -202,6 +219,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -202,6 +219,15 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
) )
layer.weight_block_size = self.weight_block_size layer.weight_block_size = self.weight_block_size
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
return return
...@@ -213,14 +239,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -213,14 +239,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
layer.weight, block_size=block_size, use_ue8m0=False layer.weight, block_size=block_size, use_ue8m0=False
) )
qweight, weight_scale_inv = process_fp8_weight_block_strategy(
qweight, weight_scale_inv
)
replace_parameter(layer, "weight", qweight.data) replace_parameter(layer, "weight", qweight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
maybe_post_process_fp8_weight_block(layer) self.fp8_linear.process_weights_after_loading(layer)
# Prevent duplicate processing (e.g., during weight reload) # Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True layer._already_called_process_weights_after_loading = True
...@@ -234,12 +256,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase): ...@@ -234,12 +256,10 @@ class Fp8PerBlockOnlineLinearMethod(_Fp8OnlineLinearBase):
assert self.weight_block_size is not None assert self.weight_block_size is not None
# Note: batch invariance already handled in the function below # Note: batch invariance already handled in the function below
return self.w8a8_block_fp8_linear.apply( return self.fp8_linear.apply_weights(
input=x, layer,
weight=layer.weight, x,
weight_scale=layer.weight_scale_inv, bias,
input_scale=layer.input_scale,
bias=bias,
) )
......
...@@ -7,6 +7,7 @@ from typing import Any, cast ...@@ -7,6 +7,7 @@ from typing import Any, cast
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
...@@ -57,6 +58,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -57,6 +58,7 @@ class QuarkW8A8Fp8(QuarkScheme):
kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
) )
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -175,7 +177,9 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -175,7 +177,9 @@ class QuarkW8A8Fp8(QuarkScheme):
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key, activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key, weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(), weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
......
...@@ -12,15 +12,11 @@ import torch ...@@ -12,15 +12,11 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max, get_fp8_min_max,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
all_close_1d, all_close_1d,
per_tensor_dequantize, per_tensor_dequantize,
) )
...@@ -29,22 +25,14 @@ from vllm.model_executor.parameter import ( ...@@ -29,22 +25,14 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_tma_aligned_size, get_tma_aligned_size,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
should_use_deepgemm_for_fp8_linear,
transform_sf_into_required_layout, transform_sf_into_required_layout,
) )
from vllm.utils.flashinfer import (
flashinfer_fp8_blockscale_gemm,
is_flashinfer_fp8_blockscale_gemm_supported,
should_use_flashinfer_for_blockscale_fp8_gemm,
)
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool: ...@@ -56,153 +44,6 @@ def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
# We need to pass in the is_hopper flag as argument because the function
# current_platform.is_device_capability() is not supported by Torch compiler.
def cutlass_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
scale_b=Bs.T,
)
# TODO we should be able to change the type of block_size to GroupShape
# after we resolve GroupShape compilation issue
# https://github.com/vllm-project/vllm/issues/25270
def _w8a8_triton_block_scaled_mm_func(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return w8a8_triton_block_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _w8a8_triton_block_scaled_mm_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"w8a8_triton_block_scaled_mm_func",
_w8a8_triton_block_scaled_mm_func,
fake_impl=_w8a8_triton_block_scaled_mm_fake,
)
def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)
has_pad = padded > dim
if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)
padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)
def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)
direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)
def _fp8_gemm_nt_op(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
def _fp8_gemm_nt_op_fake(
q_input: torch.Tensor,
input_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
use_deep_gemm_e8m0: bool,
) -> None:
return None
direct_register_custom_op(
"fp8_gemm_nt_op",
_fp8_gemm_nt_op,
mutates_args=["output"],
fake_impl=_fp8_gemm_nt_op_fake,
)
def _triton_per_token_group_quant_fp8_impl( def _triton_per_token_group_quant_fp8_impl(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
...@@ -236,362 +77,6 @@ direct_register_custom_op( ...@@ -236,362 +77,6 @@ direct_register_custom_op(
) )
def _flashinfer_fp8_blockscale_gemm_impl(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Conditional FlashInfer FP8 blockscale GEMM with batch-size-dependent selection.
This function switches between two optimized kernels based on the input batch size:
- For small batches (M < 32): Uses FlashInfer's DeepGEMM swapAB optimization.
- For larger batches (M >= 32): Uses the official DeepGEMM kernel.
The conditional logic must use torch.cond() instead of a simple if-else statement
to maintain compatibility with torch.compile graph compilation.
This batch-size-dependent selection is essential for maintaining model accuracy.
Benchmarks on GSM8K show a significant accuracy gap (88% vs 95%) for DeepSeek-V3.1
when using FlashInfer's DeepGEMM on M>=32. The M < 32 strategy fixes the accuracy
drop.
Args:
input: Input tensor of shape (batch_size, input_dim) in FP8 format
weight: Weight tensor of shape (output_dim, input_dim) in FP8 format
weight_scale: Scale factors for weight quantization (per-group)
group_size: Quantization group size for the weight tensor
use_deep_gemm_e8m0: Whether to use the E8M0 format in DeepGEMM quantization
Returns:
Output tensor of shape (batch_size, output_dim) in bfloat16 format
"""
def run_flashinfer_deepgemm_swapAB(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return flashinfer_fp8_blockscale_gemm(
input=input,
weight=weight,
weight_scale=weight_scale,
out_dtype=torch.bfloat16,
)
def run_deepgemm(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
q_input, input_scale = per_token_group_quant_fp8(
input,
group_size=group_size,
column_major_scales=True,
use_ue8m0=use_deep_gemm_e8m0,
)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
fp8_gemm_nt(
(q_input, input_scale),
(weight, weight_scale),
output,
is_deep_gemm_e8m0_used=use_deep_gemm_e8m0,
)
return output
if envs.VLLM_BATCH_INVARIANT:
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32
# PyTorch's torch.compile cannot handle input-dependent control flow in standard
# Python conditionals. torch.cond() explicitly registers both code paths in the
# computation graph, allowing torch.compile to capture both branches.
# without torch.cond, the M < 32 condition won't be able to be captured by torch
# compile
return torch.cond(
condition,
run_flashinfer_deepgemm_swapAB,
run_deepgemm,
(input, weight, weight_scale),
)
def _flashinfer_fp8_blockscale_gemm_fake(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
use_deep_gemm_e8m0: bool,
) -> torch.Tensor:
"""
Required fake/meta implementation for torch.compile graph tracing.
"""
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.bfloat16, device=input.device
)
direct_register_custom_op(
"flashinfer_fp8_blockscale_gemm",
_flashinfer_fp8_blockscale_gemm_impl,
fake_impl=_flashinfer_fp8_blockscale_gemm_fake,
)
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
class W8A8BlockFp8LinearOp:
"""
This class executes a Blocked FP8 linear layer using cutlass if supported
and torch.scaled_mm otherwise.
"""
def __init__(
self,
weight_group_shape: GroupShape,
act_quant_group_shape: GroupShape,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
use_deep_gemm: bool | None = None,
):
self.weight_group_shape = weight_group_shape
self.act_quant_group_shape = act_quant_group_shape
if use_deep_gemm is not None:
self.is_deep_gemm_supported = use_deep_gemm
else:
self.is_deep_gemm_supported = is_deep_gemm_supported()
self.is_hopper = current_platform.is_device_capability(90)
self.use_deep_gemm_e8m0 = is_deep_gemm_e8m0_used()
self.is_flashinfer_supported = is_flashinfer_fp8_blockscale_gemm_supported()
# Get the correct blockscale mul and input quant operations.
# We can't use _dispatch_w8a8_blockscale_op to figure out if we want
# to use deepgemm because we don't know the shape of weights (and
# whether deepgemm supports it) at the init time.
self.w8a8_blockscale_op, self.input_quant_op = (
self._dispatch_w8a8_blockscale_op(
cutlass_block_fp8_supported, use_aiter_and_is_supported
)
)
self.deepgemm_input_quant_op = (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
tma_aligned_scales=envs.VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES,
use_ue8m0=self.use_deep_gemm_e8m0,
)
if self.is_deep_gemm_supported
else None
)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
output_dtype = input.dtype
if should_use_flashinfer_for_blockscale_fp8_gemm(
self.is_flashinfer_supported, output_dtype, input_2d, weight
) and should_use_deepgemm_for_fp8_linear(
output_dtype, weight, self.is_deep_gemm_supported
):
output = self._run_flashinfer(input_2d, weight, weight_scale)
elif should_use_deepgemm_for_fp8_linear(
output_dtype, weight, self.is_deep_gemm_supported
):
output = self._run_deepgemm(input_2d, weight, weight_scale)
else:
output = self.w8a8_blockscale_op(
input_2d, weight, weight_scale, input_scale
)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def _run_deepgemm(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
assert self.deepgemm_input_quant_op is not None
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
output = torch.empty(
(q_input.shape[0], weight.shape[0]),
dtype=torch.bfloat16,
device=q_input.device,
)
torch.ops.vllm.fp8_gemm_nt_op(
q_input, input_scale, weight, weight_scale, output, self.use_deep_gemm_e8m0
)
return output
def _run_cutlass(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
if self.is_hopper:
return torch.ops.vllm.padded_cutlass(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
else:
return cutlass_scaled_mm(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
def _run_aiter(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.act_quant_group_shape == GroupShape(1, 128)
n, k = weight.shape
use_triton = (
not current_platform.is_fp8_fnuz()
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
)
if use_triton:
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
else:
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
if input_scale is not None:
q_input = input_2d
else:
q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton)
return gemm_a8w8_blockscale_op(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
output_dtype=input_2d.dtype,
)
def _run_triton(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor | None = None,
) -> torch.Tensor:
assert input_scale is None
assert self.input_quant_op is not None
q_input, input_scale = self.input_quant_op(input_2d)
return torch.ops.vllm.w8a8_triton_block_scaled_mm_func(
q_input,
weight,
input_scale,
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
)
def _run_flashinfer(
self,
input_2d: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
"""
Run FlashInfer FP8 block-scale GEMM.
This backend uses TensorRT-LLM's FP8 block-scale GEMM kernels
and supports FP8+FP8 (W8A8 full quantization) on SM90+ (Hopper).
"""
# Now call FlashInfer with BF16 input + FP8 weight, input will be
# quantized with FlashInfer kernel (W8A8)
output = torch.ops.vllm.flashinfer_fp8_blockscale_gemm(
input=input_2d, # BF16 input
weight=weight, # FP8 weight
weight_scale=weight_scale, # Weight scales
group_size=self.act_quant_group_shape.col,
use_deep_gemm_e8m0=self.use_deep_gemm_e8m0,
)
return output
def _dispatch_w8a8_blockscale_op(
self,
use_cutlass: bool,
use_aiter_and_is_supported: bool,
) -> tuple[
Callable[
[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor | None,
],
torch.Tensor,
],
QuantFP8,
]:
if use_cutlass:
return self._run_cutlass, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=True,
use_ue8m0=False,
)
)
if use_aiter_and_is_supported:
return self._run_aiter, QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
return self._run_triton, (
QuantFP8(
False,
self.act_quant_group_shape,
column_major_scales=False,
use_ue8m0=False,
)
)
def input_to_float8( def input_to_float8(
x: torch.Tensor, dtype: torch.dtype | None = None x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy( ...@@ -1612,34 +1097,6 @@ def process_fp8_weight_block_strategy(
return weight, weight_scale return weight, weight_scale
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
should_use_deepgemm_for_fp8_linear,
)
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight
)
if should_use_deepgemm:
scale_attr = (
"weight_scale_inv" if hasattr(layer, "weight_scale_inv") else "weight_scale"
)
dg_weight, dg_weight_scale = deepgemm_post_process_fp8_weight_block(
wq=layer.weight.data,
ws=getattr(layer, scale_attr).data,
quant_block_shape=tuple(layer.weight_block_size),
use_e8m0=is_deep_gemm_e8m0_used(),
)
replace_parameter(layer, "weight", dg_weight)
replace_parameter(layer, scale_attr, dg_weight_scale)
def process_fp8_weight_tensor_strategy_moe( def process_fp8_weight_tensor_strategy_moe(
weight: torch.Tensor, weight: torch.Tensor,
weight_scales: torch.Tensor, weight_scales: torch.Tensor,
......
...@@ -171,6 +171,16 @@ kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32)) ...@@ -171,6 +171,16 @@ kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True) kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
def create_fp8_quant_key(
static: bool,
group_shape: GroupShape,
symmetric: bool = True,
scale_dtype: torch.dtype = torch.float32,
) -> QuantKey:
scale_desc = ScaleDesc(scale_dtype, static, group_shape)
return QuantKey(FP8_DTYPE, scale_desc, symmetric=symmetric)
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent # -1 means full extent
......
...@@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): ...@@ -413,7 +413,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor):
def should_use_deepgemm_for_fp8_linear( def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype, output_dtype: torch.dtype,
weight: torch.Tensor, weight_shape: tuple[int, int],
supports_deep_gemm: bool | None = None, supports_deep_gemm: bool | None = None,
): ):
if supports_deep_gemm is None: if supports_deep_gemm is None:
...@@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear( ...@@ -428,8 +428,8 @@ def should_use_deepgemm_for_fp8_linear(
return ( return (
supports_deep_gemm supports_deep_gemm
and output_dtype == torch.bfloat16 and output_dtype == torch.bfloat16
and weight.shape[0] % N_MULTIPLE == 0 and weight_shape[0] % N_MULTIPLE == 0
and weight.shape[1] % K_MULTIPLE == 0 and weight_shape[1] % K_MULTIPLE == 0
) )
......
...@@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool: ...@@ -748,8 +748,9 @@ def is_flashinfer_fp8_blockscale_gemm_supported() -> bool:
def should_use_flashinfer_for_blockscale_fp8_gemm( def should_use_flashinfer_for_blockscale_fp8_gemm(
is_flashinfer_supported: bool, is_flashinfer_supported: bool,
output_dtype: torch.dtype, output_dtype: torch.dtype,
input: torch.Tensor, input_dtype: torch.dtype,
weight: torch.Tensor, weight_dtype: torch.dtype,
weight_shape: tuple[int, int],
): ):
if not is_flashinfer_supported: if not is_flashinfer_supported:
return False return False
...@@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( ...@@ -760,15 +761,12 @@ def should_use_flashinfer_for_blockscale_fp8_gemm(
N_MULTIPLE = 64 N_MULTIPLE = 64
K_MULTIPLE = 128 K_MULTIPLE = 128
weight_dtype = weight.dtype
input_dtype = input.dtype
should_use_flashinfer = ( should_use_flashinfer = (
output_dtype == torch.bfloat16 output_dtype == torch.bfloat16
and input_dtype == torch.bfloat16 and input_dtype == torch.bfloat16
and weight_dtype == torch.float8_e4m3fn and weight_dtype == torch.float8_e4m3fn
and weight.shape[0] % N_MULTIPLE == 0 and weight_shape[0] % N_MULTIPLE == 0
and weight.shape[1] % K_MULTIPLE == 0 and weight_shape[1] % K_MULTIPLE == 0
) )
return should_use_flashinfer return should_use_flashinfer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment