Unverified Commit 2e9034c9 authored by Maral's avatar Maral Committed by GitHub
Browse files

[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8...


[W8A8 Block Linear Refactor][2/N] Remove W8A8Fp8BlockLinearOp and adopt Fp8 block linear kernel selections. (#33892)
Signed-off-by: default avatarmaral <maralbahari.98@gmail.com>
Signed-off-by: default avatarMaral <maralbahari.98@gmail.com>
parent 8332078c
......@@ -9,11 +9,12 @@ os.environ["VLLM_USE_DEEP_GEMM"] = "0"
import torch
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
......@@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
linear_op = init_fp8_linear_kernel(
weight_quant_key=create_fp8_quant_key(
static=True, group_shape=weight_group_shape
),
activation_quant_key=create_fp8_quant_key(
static=False, group_shape=act_quant_group_shape
),
out_dtype=torch.get_default_dtype(),
module_name="build_w8a8_block_fp8_runner",
)
def run():
......
......@@ -39,7 +39,9 @@ from vllm.utils.torch_utils import set_random_seed
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
......@@ -78,7 +80,9 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
......@@ -88,6 +92,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=dtype,
)
for i in range(3)
]
......@@ -127,7 +132,9 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
......@@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, dtype=dtype)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
......
......@@ -109,6 +109,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=self.vllm_config.model_config.dtype,
)
for i in range(3)
]
......
......@@ -23,6 +23,7 @@ from vllm.config import (
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -49,6 +50,7 @@ class TestSiluMul(torch.nn.Module):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, x):
......@@ -92,6 +94,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)
def forward(self, hidden_states, residual):
......
......@@ -9,7 +9,7 @@ import vllm.config
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
from tests.utils import TestFP8Layer
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
......@@ -28,19 +28,23 @@ from vllm.config import (
VllmConfig,
)
from vllm.model_executor.kernels.linear import (
AiterFp8BlockScaledMMKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel,
_KernelT,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
......@@ -66,9 +70,12 @@ CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
# Blockwise group shapes
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
# ROCm kernels
......@@ -80,8 +87,8 @@ ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]
KERNEL_GROUPSHAPE_COMBINATIONS = (
......@@ -100,8 +107,8 @@ AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
# Blockwise
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
]
......@@ -110,8 +117,9 @@ class TestModel(torch.nn.Module):
self,
hidden_size: int,
eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
force_kernel: type[_KernelT] | None,
group_shape: GroupShape,
dtype: torch.dtype,
use_aiter_fusion: bool = False,
use_aiter_quant: bool = False,
*args,
......@@ -129,54 +137,42 @@ class TestModel(torch.nn.Module):
is_blockwise = group_shape.is_per_group()
if is_blockwise:
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
block_size = group_shape.col
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=group_shape
)
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]
self.enable_quant_fp8_custom_op = (
False
if use_aiter_quant
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_size, block_size)
)
else:
is_static = group_shape == GroupShape.PER_TENSOR
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
self.activation_quant_key = create_fp8_quant_key(
is_static, group_shape=group_shape
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=group_shape
)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
)
for _ in range(3)
]
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
transpose_weights=use_aiter_fusion,
input_dtype=dtype,
)
for _ in range(3)
]
# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x):
# avoid having graph input be an arg to a pattern directly
......@@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=False,
use_aiter_quant=False,
)
......@@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
)
......
......@@ -66,6 +66,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
self.attn = Attention(
num_heads=self.num_qo_heads,
......@@ -155,6 +156,7 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")
......
......@@ -74,6 +74,7 @@ class MLAAttentionQuantPatternModel(torch.nn.Module):
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype
# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
......@@ -190,6 +191,7 @@ class TestMLAAttentionFp8StaticQuantPatternModel(MLAAttentionQuantPatternModel):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)
w = kwargs.get("w")
......
......@@ -36,9 +36,9 @@ from vllm.model_executor.kernels.linear import (
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
kFp8Dynamic128Sym,
kFp8StaticTensorSym,
kNvfp4Dynamic,
......@@ -58,7 +58,11 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
self,
hidden_size: int,
force_kernel: FP8ScaledMMLinearKernel,
dtype: torch.dtype,
**kwargs,
):
super().__init__()
self.silu_and_mul = SiluAndMul()
......@@ -68,6 +72,7 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
input_dtype=dtype,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
......@@ -137,14 +142,20 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, **kwargs):
act_quant_key = kFp8Dynamic128Sym
def __init__(self, hidden_size: int, dtype: torch.dtype, **kwargs):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=GroupShape(1, 128),
cutlass_block_fp8_supported=False,
use_aiter_and_is_supported=True,
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(hidden_size, hidden_size)
)
self.w8a8_block_fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
weight_quant_key=self.weight_quant_key,
activation_quant_key=self.act_quant_key,
input_dtype=dtype,
)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
......@@ -157,7 +168,7 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
x2 = self.w8a8_block_fp8_linear(y, self.w, self.wscale)
return x2
def ops_in_model_before(self):
......@@ -324,7 +335,9 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
model = model_class(
hidden_size=hidden_size, force_kernel=force_kernel, x=x, dtype=dtype
)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
......
......@@ -246,8 +246,9 @@ def default_vllm_config():
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield
config = VllmConfig()
with set_current_vllm_config(config):
yield config
@pytest.fixture()
......
......@@ -12,8 +12,8 @@ from tests.kernels.quant_utils import (
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import cutlass_scaled_mm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
)
......@@ -202,7 +202,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes are supported by deepgemm
if not should_use_deepgemm_for_fp8_linear(
output_dtype=out_dtype, weight=B_fp32, supports_deep_gemm=True
output_dtype=out_dtype, weight_shape=B_fp32.shape, supports_deep_gemm=True
):
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
......
......@@ -16,6 +16,9 @@ from compressed_tensors.quantization import (
)
from tests.models.utils import check_logprobs_close
from vllm.model_executor.kernels.linear import (
Fp8BlockScaledMMLinearKernel,
)
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig,
......@@ -29,7 +32,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported,
)
......@@ -473,16 +475,14 @@ def test_compressed_tensors_fp8_block_enabled(vllm_runner):
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(
qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
)
assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2
input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method in (
input_quant_op.forward_cuda,
......
......@@ -13,6 +13,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.config.model import ModelConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
......@@ -406,6 +407,8 @@ def test_fp8_reloading(
"If this is your use case, consider using a restore function like #26327"
)
# Set model config as model_config.dtype is required in Fp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with torch.device("cuda:0"):
config = Fp8Config(
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
......
......@@ -12,6 +12,7 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm.config.model import ModelConfig
@pytest.fixture(scope="function", autouse=True)
......@@ -46,7 +47,7 @@ def _snapshot_download_or_skip(model_id: str) -> str:
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publicly available test checkpoint
model_path = (
......@@ -61,6 +62,8 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
"This test requires a local ModelOpt FP8 checkpoint."
)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
......@@ -120,11 +123,13 @@ def test_modelopt_fp8_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pc_pt_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pc-pt"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
......@@ -181,11 +186,13 @@ def test_modelopt_fp8_pc_pt_checkpoint_setup(vllm_runner):
not is_quant_method_supported("modelopt"),
reason="ModelOpt FP8 is not supported on this GPU type.",
)
def test_modelopt_fp8_pb_wo_checkpoint_setup(vllm_runner):
def test_modelopt_fp8_pb_wo_checkpoint_setup(default_vllm_config, vllm_runner):
"""Test ModelOpt FP8_PB_WO checkpoint setup."""
model_id = "CedricHwang/qwen2.5-0.5b-modelopt-fp8-pb-wo"
model_path = _snapshot_download_or_skip(model_id)
# Set model config as model_config.dtype is required in ModelOptFp8LinearMethod.
default_vllm_config.model_config = ModelConfig()
with vllm_runner(model_path, quantization="modelopt", enforce_eager=True) as llm:
def check_model(model):
......
......@@ -43,12 +43,10 @@ from vllm.distributed import (
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.kernels.linear import (
FP8ScaledMMLinearKernel,
_KernelT,
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.model_loader import get_model_loader
......@@ -1811,31 +1809,52 @@ class TestFP8Layer(torch.nn.Module):
weight_shape: tuple[int, int],
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
input_dtype: torch.dtype,
out_dtype: torch.dtype | None = None,
transpose_weights: bool = False,
device: torch.device | None = None,
force_kernel: FP8ScaledMMLinearKernel | None = None,
force_kernel: type[_KernelT] | None = None,
):
super().__init__()
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
is_static_activation_scale = activation_quant_key.scale.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
self.input_scale_ub = None
act_scale_desc = activation_quant_key.scale
weight_scale_desc = weight_quant_key.scale
is_block_wise = act_scale_desc.group_shape.is_per_group()
if is_block_wise:
block_size = weight_scale_desc.group_shape.col
weight_scale_shape = weight_shape[0] // block_size
self.weight_scale_inv = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
self.weight_scale = None
if transpose_weights:
self.weight = self.weight.t()
else:
per_tensor_weights = weight_scale_desc.group_shape.is_per_tensor()
is_static_activation_scale = act_scale_desc.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale_inv = None
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = (
torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
)
self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
force_kernel=force_kernel,
)
......@@ -1847,61 +1866,3 @@ class TestFP8Layer(torch.nn.Module):
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class TestBlockFP8Layer:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
group_shape: GroupShape,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
transpose_weights: bool = False,
):
weight_scale_shape = weight_shape[0] // group_shape[1]
self.weight_scale = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
if transpose_weights:
self.weight = self.weight.t()
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(
input=y,
weight=self.weight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
bias=bias,
)
def is_quant_fp8_enabled(self) -> bool:
return self.linear_op.input_quant_op.enabled()
......@@ -1002,11 +1002,11 @@ class VllmBackend:
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
hash_content.append(filepath)
try:
with open(filepath) as f:
hash_content.append(f.read())
......
......@@ -19,6 +19,10 @@ import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.base import (
MMLinearKernel,
MMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision import (
MPLinearKernel,
MPLinearLayerConfig,
......@@ -52,24 +56,30 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel,
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.deep_gemm import (
DeepGemmFp8BlockScaledMMKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.marlin import (
......@@ -84,6 +94,7 @@ from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonFp8BlockScaledMMKernel,
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
......@@ -128,6 +139,23 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
],
}
# in priority/performance order (when available)
_POSSIBLE_FP8_BLOCK_KERNELS: dict[
PlatformEnum, list[type[Fp8BlockScaledMMLinearKernel]]
] = {
PlatformEnum.CUDA: [
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
DeepGemmFp8BlockScaledMMKernel,
CutlassFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
PlatformEnum.ROCM: [
AiterFp8BlockScaledMMKernel,
TritonFp8BlockScaledMMKernel,
],
}
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
......@@ -152,8 +180,10 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
# TODO make all kernels inherit from MMLinearKernel
# then bound _KernelT only to MMLinearKernel
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel | MMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=MMLinearLayerConfig)
def is_supported_and_can_implement_kernel(
......@@ -243,32 +273,61 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
force_kernel: type[_KernelT] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if activation_quant_key.scale.group_shape.is_per_group():
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_BLOCK_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
return kernel_type(
scaled_mm_linear_kernel_config,
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
else:
kernel_type = choose_scaled_mm_linear_kernel(
config=scaled_mm_linear_kernel_config,
possible_kernels=_POSSIBLE_FP8_KERNELS, # type: ignore[misc]
force_kernel=force_kernel,
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_scale_ub",
],
)
def init_int8_linear_kernel(
......@@ -433,4 +492,7 @@ __all__ = [
"MarlinLinearKernel",
"XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel",
"_KernelT",
"DeepGemmFp8BlockScaledMMKernel",
"FlashInferFp8DeepGEMMDynamicBlockScaledKernel",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, TypeVar
import torch
from typing_extensions import Self
@dataclass
class MMLinearLayerConfig: ...
@dataclass
class Params:
"""Base class for quantized layer parameters.
This class provides a typed interface for accessing quantized weights and scales
from layer modules. It serves as a parameter container that can be extracted from
layers and passed to kernel implementations.
Attributes:
weight: The quantized weight tensor
weight_scale: weight scaling factors
input_scale: Optional input scaling factors
Class Variables:
WEIGHT: Attribute name for weight tensor on the layer module
WEIGHT_SCALE: Attribute name for weight scale tensor on the layer module
INPUT_SCALE: Attribute name for input scale tensor on the layer module
Important:
The string values of WEIGHT, WEIGHT_SCALE, and INPUT_SCALE class variables
MUST match the attribute names used in the corresponding quantization method's
create_weights() implementation.
For example, if FP8LinearMethod.create_weights()
sets layer.weight and layer.weight_scale,
then WEIGHT="weight" and
WEIGHT_SCALE="weight_scale" must be used here.
Usage:
```python
# Extract parameters from a quantized layer
params = Params.from_layer(layer)
# Access typed parameters
output = func(input, params.weight, params.weight_scale)
```
"""
weight: torch.Tensor
weight_scale: torch.Tensor
input_scale: torch.Tensor | None
# Attribute names on the layer
WEIGHT: ClassVar[str] = "weight"
WEIGHT_SCALE: ClassVar[str] = "weight_scale"
INPUT_SCALE: ClassVar[str] = "input_scale"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> Self:
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
)
@dataclass
class FP8Params(Params):
"""FP8 layer parameters with typed fields"""
input_scale_ub: torch.Tensor | None
INPUT_SCALE_UB: ClassVar[str] = "input_scale_ub"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> "FP8Params":
"""Extract parameters from layer"""
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
)
@dataclass
class Int8Params(Params):
"""Int8 layer parameters with typed fields"""
input_zero_point: torch.Tensor | None
azp_adj: torch.Tensor | None
INPUT_ZERO_POINT: ClassVar[str] = "input_zero_point"
AZP_ADJ: ClassVar[str] = "azp_adj"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> "Int8Params":
"""Extract parameters from layer"""
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale=getattr(layer, cls.WEIGHT_SCALE),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_zero_point=getattr(layer, cls.INPUT_ZERO_POINT, None),
azp_adj=getattr(layer, cls.AZP_ADJ, None),
)
_ParamsT = TypeVar("_ParamsT", bound=Params)
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
class MMLinearKernel(ABC, Generic[_ConfigT, _ParamsT]):
"""Abstract base class for quantized matrix multiplication kernels.
This class provides the interface for implementing custom quantized linear layer
kernels in vLLM. Subclasses should implement specific quantization strategies
(e.g., FP8, INT8) and their corresponding compute kernels.
Generic Type Parameters:
_ConfigT: Configuration type for the kernel (subclass of MMLinearLayerConfig).
Contains kernel-specific settings like quantization keys, dtypes, etc.
_ParamsT: Parameter type for the kernel (subclass of Params).
Defines the quantized weights and scales needed by the kernel.
Typical Usage:
1. Define a config dataclass inheriting from MMLinearLayerConfig
2. Define a params dataclass inheriting from Params (or FP8Params/Int8Params)
3. Subclass MMLinearKernel with your config and params types
4. Implement all abstract methods
5. Register the kernel with the quantization method
Example:
```python
@dataclass
class MyKernelConfig(MMLinearLayerConfig):
static: bool
output_dtype: torch.dtype
@dataclass
class MyKernelParams(FP8Params):
custom_scale: torch.Tensor
CUSTOM_SCALE: ClassVar[str] = "custom_scale"
class MyKernel(MMLinearKernel[MyKernelConfig, MyKernelParams]):
@classmethod
def is_supported(cls, compute_capability=None):
if compute_capability and compute_capability < 90:
return False, "Requires compute capability >= 9.0"
return True, None
@classmethod
def can_implement(cls, config):
if not config.static:
return False, "Only static quantization supported"
return True, None
def process_weights_after_loading(self, layer):
# Preprocess weights for the kernel
params = self._get_layer_params(layer)
processed = preprocess_weights(params.weight)
replace_parameter(layer, params.WEIGHT, processed)
def _get_layer_params(self, layer, **kwargs):
return MyKernelParams.from_layer(layer)
def apply_weights(self, layer, x, bias=None, **kwargs):
params = self._get_layer_params(layer)
# Call your custom kernel
output = my_custom_kernel(x, params.weight, params.weight_scale)
if bias is not None:
output += bias
return output
```
Lifecycle:
1. Kernel selection: is_supported() and can_implement() check compatibility
2. Initialization: __init__() creates kernel instance with config
3. Weight loading: process_weights_after_loading() preprocesses weights
4. Inference: apply_weights() executes the quantized matmul
"""
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
"""Check if this kernel is supported on the current hardware.
This method checks hardware-level compatibility (e.g., GPU architecture,
compute capability, available instructions). It's called during kernel
selection to filter out kernels that cannot run on the current device.
Args:
compute_capability: GPU compute capability (e.g., 80 for A100, 90 for H100).
If None, should check the current device.
Returns:
A tuple of (is_supported, reason):
- is_supported: True if the kernel can run on this hardware
- reason: If not supported, a string explaining why; otherwise None
"""
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, config: _ConfigT) -> tuple[bool, str | None]:
"""Check if this kernel can implement the given configuration.
This method checks configuration-level compatibility (e.g., quantization
scheme, group sizes, static vs dynamic quantization). It's called after
is_supported() to determine if this kernel can handle the specific
quantization configuration.
Args:
config: The kernel configuration to check
Returns:
A tuple of (can_implement, reason):
- can_implement: True if this kernel supports the config
- reason: If not supported, a string explaining why; otherwise None
```
"""
raise NotImplementedError
def __init__(self, config: _ConfigT) -> None:
"""Initialize the kernel with the given configuration.
Args:
config: Kernel-specific configuration containing settings like
quantization keys, output dtypes, etc.
"""
self.config = config
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Process and transform weights after loading from checkpoint.
This method is called once after weights are loaded but before inference.
Use it to preprocess weights into the format required by your kernel
(e.g., reordering, padding, format conversion).
Modifications should be done in-place using replace_parameter() to ensure
the layer's parameters are properly updated.
Args:
layer: The layer module containing the weights to process
Example:
```python
def process_weights_after_loading(self, layer):
params = self._get_layer_params(layer)
# Reorder weights for better memory access
weight_reordered = reorder_weights(params.weight)
replace_parameter(layer, params.WEIGHT, weight_reordered)
```
"""
raise NotImplementedError
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer: torch.nn.Module, **kwargs: Any) -> _ParamsT:
"""Extract typed parameters from the layer module.
This internal method retrieves the quantized weights and scales from
the layer as a typed parameter object. Subclasses should typically
delegate to ParamsClass.from_layer().
Args:
layer: The layer module containing the parameters
**kwargs: Additional arguments
Returns:
A typed parameter object containing weights, scales, and other
quantization parameters
Example:
```python
def _get_layer_params(self, layer, **kwargs):
return MyKernelParams.from_layer(layer)
```
"""
raise NotImplementedError
def get_output_padding(self) -> int | None:
"""Get the number of output tokens to pad for this kernel.
Some kernels require input padding for optimal performance.
Override this method to specify padding requirements.
Returns:
Number of tokens to pad, or None for no padding (default)
"""
return None
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
**kwargs: Any,
) -> torch.Tensor:
"""Apply the quantized weights to the input tensor.
This is the main inference method that performs the quantized matrix
multiplication. It should handle input quantization (if needed), call
the underlying kernel, and apply bias.
Args:
layer: The layer module containing the quantized weights
x: Input tensor of shape [..., in_features]
bias: Optional bias tensor of shape [out_features]
**kwargs: Additional kernel-specific arguments
Returns:
Output tensor of shape [..., out_features]
"""
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import ClassVar
import torch
from typing_extensions import Self
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_block_strategy,
)
from vllm.model_executor.utils import replace_parameter
from ..base import (
FP8Params,
MMLinearKernel,
)
from .ScaledMMLinearKernel import FP8ScaledMMLinearLayerConfig
@dataclass
class FP8BlockParams(FP8Params):
weight_scale_inv: torch.Tensor | None
weight_scale: torch.Tensor | None
WEIGHT_SCALE_INV: ClassVar[str] = "weight_scale_inv"
@classmethod
def from_layer(cls, layer: torch.nn.Module) -> Self:
return cls(
weight=getattr(layer, cls.WEIGHT),
weight_scale_inv=getattr(layer, cls.WEIGHT_SCALE_INV, None),
weight_scale=getattr(layer, cls.WEIGHT_SCALE, None),
input_scale=getattr(layer, cls.INPUT_SCALE, None),
input_scale_ub=getattr(layer, cls.INPUT_SCALE_UB, None),
)
class Fp8BlockScaledMMLinearKernel(
MMLinearKernel[FP8ScaledMMLinearLayerConfig, FP8BlockParams], ABC
):
# Set to False in subclasses that accept BF16 input directly (e.g. FlashInfer)
# and therefore do not need the input quantization step in apply_weights.
apply_input_quant: ClassVar[bool] = True
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.weight_group_shape = config.weight_quant_key.scale.group_shape
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
use_ue8m0=False,
)
self.use_triton = False
@classmethod
def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
act_quant_key = config.activation_quant_key
if act_quant_key.scale.static:
return (
False,
"Only dynamic per token group activation quantization is supported.",
)
return True, None
def _get_layer_params(self, layer: torch.nn.Module, **kwargs) -> FP8BlockParams:
return FP8BlockParams.from_layer(layer)
def process_weights_after_loading(self, layer: torch.nn.Module):
params = self._get_layer_params(layer)
# Fp8LinearMethod registered weight scale
# buffer as weight_scale_inv unlike compressed tensors.
weight_scale = (
params.weight_scale
if params.weight_scale_inv is None
else params.weight_scale_inv
)
scale_attr_name = (
params.WEIGHT_SCALE
if params.weight_scale_inv is None
else params.WEIGHT_SCALE_INV
)
new_weight, new_weight_scale = process_fp8_weight_block_strategy(
params.weight,
weight_scale,
)
replace_parameter(layer, params.WEIGHT, new_weight.data)
replace_parameter(layer, scale_attr_name, new_weight_scale.data)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
params = self._get_layer_params(layer)
weight = params.weight
weight_scale = (
params.weight_scale
if params.weight_scale_inv is None
else params.weight_scale_inv
)
input_scale = params.input_scale
scale_up = params.input_scale_ub
# View input as 2D matrix for fp8 methods
input_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], weight.shape[0]]
if self.apply_input_quant:
q_input, input_scale = self.quant_fp8(
input_2d, input_scale, scale_up, use_triton=self.use_triton
)
else:
q_input = input_2d
# Provide a concrete placeholder so apply_block_scaled_mm args are
# always Tensors. Subclasses with apply_input_quant=False must not
# use As in apply_block_scaled_mm.
input_scale = (
input_scale if input_scale is not None else input_2d.new_ones(1)
)
output = self.apply_block_scaled_mm(
A=q_input,
B=weight,
As=input_scale,
Bs=weight_scale,
)
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype).view(*output_shape)
@abstractmethod
def apply_block_scaled_mm(
self,
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
class Fp8BlockScaledDynamicMMLinearKernel(Fp8BlockScaledMMLinearKernel, ABC):
"""Dynamic FP8 block-scaled kernel that dispatches at runtime.
Extends Fp8BlockScaledMMLinearKernel to inherit apply_weights and overrides
apply_block_scaled_mm to dispatch between two sub-kernels using torch.cond.
Subclasses must define:
base_type: The primary kernel class.
fallback_type: The fallback kernel class.
"""
base_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
fallback_type: ClassVar[type[Fp8BlockScaledMMLinearKernel]]
def __init__(self, config: "FP8ScaledMMLinearLayerConfig") -> None:
super().__init__(config)
self.base = self.base_type(config)
self.fallback = self.fallback_type(config)
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
is_base_supported, reason_1 = cls.base_type.is_supported(compute_capability)
is_fallback_supported, reason_2 = cls.fallback_type.is_supported(
compute_capability
)
if is_base_supported and is_fallback_supported:
return True, None
if not is_base_supported and not is_fallback_supported:
return (
False,
f"base is not supported due to {reason_1}; "
f"fallback is not supported due to {reason_2}",
)
if not is_base_supported:
return False, f"base is not supported due to {reason_1}"
return False, f"fallback is not supported due to {reason_2}"
@classmethod
def can_implement(
cls, config: "FP8ScaledMMLinearLayerConfig"
) -> tuple[bool, str | None]:
can_implement_base, reason_1 = cls.base_type.can_implement(config)
can_implement_fallback, reason_2 = cls.fallback_type.can_implement(config)
if can_implement_base and can_implement_fallback:
return True, None
if not can_implement_base and not can_implement_fallback:
return (
False,
f"base cannot implement due to {reason_1}; "
f"fallback cannot implement due to {reason_2}",
)
if not can_implement_base:
return False, f"base cannot implement due to {reason_1}"
return False, f"fallback cannot implement due to {reason_2}"
......@@ -14,14 +14,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
@dataclass
class ScaledMMLinearLayerConfig:
pass
from ..base import MMLinearLayerConfig
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class Int8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool
is_channelwise: bool
......@@ -29,10 +26,12 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class FP8ScaledMMLinearLayerConfig(MMLinearLayerConfig):
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
weight_shape: tuple[int, int]
input_dtype: torch.dtype
out_dtype: torch.dtype
_FP8ParamsT = tuple[
......@@ -50,7 +49,7 @@ _Int8ParamsT = tuple[
]
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
_ConfigT = TypeVar("_ConfigT", bound=MMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
......
......@@ -4,6 +4,9 @@
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
......@@ -31,7 +34,6 @@ from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
......@@ -55,4 +57,5 @@ __all__ = [
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"Fp8BlockScaledMMLinearKernel",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment