Unverified Commit 148117ea authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent e9c83cdc
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
...@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import ( ...@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
Fp8LinearOp, kFp8StaticTensorSym,
GroupShape,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from ...utils import has_module_attribute, multi_gpu_test from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from ..backend import TestBackend from ..backend import TestBackend
...@@ -76,49 +75,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module): ...@@ -76,49 +75,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, token_num=16, eps=1e-6): def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.fp8_linear_layers = [
self.w = [ TestFP8Layer(
torch.rand(hidden_size, hidden_size) weight_shape=(hidden_size, hidden_size),
.to(dtype=current_platform.fp8_dtype()) activation_quant_key=self.quant_key,
.t() weight_quant_key=self.quant_key,
for _ in range(3) )
for i in range(3)
] ]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states): def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states) z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z) x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x) y = self.norm[0](x)
z2 = self.fp8_linear.apply( z2 = self.fp8_linear_layers[0](y)
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2) x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply( z3 = self.fp8_linear_layers[1](y2)
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3) x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply( z4 = self.fp8_linear_layers[2](y3)
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4) x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here
return y4 return y4
...@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): ...@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return [ return [
torch.ops.vllm.all_reduce.default, torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled() if self.fp8_linear_layers[0].is_quant_fp8_enabled()
else torch.ops.aten.reciprocal.default, else torch.ops.aten.reciprocal.default,
] ]
......
...@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import ( ...@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp kFp8StaticTensorSym,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from ...utils import multi_gpu_test from ...utils import TestFP8Layer, multi_gpu_test
from ..backend import TestBackend from ..backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module): ...@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, eps=1e-6): def __init__(self, hidden_size=16, eps=1e-6):
super().__init__() super().__init__()
self.vllm_config = get_current_vllm_config() self.vllm_config = get_current_vllm_config()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.eps = eps self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.fp8_linear_layers = [
self.w = [ TestFP8Layer(
torch.rand(hidden_size, hidden_size) weight_shape=(hidden_size, hidden_size),
.to(dtype=current_platform.fp8_dtype()) activation_quant_key=self.quant_key,
.t() weight_quant_key=self.quant_key,
for _ in range(3) )
for i in range(3)
] ]
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states): def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states) z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z) x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x) y = self.norm[0](x)
z2 = self.fp8_linear.apply( z2 = self.fp8_linear_layers[0](y)
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
x2 = tensor_model_parallel_all_reduce(z2) x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply( z3 = self.fp8_linear_layers[1](y2)
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
x3 = tensor_model_parallel_all_reduce(z3) x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply( z4 = self.fp8_linear_layers[2](y3)
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
x4 = tensor_model_parallel_all_reduce(z4) x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here
return y4 return y4
...@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module): ...@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return [ return [
torch.ops._C.fused_add_rms_norm.default, torch.ops._C.fused_add_rms_norm.default,
] ]
elif self.fp8_linear.quant_fp8.enabled(): elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
return [ return [
torch.ops._C.static_scaled_fp8_quant.default, torch.ops._C.static_scaled_fp8_quant.default,
] ]
......
...@@ -20,11 +20,13 @@ from vllm.config import ( ...@@ -20,11 +20,13 @@ from vllm.config import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp kFp8StaticTensorSym,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import TestFP8Layer
from .backend import TestBackend from .backend import TestBackend
TEST_FP8 = current_platform.supports_fp8() TEST_FP8 = current_platform.supports_fp8()
...@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype() ...@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
class TestSiluMul(torch.nn.Module): class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size: int = 128): def __init__(self, hidden_size: int = 128):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8: if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = TestFP8Layer(
self.fp8_linear = Fp8LinearOp( weight_shape=(hidden_size, hidden_size),
act_quant_static=True, activation_quant_key=self.quant_key,
act_quant_group_shape=GroupShape.PER_TENSOR, weight_quant_key=self.quant_key,
) )
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
if TEST_FP8: if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return self.fp8_linear(y)
return x2
else: else:
return y return y
...@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module): ...@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
class TestFusedAddRMSNorm(torch.nn.Module): class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym
def __init__(self, hidden_size=16, intermediate_size=32): def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module): ...@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8: if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, intermediate_size),
self.scale = torch.rand(1, dtype=torch.float32) activation_quant_key=self.quant_key,
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() weight_quant_key=self.quant_key,
self.wscale = torch.rand(1, dtype=torch.float32) )
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
# Reshape input # Reshape input
...@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module): ...@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if TEST_FP8: if TEST_FP8:
# scaled_mm with static input quantization # scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply( fp8_linear_result = self.fp8_linear(norm_output)
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output return fp8_linear_result, residual_output
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import pytest import pytest
import torch import torch
import vllm.config
import vllm.plugins import vllm.plugins
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
...@@ -20,8 +21,22 @@ from vllm.config import ( ...@@ -20,8 +21,22 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
W8A8BlockFp8LinearOp, CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
...@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaleDesc, ScaleDesc,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.deep_gemm import (
is_deep_gemm_supported,
)
from ..utils import override_cutlass_fp8_supported from ..utils import TestBlockFP8Layer, TestFP8Layer
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -45,157 +59,195 @@ FP8_DTYPE = current_platform.fp8_dtype() ...@@ -45,157 +59,195 @@ FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP = torch.ops._C.rms_norm.default RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS = [
# FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
(FlashInferFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
(CutlassFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
]
# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS = [
# ROCmFP8ScaledMMLinearKernel supports per-tensor only
(ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
]
KERNEL_GROUPSHAPE_COMBINATIONS = (
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
if current_platform.is_cuda()
else ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)
# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS = [
# Per-token with ROCmFP8ScaledMMLinearKernel
(ROCmFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR, False),
# Per-token with RowWiseTorchFP8ScaledMMLinearKernel
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(RowWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
]
class TestModel(torch.nn.Module): class TestModel(torch.nn.Module):
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
eps: float, eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
group_shape: GroupShape, group_shape: GroupShape,
use_aiter: bool = False, use_aiter_fusion: bool = False,
cuda_force_torch: bool = False, use_aiter_quant: bool = False,
use_aiter_quant_op: bool = True,
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.use_aiter = use_aiter self.fp8_linear_layers: list[torch.nn.Module]
self.use_aiter_quant_op = use_aiter_quant_op
self.cuda_force_torch = cuda_force_torch
self.group_shape = group_shape self.group_shape = group_shape
self.enable_quant_fp8_custom_op = None # Will be set later if applicable self.use_aiter_quant_op = use_aiter_quant
self.use_aiter_fusion = use_aiter_fusion
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.enable_rms_norm_custom_op = self.norm[0].enabled()
# Setup quantization scale descriptor # Determine if blockwise based on group_shape
static = group_shape == GroupShape.PER_TENSOR and not use_aiter is_blockwise = group_shape.is_per_group()
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
# Setup scales if is_blockwise:
if static: act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.activation_quant_key = QuantKey(
else: dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
self.scale = [None for _ in range(3)] )
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]
# Setup weights self.enable_quant_fp8_custom_op = (
self.w = [ False
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3) if use_aiter_quant
] else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
if not group_shape.is_per_group() or use_aiter:
self.w = [self.w[0].t() for _ in range(3)]
# Setup weight scales
if group_shape.is_per_group():
scale_size = (
(hidden_size + 128 - 1) // 128
if use_aiter
else hidden_size // group_shape[1]
) )
wscale_shape: tuple[int, ...] = (scale_size, scale_size)
else: else:
wscale_shape = (1,) is_static = group_shape == GroupShape.PER_TENSOR
self.wscale = [torch.rand(wscale_shape, dtype=torch.float32) for _ in range(3)] act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
# Setup FP8 linear operation self.activation_quant_key = QuantKey(
is_per_group = group_shape.is_per_group() dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
if is_per_group and use_aiter:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(128, 128),
act_quant_group_shape=group_shape,
use_aiter_and_is_supported=use_aiter_quant_op,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif is_per_group:
self.fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=False,
) )
self.enable_quant_fp8_custom_op = self.fp8_linear.input_quant_op.enabled() self.weight_quant_key = QuantKey(
elif use_aiter: dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=group_shape,
) )
self.fp8_linear.quant_fp8.use_aiter = use_aiter_quant_op self.fp8_linear_layers = [
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() TestFP8Layer(
else: weight_shape=(hidden_size, hidden_size),
with override_cutlass_fp8_supported(not cuda_force_torch): activation_quant_key=self.activation_quant_key,
self.fp8_linear = Fp8LinearOp( weight_quant_key=self.weight_quant_key,
act_quant_static=static, force_kernel=force_kernel,
act_quant_group_shape=group_shape,
) )
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() for _ in range(3)
]
self.enable_rms_norm_custom_op = self.norm[0].enabled() # Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
def forward(self, x): def forward(self, x):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
x = resid = torch.relu(x) x = resid = torch.relu(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply( x2 = self.fp8_linear_layers[0](y)
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply( x3 = self.fp8_linear_layers[1](y2)
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply( x4 = self.fp8_linear_layers[2](y3)
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here
return y4 return y4
def ops_in_model_before(self): def ops_in_model_before(self):
if ( if self.group_shape.is_per_group():
self.use_aiter # Blockwise path
and self.group_shape.is_per_group() if self.use_aiter_fusion and self.use_aiter_quant_op:
and current_platform.is_fp8_fnuz() return [rocm_aiter_ops.get_group_quant_op()]
): if self.use_aiter_fusion:
return [rocm_aiter_ops.get_group_quant_op()] return [torch.ops.vllm.triton_per_token_group_quant_fp8.default]
if self.use_aiter and self.group_shape.is_per_group(): else:
return [torch.ops.vllm.triton_per_token_group_quant_fp8.default] if self.use_aiter_quant_op:
if self.use_aiter and self.use_aiter_quant_op: return [rocm_aiter_ops.get_per_token_quant_op()]
return [rocm_aiter_ops.get_per_token_quant_op()]
if self.use_aiter: # Common path
return [QUANT_OPS[self.quant_key]] return (
if self.enable_quant_fp8_custom_op: [QUANT_OPS[self.activation_quant_key]]
return [QUANT_OPS[self.quant_key]] if self.enable_quant_fp8_custom_op
return [torch.ops.aten.reciprocal] else [torch.ops.aten.reciprocal]
)
def ops_in_model_after(self): def ops_in_model_after(self):
if self.use_aiter and self.group_shape.is_per_group(): if self.use_aiter_fusion:
from vllm.compilation.rocm_aiter_fusion import ( if self.group_shape.is_per_group():
AiterFusedAddRMSFp8GroupQuantPattern, # Blockwise aiter fusion
AiterRMSFp8GroupQuantPattern, from vllm.compilation.rocm_aiter_fusion import (
) AiterFusedAddRMSFp8GroupQuantPattern,
AiterRMSFp8GroupQuantPattern,
)
return [ return [
AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP, AiterFusedAddRMSFp8GroupQuantPattern.FUSED_OP,
AiterRMSFp8GroupQuantPattern.FUSED_OP, AiterRMSFp8GroupQuantPattern.FUSED_OP,
] ]
if self.use_aiter: else:
from vllm.compilation.rocm_aiter_fusion import ( # Per-token aiter fusion
AiterFusedAddRMSNormDynamicQuantPattern, from vllm.compilation.rocm_aiter_fusion import (
AiterRMSNormDynamicQuantPattern, AiterFusedAddRMSNormDynamicQuantPattern,
) AiterRMSNormDynamicQuantPattern,
)
return [ return [
AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP, AiterFusedAddRMSNormDynamicQuantPattern.FUSED_OP,
AiterRMSNormDynamicQuantPattern.FUSED_OP, AiterRMSNormDynamicQuantPattern.FUSED_OP,
] ]
# Regular fusion
return [ return [
FUSED_OPS[FusedRMSQuantKey(self.quant_key, True)], FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, True)],
FUSED_OPS[FusedRMSQuantKey(self.quant_key, False)], FUSED_OPS[FusedRMSQuantKey(self.activation_quant_key, False)],
] ]
def ops_in_model_before_partial(self): def ops_in_model_before_partial(self):
...@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module): ...@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module):
) )
GROUP_SHAPES = [
GroupShape.PER_TOKEN,
GroupShape.PER_TENSOR,
GroupShape(1, 128),
GroupShape(1, 64),
]
def _run_fusion_test( def _run_fusion_test(
model, model,
fusion_pass, fusion_pass,
...@@ -259,14 +303,9 @@ def _run_fusion_test( ...@@ -259,14 +303,9 @@ def _run_fusion_test(
@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("group_shape", GROUP_SHAPES) @pytest.mark.parametrize("kernel_groupshape", KERNEL_GROUPSHAPE_COMBINATIONS)
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False]) @pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) @pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm" not current_platform.is_cuda_alike(), reason="Only test on CUDA and ROCm"
) )
...@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant( ...@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant(
hidden_size, hidden_size,
num_tokens, num_tokens,
eps, eps,
group_shape, kernel_groupshape,
enable_rms_norm_custom_op, enable_rms_norm_custom_op,
enable_quant_fp8_custom_op, enable_quant_fp8_custom_op,
cuda_force_torch,
): ):
force_kernel, group_shape = kernel_groupshape
if not enable_quant_fp8_custom_op and group_shape.is_per_group(): if not enable_quant_fp8_custom_op and group_shape.is_per_group():
pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization") pytest.skip("Unsupported unwrapped quant fp8 op for blockwise quantization")
...@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant( ...@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RMSNormQuantFusionPass(vllm_config) fusion_pass = RMSNormQuantFusionPass(vllm_config)
model = TestModel( model = TestModel(
hidden_size=hidden_size, hidden_size=hidden_size,
eps=eps, eps=eps,
force_kernel=force_kernel,
group_shape=group_shape, group_shape=group_shape,
use_aiter=False, use_aiter_fusion=False,
cuda_force_torch=cuda_force_torch, use_aiter_quant=False,
) )
backend, _ = _run_fusion_test( backend, _ = _run_fusion_test(
...@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant( ...@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant(
assert n_add_nodes(backend.graph_post_pass) == 2 assert n_add_nodes(backend.graph_post_pass) == 2
GROUP_SHAPE_QUANT_OPS_MATCHS = [
(GroupShape.PER_TOKEN, True),
(GroupShape.PER_TOKEN, False),
(GroupShape(1, 128), True),
]
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [256]) @pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) @pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"group_shape, use_aiter_quant_op", GROUP_SHAPE_QUANT_OPS_MATCHS "kernel_groupshape_quant", AITER_KERNEL_GROUPSHAPE_COMBINATIONS
) )
@pytest.mark.skipif( @pytest.mark.skipif(
(not current_platform.is_rocm() or not IS_AITER_FOUND), (not current_platform.is_rocm() or not IS_AITER_FOUND),
...@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant( ...@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant(
hidden_size: int, hidden_size: int,
num_tokens: int, num_tokens: int,
eps: float, eps: float,
group_shape: GroupShape, kernel_groupshape_quant: tuple,
use_aiter_quant_op: bool,
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
): ):
force_kernel, group_shape, use_aiter_quant_op = kernel_groupshape_quant
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=ModelConfig(dtype=dtype), model_config=ModelConfig(dtype=dtype),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
...@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant( ...@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant(
from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass from vllm.compilation.rocm_aiter_fusion import RocmAiterRMSNormFusionPass
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
rocm_aiter_ops.refresh_env_variables() rocm_aiter_ops.refresh_env_variables()
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(1) torch.manual_seed(1)
maybe_create_device_identity()
fusion_pass = RocmAiterRMSNormFusionPass(vllm_config) fusion_pass = RocmAiterRMSNormFusionPass(vllm_config)
model = TestModel( model = TestModel(
hidden_size=hidden_size, hidden_size=hidden_size,
eps=eps, eps=eps,
force_kernel=force_kernel,
group_shape=group_shape, group_shape=group_shape,
use_aiter=True, use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant_op=use_aiter_quant_op, use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
) )
_run_fusion_test( _run_fusion_test(
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata ...@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from ..utils import TestFP8Layer
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
...@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): ...@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape,
)
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.fp8_linear = TestFP8Layer(
"w", weight_shape=(hidden_size, hidden_size),
{ activation_quant_key=self.quant_key,
"weight": torch.randn(hidden_size, hidden_size) weight_quant_key=self.quant_key,
.to(dtype=FP8_DTYPE, device=self.device) device=self.device,
.t(),
"wscale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
},
) )
w = kwargs.get("w")
if w is not None:
self.fp8_linear.weight = w["weight"]
self.fp8_linear.weight_scale = w["wscale"]
self.fp8_linear.input_scale = w["scale"]
self.w = {
"weight": self.fp8_linear.weight,
"wscale": self.fp8_linear.weight_scale,
"scale": self.fp8_linear.input_scale,
}
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
return self.fp8_linear.apply( return self.fp8_linear(attn_output)
input=attn_output,
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
......
...@@ -25,19 +25,30 @@ from vllm.config import ( ...@@ -25,19 +25,30 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
PerTensorTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported from ..utils import TestFP8Layer
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
...@@ -49,25 +60,27 @@ def is_nvfp4_supported(): ...@@ -49,25 +60,27 @@ def is_nvfp4_supported():
class TestSiluMulFp8QuantModel(torch.nn.Module): class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): quant_key = kFp8StaticTensorSym
def __init__(
self, hidden_size: int, force_kernel: FP8ScaledMMLinearKernel, **kwargs
):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
force_kernel=force_kernel,
)
with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) x2 = self.fp8_linear(y)
return x2 return x2
def ops_in_model_before(self): def ops_in_model_before(self):
...@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module): ...@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant] return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
ROCM_KERNELS = [ROCmFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel]
CUDA_KERNELS = [
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
]
TEST_KERNELS = ROCM_KERNELS if current_platform.is_rocm() else CUDA_KERNELS
@pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_class, enable_quant_fp8_custom_op, cuda_force_torch", "model_class, enable_quant_fp8_custom_op, force_kernel",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS))
+ [ + [
(TestSiluMulNvfp4QuantModel, False, False), (TestSiluMulNvfp4QuantModel, False, None),
(TestSiluMulGroupFp8QuantModel, False, False), (TestSiluMulGroupFp8QuantModel, False, None),
], ],
) )
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.skipif( @pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
) )
...@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant( ...@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
], ],
enable_silu_mul_custom_op: bool, enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool, enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool, force_kernel: FP8ScaledMMLinearKernel | None,
): ):
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.") pytest.skip("NVFP4 is not supported on this GPU.")
...@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant( ...@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
maybe_create_device_identity()
x = torch.rand(num_tokens, hidden_size * 2) x = torch.rand(num_tokens, hidden_size * 2)
...@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant( ...@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
backend = TestBackend(*passes) backend = TestBackend(*passes)
model = model_class( model = model_class(hidden_size=hidden_size, force_kernel=force_kernel, x=x)
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)
# First dimension dynamic # First dimension dynamic
torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(x, 0)
......
...@@ -11,13 +11,13 @@ from abc import ABC ...@@ -11,13 +11,13 @@ from abc import ABC
import pytest import pytest
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, Int8ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearKernel,
...@@ -33,36 +33,38 @@ def test_is_supported_is_abstract(): ...@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
def test_cpu_kernel_implements_is_supported(): def test_cpu_kernel_implements_is_supported():
"""Test that CPUScaledMMLinearKernel implements is_supported() method.""" """Test that CPUInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), ( assert hasattr(CPUInt8ScaledMMLinearKernel, "is_supported"), (
"CPUScaledMMLinearKernel missing is_supported() method" "CPUInt8ScaledMMLinearKernel missing is_supported() method"
) )
# Verify it's a classmethod by checking if it can be called with the class # Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type # and by checking the method type
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction( assert inspect.ismethod(
CPUScaledMMLinearKernel.is_supported CPUInt8ScaledMMLinearKernel.is_supported
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod" ) or inspect.isfunction(CPUInt8ScaledMMLinearKernel.is_supported), (
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod # Verify it can be called as a classmethod
result, reason = CPUScaledMMLinearKernel.is_supported() result, reason = CPUInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool" assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None" assert reason is None or isinstance(reason, str), "reason should be str or None"
def test_aiter_kernel_implements_is_supported(): def test_aiter_kernel_implements_is_supported():
"""Test that AiterScaledMMLinearKernel implements is_supported() method.""" """Test that AiterInt8ScaledMMLinearKernel implements is_supported() method."""
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), ( assert hasattr(AiterInt8ScaledMMLinearKernel, "is_supported"), (
"AiterScaledMMLinearKernel missing is_supported() method" "AiterInt8ScaledMMLinearKernel missing is_supported() method"
) )
# Verify it's a classmethod by checking if it can be called with the class # Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type # and by checking the method type
assert inspect.ismethod( assert inspect.ismethod(
AiterScaledMMLinearKernel.is_supported AiterInt8ScaledMMLinearKernel.is_supported
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), ( ) or inspect.isfunction(AiterInt8ScaledMMLinearKernel.is_supported), (
"AiterScaledMMLinearKernel.is_supported() should be a classmethod" "AiterInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
) )
# Verify it can be called as a classmethod # Verify it can be called as a classmethod
# (will return False on CPU, which is expected) # (will return False on CPU, which is expected)
result, reason = AiterScaledMMLinearKernel.is_supported() result, reason = AiterInt8ScaledMMLinearKernel.is_supported()
assert isinstance(result, bool), "is_supported() should return a bool" assert isinstance(result, bool), "is_supported() should return a bool"
assert reason is None or isinstance(reason, str), "reason should be str or None" assert reason is None or isinstance(reason, str), "reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm # On CPU, it should return False with a reason about requiring ROCm
...@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported(): ...@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
def test_cpu_kernel_accepts_all_configs(): def test_cpu_kernel_accepts_all_configs():
"""Test that CPUScaledMMLinearKernel accepts all config combinations.""" """Test that CPUInt8ScaledMMLinearKernel accepts all config combinations."""
configs = [ configs = [
ScaledMMLinearLayerConfig( Int8ScaledMMLinearLayerConfig(
is_channelwise=False, is_channelwise=False,
is_static_input_scheme=True, is_static_input_scheme=True,
input_symmetric=True, input_symmetric=True,
), ),
ScaledMMLinearLayerConfig( Int8ScaledMMLinearLayerConfig(
is_channelwise=True, is_channelwise=True,
is_static_input_scheme=False, is_static_input_scheme=False,
input_symmetric=False, input_symmetric=False,
...@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs(): ...@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
] ]
for config in configs: for config in configs:
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config) can_impl, reason = CPUInt8ScaledMMLinearKernel.can_implement(config)
assert can_impl, ( assert can_impl, (
f"CPUScaledMMLinearKernel should accept config {config}: {reason}" f"CPUInt8ScaledMMLinearKernel should accept config {config}: {reason}"
) )
...@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [ ...@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
] ]
# TritonScaledMMLinearKernel only supports symmetric quantization. # TritonInt8ScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
......
...@@ -42,6 +42,17 @@ from vllm.distributed import ( ...@@ -42,6 +42,17 @@ from vllm.distributed import (
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.cli.serve import ServeSubcommand from vllm.entrypoints.cli.serve import ServeSubcommand
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
)
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
...@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes ...@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
FP8_DTYPE = current_platform.fp8_dtype()
if current_platform.is_rocm(): if current_platform.is_rocm():
from amdsmi import ( from amdsmi import (
amdsmi_get_gpu_vram_usage, amdsmi_get_gpu_vram_usage,
...@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]): ...@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
for element in itertools.product(*iterables): for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element) normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized)) yield tuple(itertools.chain(*normalized))
class TestFP8Layer(torch.nn.Module):
"""
Test helper for FP8 linear operations. Creates random weights and scales
based on quantization configuration.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
activation_quant_key: Activation quantization configuration.
weight_quant_key: Weight quantization configuration.
out_dtype: Output dtype. Defaults to current default dtype.
force_kernel: Optional kernel to force use of specific implementation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype | None = None,
device: torch.device | None = None,
force_kernel: FP8ScaledMMLinearKernel | None = None,
):
super().__init__()
per_tensor_weights = weight_quant_key.scale.group_shape.is_per_tensor()
is_static_activation_scale = activation_quant_key.scale.static
weight_scale_shape = (1,) if per_tensor_weights else (weight_shape[0], 1)
self.weight_scale = torch.rand(
weight_scale_shape, dtype=torch.float32, device=device
)
self.input_scale = (
torch.rand(1, dtype=torch.float32, device=device)
if is_static_activation_scale
else None
)
self.weight = torch.rand(weight_shape, device=device).to(dtype=FP8_DTYPE).t()
self.input_scale_ub = None
out_dtype = torch.get_default_dtype() if out_dtype is None else out_dtype
self.kernel = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=out_dtype,
force_kernel=force_kernel,
)
def is_quant_fp8_enabled(self) -> bool:
return self.kernel.quant_fp8.enabled()
def forward(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.kernel.apply_weights(self, y, bias)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class TestBlockFP8Layer:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def __init__(
self,
weight_shape: tuple[int, int],
group_shape: GroupShape,
cutlass_block_fp8_supported: bool = False,
use_aiter_and_is_supported: bool = False,
transpose_weights: bool = False,
):
weight_scale_shape = weight_shape[0] // group_shape[1]
self.weight_scale = torch.rand(
(weight_scale_shape, weight_scale_shape), dtype=torch.float32
)
self.weight = torch.rand(weight_shape).to(dtype=FP8_DTYPE)
self.input_scale = None
if transpose_weights:
self.weight = self.weight.t()
self.linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
act_quant_group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported,
use_aiter_and_is_supported=use_aiter_and_is_supported,
)
def __call__(
self, y: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
return self.linear_op.apply(
input=y,
weight=self.weight,
weight_scale=self.weight_scale,
input_scale=self.input_scale,
bias=bias,
)
def is_quant_fp8_enabled(self) -> bool:
return self.linear_op.input_quant_op.enabled()
...@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl( ...@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl(
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K] # a to be [M, K]
# b to be [N, K] # b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
......
...@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate ...@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from torch.nn import Parameter from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, W8A8BlockFp8LinearOp,
create_fp8_input_scale, create_fp8_input_scale,
...@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( ...@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy, process_fp8_weight_tensor_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
maybe_create_device_identity,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BlockQuantScaleParameter, BlockQuantScaleParameter,
...@@ -42,6 +49,18 @@ strategy_to_parameter_type = { ...@@ -42,6 +49,18 @@ strategy_to_parameter_type = {
QuantizationStrategy.TENSOR: PerTensorScaleParameter, QuantizationStrategy.TENSOR: PerTensorScaleParameter,
} }
STATIC_QUANT = True
DYNAMIC_QUANT = False
activation_quant_key_mapping = {
STATIC_QUANT: kFp8StaticTensorSym,
DYNAMIC_QUANT: kFp8DynamicTokenSym,
}
weight_quant_key_mapping = {
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
}
logger = init_logger(__name__)
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
...@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = (
GroupShape.PER_TENSOR
if is_static_input_scheme
else GroupShape.PER_TOKEN
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
if self.weight_block_size is not None: if self.weight_block_size is not None:
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size), weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape, act_quant_group_shape=self.act_q_group_shape,
...@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
self.fp8_linear = Fp8LinearOp( activation_quant_key = activation_quant_key_mapping[is_static_input_scheme]
act_quant_static=self.is_static_input_scheme, weight_quant_key = weight_quant_key_mapping[self.strategy]
act_quant_group_shape=self.act_q_group_shape, self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=weight_quant_key,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
) )
@classmethod @classmethod
...@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.weight_block_size = None layer.weight_block_size = None
...@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
getattr(layer, "input_scale", None), getattr(layer, "input_scale", None),
) )
weight = weight.t() weight = weight.t()
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
...@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else: else:
layer.input_scale = None layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK: if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer) maybe_post_process_fp8_weight_block(layer)
...@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias=bias, bias=bias,
) )
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
...@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( ...@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, init_int8_linear_kernel,
choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BasevLLMParameter, BasevLLMParameter,
...@@ -25,8 +24,6 @@ logger = init_logger(__name__) ...@@ -25,8 +24,6 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__( def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
): ):
...@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
): ):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( self.kernel = init_int8_linear_kernel(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme, is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric, input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
) )
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT # WEIGHT
weight = ModelWeightParameter( weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
...@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = BasevLLMParameter( input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
) )
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric: if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype # Note: compressed-tensors stores the zp using the same dtype
# as the weights # as the weights
...@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point = BasevLLMParameter( input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
) )
layer.register_parameter("input_zero_point", input_zero_point)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type( layer.register_parameter("input_scale", input_scale)
c=scaled_mm_linear_kernel_config, if not hasattr(layer, "azp_adj"):
w_q_param_name="weight", layer.register_parameter("azp_adj", None)
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here. # different from the format the kernel may want. Handle repacking here.
......
...@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig): ...@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
class FBGEMMFp8LinearMethod(LinearMethodBase): class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def create_weights( def create_weights(
self, self,
...@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
del input_size, output_size del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
...@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
)
...@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
apply_fi_trtllm_fp8_per_tensor_moe, apply_fi_trtllm_fp8_per_tensor_moe,
...@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( ...@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTensorSym,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
cutlass_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size = self.quant_config.weight_block_size self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None self.block_quant = self.weight_block_size is not None
self.act_q_static = self.quant_config.activation_scheme == "static" self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
if self.block_quant: if self.block_quant:
assert not self.act_q_static assert not self.act_q_static
assert self.weight_block_size is not None assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size), weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape, act_quant_group_shape=GroupShape(1, self.weight_block_size[0]),
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported,
) )
else: else:
self.fp8_linear = Fp8LinearOp( # Use per-token quantization for better perf if dynamic and cutlass
act_quant_static=self.act_q_static, if self.act_q_static:
act_quant_group_shape=self.act_q_group_shape, activation_quant_key = kFp8StaticTensorSym
elif cutlass_fp8_supported():
activation_quant_key = kFp8DynamicTokenSym
else:
activation_quant_key = kFp8DynamicTensorSym
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=activation_quant_key,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
def create_weights( def create_weights(
...@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
...@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase):
scale = create_fp8_input_scale(output_partition_sizes, weight_loader) scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"}) set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False): if getattr(layer, "_already_called_process_weights_after_loading", False):
...@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod(FusedMoEMethodBase):
......
...@@ -2,19 +2,58 @@ ...@@ -2,19 +2,58 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar
import torch import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
@dataclass @dataclass
class ScaledMMLinearLayerConfig: class ScaledMMLinearLayerConfig:
is_channelwise: bool pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool is_static_input_scheme: bool
is_channelwise: bool
input_symmetric: bool input_symmetric: bool
class ScaledMMLinearKernel(ABC): @dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
_FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
]
_Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def is_supported( def is_supported(
...@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC): ...@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC):
@classmethod @classmethod
@abstractmethod @abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
self, assert self.can_implement(c)[0]
c: ScaledMMLinearLayerConfig, assert self.is_supported()[0]
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c self.config = c
self.w_q_name = w_q_param_name self.layer_param_names = layer_param_names
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
@abstractmethod @abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC): ...@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def _get_weight_params( # return a covariant type in the subclass
self, layer: torch.nn.Module @abstractmethod
) -> tuple[ def _get_layer_params(self, layer) -> _ParamsT:
torch.Tensor, # weight raise NotImplementedError
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp class FP8ScaledMMLinearKernel(
torch.Tensor | None, # azp_adj ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
]: ):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
act_scale_descriptor = c.activation_quant_key.scale
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
)
self.fp8_dtype = current_platform.fp8_dtype()
super().__init__(c, layer_param_names)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _get_layer_params(self, layer) -> _FP8ParamsT:
w, w_s, x_s, x_s_ub = self.layer_param_names
return (
getattr(layer, w),
getattr(layer, w_s),
getattr(layer, x_s, None),
getattr(layer, x_s_ub, None),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
fp8_dtype = self.fp8_dtype
maybe_out_dtype = self.config.out_dtype
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != fp8_dtype:
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
x_s_ub,
)
return self.apply_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
@abstractmethod
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
raise NotImplementedError
def get_output_padding(self) -> int | None:
return None
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> _Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return ( return (
getattr(layer, self.w_q_name), getattr(layer, w_q),
getattr(layer, self.w_s_name), getattr(layer, w_s),
getattr(layer, self.i_s_name), getattr(layer, i_s, None),
getattr(layer, self.i_zp_name), getattr(layer, i_zp, None),
getattr(layer, self.azp_adj_name), getattr(layer, azp_adj, None),
) )
...@@ -2,76 +2,229 @@ ...@@ -2,76 +2,229 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from typing import TypeVar
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel, CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel, ScaledMMLinearKernel,
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
PlatformEnum.CUDA: [
CutlassInt8ScaledMMLinearKernel,
TritonInt8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel],
}
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], FlashInferFP8ScaledMMLinearKernel,
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
ROCmFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.CPU: [
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
} }
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
def is_supported_and_can_implement_kernel(
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
) -> tuple[bool, str]:
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
return False, f" {kernel.__name__} is disabled by environment variable"
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
is_supported, failure_reason = kernel.is_supported(compute_capability)
if not is_supported:
return False, f"{kernel.__name__} {failure_reason}."
can_implement, failure_reason = kernel.can_implement(config)
if not can_implement:
return (
False,
f"{kernel.__name__} {failure_reason}.",
)
return True, ""
def choose_scaled_mm_linear_kernel( def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None config: _KernelConfigT,
) -> type[ScaledMMLinearKernel]: possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
compute_capability: int | None = None,
force_kernel: type[_KernelT] | None = None,
) -> type[_KernelT]:
""" """
Choose an ScaledMMLinearKernel that can implement the given config for the Choose a _KernelT that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of given compute capability. Attempts to choose the best kernel in terms of
performance. performance.
Args: Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer config (_KernelConfigT): Description of the linear layer
to be implemented. to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the the target device, if None uses `current_platform` to get the
compute capability. Defaults to None. compute capability. Defaults to None.
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
the possible_kernels if it can be implemented. If None, it will only try the
possible kernels.
Raises: Raises:
ValueError: If no kernel can implement the given config. ValueError: If no kernel can implement the given config.
Returns: Returns:
type[ScaledMMLinearKernel]: Chosen kernel. _KernelT: Chosen kernel.
""" """
failure_reasons = [] failure_reason_list = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue
# If the current platform uses compute_capability, if force_kernel is not None:
# make sure the kernel supports the compute capability. can_implement, failure_reason = is_supported_and_can_implement_kernel(
is_supported, reason = kernel.is_supported(compute_capability) force_kernel, config, compute_capability
if not is_supported: )
failure_reasons.append(f"{kernel.__name__}: {reason}") if can_implement:
continue return force_kernel
can_implement, reason = kernel.can_implement(config) logger.info_once(
if not can_implement: "Tried to force %s, but the kernel couldn't be implemented",
failure_reasons.append(f"{kernel.__name__}: {reason}") force_kernel.__name__,
continue scope="global",
)
return kernel for kernel in possible_kernels[current_platform._enum]:
is_supported_and_can_implement, failure_reason = (
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
)
if is_supported_and_can_implement:
return kernel
failure_reason_list.append(failure_reason)
raise ValueError( raise ValueError(
"Failed to find a kernel that can implement the " "Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons) "ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
)
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
def init_int8_linear_kernel(
is_channelwise: bool,
is_static_input_scheme: bool,
input_symmetric: bool,
module_name: str,
) -> Int8ScaledMMLinearKernel:
config = Int8ScaledMMLinearLayerConfig(
is_channelwise=is_channelwise,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(
config,
_POSSIBLE_INT8_KERNELS,
)
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
],
) )
...@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops ...@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod @classmethod
def is_supported( def is_supported(
cls, compute_capability: int | None = None cls, compute_capability: int | None = None
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if not current_platform.is_rocm(): if not current_platform.is_rocm():
return ( return False, "Requires ROCm."
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90: if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}" return False, "requires compute capability 90 and above."
try: try:
import aiter # noqa: F401 # deliberately attempt to import aiter import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception: except Exception:
return ( return False, "requires `aiter` to be installed."
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
if not rocm_aiter_ops.is_linear_enabled(): if not rocm_aiter_ops.is_linear_enabled():
return ( return (
False, False,
"AiterScaledMMLinearKernel is disabled. " "requires setting `VLLM_ROCM_USE_AITER=1` "
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
) )
return True, None return True, None
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric: if not c.input_symmetric:
return ( return False, "supports symmetric quantization only."
False,
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
`AiterScaledMMLinearKernel` implements a fused version of `AiterInt8ScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting. broadcasting.
Currently only support per-tensor-per-tensor GEMM Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM. ATIER block scaled GEMM and mix-precision GEMM.
""" """
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant: # ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s. # * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None symmetric = azp_adj is None
assert symmetric, ( assert symmetric, (
"AiterScaledMMLinearKernel only supports symmetric quantization." "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
) )
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric) x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, ( assert x_zp is None, (
"AiterScaledMMLinearKernel only supports symmetric quantization." "AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
) )
out_dtype = x.dtype out_dtype = x.dtype
...@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): ...@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
), ( ), (
"Currently only support per-tensor-per-tensor GEMM " "Currently only support per-tensor-per-tensor GEMM "
+ " and per-token-per-channel GEMM through AITER" + " and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " " w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
+ "does not support AITER block scaled GEMM." + "does not support AITER block scaled GEMM."
) )
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K] # a to be [M, K]
# b to be [N, K] # b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
...@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel ...@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CPUScaledMMLinearKernel(ScaledMMLinearKernel): class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod @classmethod
def is_supported( def is_supported(
cls, compute_capability: int | None = None cls, compute_capability: int | None = None
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if not current_platform.is_cpu(): if not current_platform.is_cpu():
return False, "Requires CPU." return False, "requires CPU."
return True, None return True, None
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name) w_q_name, _, _, _, _ = self.layer_param_names
weight = getattr(layer, w_q_name)
dtype = weight.dtype dtype = weight.dtype
N, K = weight.size() N, K = weight.size()
if ( if (
...@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None: def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT # WEIGHT
# Transpose to [K, N] for convenience # Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name) w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
weight = getattr(layer, w_q_name)
replace_parameter( replace_parameter(
layer, layer,
self.w_q_name, w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False), torch.nn.Parameter(weight.t().data, requires_grad=False),
) )
...@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N # If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case. # scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1 is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name) weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise: if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter( replace_parameter(
layer, layer,
self.w_s_name, w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False), torch.nn.Parameter(weight_scale.data, requires_grad=False),
) )
# INPUT SCALE # INPUT SCALE
if self.config.is_static_input_scheme: if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name) input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric: if self.config.input_symmetric:
replace_parameter( replace_parameter(
layer, layer,
self.i_s_name, i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False), torch.nn.Parameter(input_scale.max(), requires_grad=False),
) )
setattr(layer, self.i_zp_name, None)
else: else:
input_zero_point = getattr(layer, self.i_zp_name) input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges # reconstruct the ranges
int8_traits = torch.iinfo(torch.int8) int8_traits = torch.iinfo(torch.int8)
...@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter( replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
) )
azp = ( azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32) (int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
) )
replace_parameter( replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
) )
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# Different from cutlass, oneDNN kernels only need the AZP adjustment # Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the # term for dynamic quantization. And s_b should be folded into the
# term. Such as: # term. Such as:
...@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias = # s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias # s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme): if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name) weight = getattr(layer, w_q_name)
weight_scale = getattr(layer, self.w_s_name) weight_scale = getattr(layer, w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze() azp_adj = azp_adj * weight_scale.squeeze()
setattr( setattr(
layer, layer,
self.azp_adj_name, azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False), torch.nn.Parameter(azp_adj, requires_grad=False),
) )
else:
setattr(layer, self.azp_adj_name, None)
weight = getattr(layer, self.w_q_name) weight = getattr(layer, w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm( self.dnnl_handler = ops.create_onednn_scaled_mm(
weight, weight,
getattr(layer, self.w_s_name), getattr(layer, w_s_name),
torch.get_default_dtype(), torch.get_default_dtype(),
getattr(layer, self.i_s_name) is None, getattr(layer, i_s_name) is None,
not self.config.input_symmetric, not self.config.input_symmetric,
32, 32,
) )
# weight is prepacked and maintained by the dnnl_handler, # weight is prepacked and maintained by the dnnl_handler,
# release the original weight # release the original weight
setattr(layer, self.w_q_name, None) setattr(layer, w_q_name, None)
del weight del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None: def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, _, _, _ = self.layer_param_names
# WEIGHT # WEIGHT
weight = getattr(layer, self.w_q_name) weight = getattr(layer, w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight) packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter( replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False) layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
) )
if layer.bias is not None: if layer.bias is not None:
...@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT SCALE # WEIGHT SCALE
# CPU SGL kernels only support per-channel. # CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case. # For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name) weight_scale = getattr(layer, w_s_name)
if not self.config.is_channelwise: if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter( replace_parameter(
layer, layer,
self.w_s_name, w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False), torch.nn.Parameter(weight_scale.data, requires_grad=False),
) )
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant: # ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.
...@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer) w_q, w_s, _, _, _ = self._get_layer_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant( return torch.ops._C.int8_scaled_mm_with_quant(
x, x,
w_q, w_q,
......
...@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod @classmethod
def is_supported( def is_supported(
cls, compute_capability: int | None = None cls, compute_capability: int | None = None
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return False, "Requires CUDA." return False, "requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return True, None return True, None
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
config = self.config
# WEIGHT # WEIGHT
# Cutlass kernels need transposed weight. # Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name) weight = getattr(layer, w_q_name)
replace_parameter( replace_parameter(
layer, layer,
self.w_q_name, w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False), torch.nn.Parameter(weight.t().data, requires_grad=False),
) )
...@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N # If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case. # scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1 is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name) weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise: if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter( replace_parameter(
layer, layer,
self.w_s_name, w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False), torch.nn.Parameter(weight_scale.data, requires_grad=False),
) )
# INPUT SCALE # INPUT SCALE
if self.config.is_static_input_scheme: if config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name) input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric: if config.input_symmetric:
replace_parameter( replace_parameter(
layer, layer,
self.i_s_name, i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False), torch.nn.Parameter(input_scale.max(), requires_grad=False),
) )
setattr(layer, self.i_zp_name, None) setattr(layer, i_zp_name, None)
else: else:
input_zero_point = getattr(layer, self.i_zp_name) input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges # reconstruct the ranges
int8_traits = torch.iinfo(torch.int8) int8_traits = torch.iinfo(torch.int8)
...@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min) scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter( replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False) layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
) )
# AZP loaded as int8 but used as int32 # AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32) azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter( replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False) layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
) )
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights. # azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for # It does not depend on scales or azp, so it is the same for
# static and dynamic quantization. # static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md # For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md # https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric: if not config.input_symmetric:
weight = getattr(layer, self.w_q_name) weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme: if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj # cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case # in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr( setattr(
layer, layer,
self.azp_adj_name, azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False), torch.nn.Parameter(azp_adj, requires_grad=False),
) )
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights( def apply_weights(
self, self,
...@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant: # ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.
...@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return ops.cutlass_scaled_mm( return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
) )
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment