Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
148117ea
Unverified
Commit
148117ea
authored
Jan 20, 2026
by
vllmellm
Committed by
GitHub
Jan 20, 2026
Browse files
[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
e9c83cdc
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
971 additions
and
560 deletions
+971
-560
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
+5
-0
tests/compile/distributed/test_fusion_all_reduce.py
tests/compile/distributed/test_fusion_all_reduce.py
+17
-27
tests/compile/distributed/test_sequence_parallelism.py
tests/compile/distributed/test_sequence_parallelism.py
+17
-26
tests/compile/test_functionalization.py
tests/compile/test_functionalization.py
+19
-22
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+182
-146
tests/compile/test_fusion_attn.py
tests/compile/test_fusion_attn.py
+20
-21
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+44
-27
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
...s/kernels/quantization/test_scaled_mm_kernel_selection.py
+24
-22
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+1
-1
tests/utils.py
tests/utils.py
+127
-0
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+33
-30
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+10
-23
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+12
-16
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+21
-26
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
...rs/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
+144
-33
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
...xecutor/layers/quantization/kernels/scaled_mm/__init__.py
+183
-30
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
...l_executor/layers/quantization/kernels/scaled_mm/aiter.py
+17
-36
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
...del_executor/layers/quantization/kernels/scaled_mm/cpu.py
+33
-38
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
...executor/layers/quantization/kernels/scaled_mm/cutlass.py
+61
-35
No files found.
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
0 → 100644
View file @
148117ea
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
tests/compile/distributed/test_fusion_all_reduce.py
View file @
148117ea
...
...
@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
GroupShape
,
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.torch_utils
import
set_random_seed
from
...utils
import
has_module_attribute
,
multi_gpu_test
from
...utils
import
TestFP8Layer
,
has_module_attribute
,
multi_gpu_test
from
..backend
import
TestBackend
...
...
@@ -76,25 +75,21 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
)
.
to
(
dtype
=
current_platform
.
fp8_dtype
())
.
t
()
for
_
in
range
(
3
)
]
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
self
.
fp8_linear_layers
=
[
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)
]
for
i
in
range
(
3
)
]
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
...
...
@@ -102,23 +97,18 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
z2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
z2
=
self
.
fp8_linear_layers
[
0
](
y
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
z3
=
self
.
fp8_linear_layers
[
1
](
y2
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
z4
=
self
.
fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
z4
=
self
.
fp8_linear_layers
[
2
](
y3
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
...
...
@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
if
self
.
fp8_linear
.
quant_fp8
.
enabled
()
if
self
.
fp8_linear
_layers
[
0
].
is_
quant_fp8
_
enabled
()
else
torch
.
ops
.
aten
.
reciprocal
.
default
,
]
...
...
tests/compile/distributed/test_sequence_parallelism.py
View file @
148117ea
...
...
@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.torch_utils
import
set_random_seed
from
...utils
import
multi_gpu_test
from
...utils
import
TestFP8Layer
,
multi_gpu_test
from
..backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -94,26 +95,22 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
eps
=
1e-6
):
super
().
__init__
()
self
.
vllm_config
=
get_current_vllm_config
()
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
)
.
to
(
dtype
=
current_platform
.
fp8_dtype
())
.
t
()
for
_
in
range
(
3
)
]
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
self
.
fp8_linear_layers
=
[
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)
]
for
i
in
range
(
3
)
]
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
...
...
@@ -121,23 +118,17 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
z2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
z2
=
self
.
fp8_linear_layers
[
0
](
y
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
z3
=
self
.
fp8_linear_layers
[
1
](
y2
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
z4
=
self
.
fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
z4
=
self
.
fp8_linear_layers
[
2
](
y3
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
...
...
@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
]
elif
self
.
fp8_linear
.
quant_fp8
.
enabled
():
elif
any
(
layer
.
is_
quant_fp8
_
enabled
()
for
layer
in
self
.
fp8_linear_layers
)
:
return
[
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
]
...
...
tests/compile/test_functionalization.py
View file @
148117ea
...
...
@@ -20,11 +20,13 @@ from vllm.config import (
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
kFp8StaticTensorSym
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
..utils
import
TestFP8Layer
from
.backend
import
TestBackend
TEST_FP8
=
current_platform
.
supports_fp8
()
...
...
@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
class
TestSiluMul
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
:
int
=
128
):
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
if
TEST_FP8
:
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
fp8_linear
=
Fp8LinearOp
(
act
_quant_static
=
True
,
ac
t_quant_
group_shape
=
GroupShape
.
PER_TENSOR
,
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
act
ivation_quant_key
=
self
.
quant_key
,
weigh
t_quant_
key
=
self
.
quant_key
,
)
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
if
TEST_FP8
:
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
return
x2
return
self
.
fp8_linear
(
y
)
else
:
return
y
...
...
@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
class
TestFusedAddRMSNorm
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
...
...
@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
if
TEST_FP8
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
w
=
torch
.
rand
(
hidden_size
,
intermediate_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
intermediate_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
)
def
forward
(
self
,
hidden_states
,
residual
):
# Reshape input
...
...
@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if
TEST_FP8
:
# scaled_mm with static input quantization
fp8_linear_result
=
self
.
fp8_linear
.
apply
(
norm_output
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
scale
.
to
(
norm_output
.
device
),
)
fp8_linear_result
=
self
.
fp8_linear
(
norm_output
)
return
fp8_linear_result
,
residual_output
...
...
tests/compile/test_fusion.py
View file @
148117ea
...
...
@@ -5,6 +5,7 @@
import
pytest
import
torch
import
vllm.config
import
vllm.plugins
from
vllm._aiter_ops
import
IS_AITER_FOUND
,
rocm_aiter_ops
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
...
...
@@ -20,8 +21,22 @@ from vllm.config import (
VllmConfig
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
...
...
@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaleDesc
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_supported
,
)
from
..utils
import
override_cutlass_fp8_supported
from
..utils
import
TestBlockFP8Layer
,
TestFP8Layer
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -45,135 +59,170 @@ FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
(
FlashInferFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
(
CutlassFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
(
CutlassFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
(
PerTensorTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# Blockwise group shapes (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
)),
(
None
,
GroupShape
(
1
,
64
)),
]
# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# ROCmFP8ScaledMMLinearKernel supports per-tensor only
(
ROCmFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# Blockwise group shapes (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
)),
(
None
,
GroupShape
(
1
,
64
)),
]
KERNEL_GROUPSHAPE_COMBINATIONS
=
(
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
if
current_platform
.
is_cuda
()
else
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)
# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# Per-token with ROCmFP8ScaledMMLinearKernel
(
ROCmFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
,
False
),
# Per-token with RowWiseTorchFP8ScaledMMLinearKernel
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
True
),
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
False
),
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
True
),
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
False
),
# Blockwise (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
),
True
),
]
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
,
force_kernel
:
FP8ScaledMMLinearKernel
|
None
,
group_shape
:
GroupShape
,
use_aiter
:
bool
=
False
,
cuda_force_torch
:
bool
=
False
,
use_aiter_quant_op
:
bool
=
True
,
use_aiter_fusion
:
bool
=
False
,
use_aiter_quant
:
bool
=
False
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
use_aiter
=
use_aiter
self
.
use_aiter_quant_op
=
use_aiter_quant_op
self
.
cuda_force_torch
=
cuda_force_torch
self
.
fp8_linear_layers
:
list
[
torch
.
nn
.
Module
]
self
.
group_shape
=
group_shape
self
.
enable_quant_fp8_custom_op
=
None
# Will be set later if applicable
self
.
use_aiter_quant_op
=
use_aiter_quant
self
.
use_aiter_fusion
=
use_aiter_fusion
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
# Setup quantization scale descriptor
static
=
group_shape
==
GroupShape
.
PER_TENSOR
and
not
use_aiter
quant_scale
=
ScaleDesc
(
torch
.
float32
,
static
,
group_shape
)
self
.
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
# Setup scales
if
static
:
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
else
:
self
.
scale
=
[
None
for
_
in
range
(
3
)]
# Determine if blockwise based on group_shape
is_blockwise
=
group_shape
.
is_per_group
()
# Setup weights
self
.
w
=
[
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
)
for
_
in
range
(
3
)
]
if
not
group_shape
.
is_per_group
()
or
use_aiter
:
self
.
w
=
[
self
.
w
[
0
].
t
()
for
_
in
range
(
3
)]
# Setup weight scales
if
group_shape
.
is_per_group
():
scale_size
=
(
(
hidden_size
+
128
-
1
)
//
128
if
use_aiter
else
hidden_size
//
group_shape
[
1
]
)
wscale_shape
:
tuple
[
int
,
...]
=
(
scale_size
,
scale_size
)
else
:
wscale_shape
=
(
1
,)
self
.
wscale
=
[
torch
.
rand
(
wscale_shape
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
# Setup FP8 linear operation
is_per_group
=
group_shape
.
is_per_group
()
if
is_per_group
and
use_aiter
:
self
.
fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
128
,
128
),
act_quant_group_shape
=
group_shape
,
use_aiter_and_is_supported
=
use_aiter_quant_op
,
if
is_blockwise
:
act_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
self
.
activation_quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
act_quant_scale_desc
,
symmetric
=
True
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif
is_per_group
:
self
.
fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
group_shape
[
1
],
group_shape
[
1
]),
act_quant_group_shape
=
group_shape
,
self
.
fp8_linear_layers
=
[
TestBlockFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
(),
use_aiter_and_is_supported
=
False
,
use_aiter_and_is_supported
=
use_aiter_quant
,
transpose_weights
=
use_aiter_fusion
,
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
input_quant_op
.
enabled
()
elif
use_aiter
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
False
,
act_quant_group_shape
=
group_shape
,
for
_
in
range
(
3
)
]
self
.
enable_quant_fp8_custom_op
=
(
False
if
use_aiter_quant
else
self
.
fp8_linear_layers
[
0
].
linear_op
.
input_quant_op
.
enabled
()
)
self
.
fp8_linear
.
quant_fp8
.
use_aiter
=
use_aiter_quant_op
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
else
:
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
static
,
act_quant_group_shape
=
group_shape
,
is_static
=
group_shape
==
GroupShape
.
PER_TENSOR
act_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
is_static
,
group_shape
)
w_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
True
,
group_shape
)
self
.
activation_quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
act_quant_scale_desc
,
symmetric
=
True
)
self
.
weight_quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
w_quant_scale_desc
,
symmetric
=
True
)
self
.
fp8_linear_layers
=
[
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
activation_quant_key
,
weight_quant_key
=
self
.
weight_quant_key
,
force_kernel
=
force_kernel
,
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
for
_
in
range
(
3
)
]
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
# Enable aiter quantization if requested
for
layer
in
self
.
fp8_linear_layers
:
layer
.
kernel
.
quant_fp8
.
use_aiter
=
use_aiter_quant
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear_layers
[
0
].
is_quant_fp8_enabled
()
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
x
=
resid
=
torch
.
relu
(
x
)
y
=
self
.
norm
[
0
](
x
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
x2
=
self
.
fp8_linear_layers
[
0
](
y
)
# make sure resid is used for replacement to work
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
x3
=
self
.
fp8_linear
.
apply
(
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
x3
=
self
.
fp8_linear_layers
[
1
](
y2
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
x4
=
self
.
fp8_linear
.
apply
(
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
x4
=
self
.
fp8_linear_layers
[
2
](
y3
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
def
ops_in_model_before
(
self
):
if
(
self
.
use_aiter
and
self
.
group_shape
.
is_per_group
()
and
current_platform
.
is_fp8_fnuz
()
):
if
self
.
group_shape
.
is_per_group
():
# Blockwise path
if
self
.
use_aiter_fusion
and
self
.
use_aiter_quant_op
:
return
[
rocm_aiter_ops
.
get_group_quant_op
()]
if
self
.
use_aiter
and
self
.
group_shape
.
is_per_group
()
:
if
self
.
use_aiter
_fusion
:
return
[
torch
.
ops
.
vllm
.
triton_per_token_group_quant_fp8
.
default
]
if
self
.
use_aiter
and
self
.
use_aiter_quant_op
:
else
:
if
self
.
use_aiter_quant_op
:
return
[
rocm_aiter_ops
.
get_per_token_quant_op
()]
if
self
.
use_aiter
:
return
[
QUANT_OPS
[
self
.
quant_key
]]
if
self
.
enable_quant_fp8_custom_op
:
return
[
QUANT_OPS
[
self
.
quant_key
]]
return
[
torch
.
ops
.
aten
.
reciprocal
]
# Common path
return
(
[
QUANT_OPS
[
self
.
activation_quant_key
]]
if
self
.
enable_quant_fp8_custom_op
else
[
torch
.
ops
.
aten
.
reciprocal
]
)
def
ops_in_model_after
(
self
):
if
self
.
use_aiter
and
self
.
group_shape
.
is_per_group
():
if
self
.
use_aiter_fusion
:
if
self
.
group_shape
.
is_per_group
():
# Blockwise aiter fusion
from
vllm.compilation.rocm_aiter_fusion
import
(
AiterFusedAddRMSFp8GroupQuantPattern
,
AiterRMSFp8GroupQuantPattern
,
...
...
@@ -183,7 +232,8 @@ class TestModel(torch.nn.Module):
AiterFusedAddRMSFp8GroupQuantPattern
.
FUSED_OP
,
AiterRMSFp8GroupQuantPattern
.
FUSED_OP
,
]
if
self
.
use_aiter
:
else
:
# Per-token aiter fusion
from
vllm.compilation.rocm_aiter_fusion
import
(
AiterFusedAddRMSNormDynamicQuantPattern
,
AiterRMSNormDynamicQuantPattern
,
...
...
@@ -193,9 +243,11 @@ class TestModel(torch.nn.Module):
AiterFusedAddRMSNormDynamicQuantPattern
.
FUSED_OP
,
AiterRMSNormDynamicQuantPattern
.
FUSED_OP
,
]
# Regular fusion
return
[
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_key
,
True
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_key
,
False
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
activation_
quant_key
,
True
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
activation_
quant_key
,
False
)],
]
def
ops_in_model_before_partial
(
self
):
...
...
@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module):
)
GROUP_SHAPES
=
[
GroupShape
.
PER_TOKEN
,
GroupShape
.
PER_TENSOR
,
GroupShape
(
1
,
128
),
GroupShape
(
1
,
64
),
]
def
_run_fusion_test
(
model
,
fusion_pass
,
...
...
@@ -259,14 +303,9 @@ def _run_fusion_test(
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"group
_
shape"
,
GROUP
_
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"
kernel_
groupshape"
,
KERNEL_
GROUPSHAPE
_COMBINATION
S
)
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_quant_fp8_custom_op"
,
[
True
,
False
])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
parametrize
(
"cuda_force_torch"
,
[
True
,
False
]
if
cutlass_fp8_supported
()
else
[
True
]
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test on CUDA and ROCm"
)
...
...
@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant(
hidden_size
,
num_tokens
,
eps
,
group
_
shape
,
kernel_
groupshape
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
cuda_force_torch
,
):
force_kernel
,
group_shape
=
kernel_groupshape
if
not
enable_quant_fp8_custom_op
and
group_shape
.
is_per_group
():
pytest
.
skip
(
"Unsupported unwrapped quant fp8 op for blockwise quantization"
)
...
...
@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
model
=
TestModel
(
hidden_size
=
hidden_size
,
eps
=
eps
,
force_kernel
=
force_kernel
,
group_shape
=
group_shape
,
use_aiter
=
False
,
cuda_force_torch
=
cuda_force_torch
,
use_aiter
_fusion
=
False
,
use_aiter_quant
=
False
,
)
backend
,
_
=
_run_fusion_test
(
...
...
@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant(
assert
n_add_nodes
(
backend
.
graph_post_pass
)
==
2
GROUP_SHAPE_QUANT_OPS_MATCHS
=
[
(
GroupShape
.
PER_TOKEN
,
True
),
(
GroupShape
.
PER_TOKEN
,
False
),
(
GroupShape
(
1
,
128
),
True
),
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"group
_
shape
, use_aiter_quant_op"
,
GROUP
_
SHAPE_
QUANT_OPS_MATCH
S
"
kernel_
groupshape
_quant"
,
AITER_KERNEL_
GROUPSHAPE_
COMBINATION
S
)
@
pytest
.
mark
.
skipif
(
(
not
current_platform
.
is_rocm
()
or
not
IS_AITER_FOUND
),
...
...
@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant(
hidden_size
:
int
,
num_tokens
:
int
,
eps
:
float
,
group_shape
:
GroupShape
,
use_aiter_quant_op
:
bool
,
kernel_groupshape_quant
:
tuple
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
force_kernel
,
group_shape
,
use_aiter_quant_op
=
kernel_groupshape_quant
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
dtype
),
compilation_config
=
CompilationConfig
(
...
...
@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant(
from
vllm.compilation.rocm_aiter_fusion
import
RocmAiterRMSNormFusionPass
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
rocm_aiter_ops
.
refresh_env_variables
()
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
fusion_pass
=
RocmAiterRMSNormFusionPass
(
vllm_config
)
model
=
TestModel
(
hidden_size
=
hidden_size
,
eps
=
eps
,
force_kernel
=
force_kernel
,
group_shape
=
group_shape
,
use_aiter
=
True
,
use_aiter_quant
_op
=
use_aiter_quant_op
,
use_aiter
_fusion
=
True
,
# Always use aiter fusion ops in aiter test
use_aiter_quant
=
use_aiter_quant_op
,
# Toggle aiter quantization
)
_run_fusion_test
(
...
...
tests/compile/test_fusion_attn.py
View file @
148117ea
...
...
@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym
,
kNvfp4Quant
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
..utils
import
TestFP8Layer
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
...
...
@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
self
.
quant_key
.
scale
.
static
,
act_quant_group_shape
=
self
.
quant_key
.
scale
.
group_shape
,
)
hidden_size
=
self
.
num_qo_heads
*
self
.
head_size
self
.
w
=
kwargs
.
get
(
"w"
,
{
"weight"
:
torch
.
randn
(
hidden_size
,
hidden_size
)
.
to
(
dtype
=
FP8_DTYPE
,
device
=
self
.
device
)
.
t
(),
"wscale"
:
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
"scale"
:
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
},
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
device
=
self
.
device
,
)
w
=
kwargs
.
get
(
"w"
)
if
w
is
not
None
:
self
.
fp8_linear
.
weight
=
w
[
"weight"
]
self
.
fp8_linear
.
weight_scale
=
w
[
"wscale"
]
self
.
fp8_linear
.
input_scale
=
w
[
"scale"
]
self
.
w
=
{
"weight"
:
self
.
fp8_linear
.
weight
,
"wscale"
:
self
.
fp8_linear
.
weight_scale
,
"scale"
:
self
.
fp8_linear
.
input_scale
,
}
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
):
"""Forward pass that creates the pattern to be fused."""
attn_output
=
self
.
attn
(
q
,
k
,
v
)
return
self
.
fp8_linear
.
apply
(
input
=
attn_output
,
weight
=
self
.
w
[
"weight"
],
weight_scale
=
self
.
w
[
"wscale"
],
input_scale
=
self
.
w
[
"scale"
],
)
return
self
.
fp8_linear
(
attn_output
)
class
TestAttentionNvfp4QuantPatternModel
(
AttentionQuantPatternModel
):
...
...
tests/compile/test_silu_mul_quant_fusion.py
View file @
148117ea
...
...
@@ -25,19 +25,30 @@ from vllm.config import (
set_current_vllm_config
,
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
PerTensorTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
)
from
vllm.platforms
import
current_platform
from
..utils
import
override_cutlass_fp8_supported
from
..utils
import
TestFP8Layer
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
...
@@ -49,25 +60,27 @@ def is_nvfp4_supported():
class
TestSiluMulFp8QuantModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
cuda_force_torch
:
bool
,
**
kwargs
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
:
int
,
force_kernel
:
FP8ScaledMMLinearKernel
,
**
kwargs
):
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
force_kernel
=
force_kernel
,
)
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
is_
quant_fp8
_
enabled
()
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
x2
=
self
.
fp8_linear
(
y
)
return
x2
def
ops_in_model_before
(
self
):
...
...
@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return
[
torch
.
ops
.
vllm
.
rocm_aiter_act_mul_and_fp8_group_quant
]
ROCM_KERNELS
=
[
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
]
CUDA_KERNELS
=
[
FlashInferFP8ScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
]
TEST_KERNELS
=
ROCM_KERNELS
if
current_platform
.
is_rocm
()
else
CUDA_KERNELS
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"enable_silu_mul_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"model_class, enable_quant_fp8_custom_op,
cuda_
force_
torch
"
,
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
[
True
,
False
]
))
"model_class, enable_quant_fp8_custom_op, force_
kernel
"
,
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
TEST_KERNELS
))
+
[
(
TestSiluMulNvfp4QuantModel
,
False
,
Fals
e
),
(
TestSiluMulGroupFp8QuantModel
,
False
,
Fals
e
),
(
TestSiluMulNvfp4QuantModel
,
False
,
Non
e
),
(
TestSiluMulGroupFp8QuantModel
,
False
,
Non
e
),
],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
],
reason
=
"Only test on CUDA and ROCm"
)
...
...
@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
],
enable_silu_mul_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
cuda_
force_
torch
:
bool
,
force_
kernel
:
FP8ScaledMMLinearKernel
|
None
,
):
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
():
pytest
.
skip
(
"NVFP4 is not supported on this GPU."
)
...
...
@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
maybe_create_device_identity
()
x
=
torch
.
rand
(
num_tokens
,
hidden_size
*
2
)
...
...
@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
passes
=
[
NoOpEliminationPass
(
config
),
*
fusion_passes
,
PostCleanupPass
(
config
)]
backend
=
TestBackend
(
*
passes
)
model
=
model_class
(
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
)
model
=
model_class
(
hidden_size
=
hidden_size
,
force_kernel
=
force_kernel
,
x
=
x
)
# First dimension dynamic
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
...
...
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
View file @
148117ea
...
...
@@ -11,13 +11,13 @@ from abc import ABC
import
pytest
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
ScaledMMLinearLayerConfig
,
Int8
ScaledMMLinearLayerConfig
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
,
Aiter
Int8
ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
CPUScaledMMLinearKernel
,
CPU
Int8
ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
ScaledMMLinearKernel
,
...
...
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
def
test_cpu_kernel_implements_is_supported
():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
CPUScaledMMLinearKernel
,
"is_supported"
),
(
"CPUScaledMMLinearKernel missing is_supported() method"
"""Test that CPU
Int8
ScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
CPU
Int8
ScaledMMLinearKernel
,
"is_supported"
),
(
"CPU
Int8
ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert
inspect
.
ismethod
(
CPUScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
CPUScaledMMLinearKernel
.
is_supported
),
"CPUScaledMMLinearKernel.is_supported() should be a classmethod"
assert
inspect
.
ismethod
(
CPUInt8ScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
CPUInt8ScaledMMLinearKernel
.
is_supported
),
(
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
result
,
reason
=
CPUScaledMMLinearKernel
.
is_supported
()
result
,
reason
=
CPU
Int8
ScaledMMLinearKernel
.
is_supported
()
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
def
test_aiter_kernel_implements_is_supported
():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
AiterScaledMMLinearKernel
,
"is_supported"
),
(
"AiterScaledMMLinearKernel missing is_supported() method"
"""Test that Aiter
Int8
ScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
Aiter
Int8
ScaledMMLinearKernel
,
"is_supported"
),
(
"Aiter
Int8
ScaledMMLinearKernel missing is_supported() method"
)
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
assert
inspect
.
ismethod
(
AiterScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
AiterScaledMMLinearKernel
.
is_supported
),
(
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
),
(
"Aiter
Int8
ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
result
,
reason
=
AiterScaledMMLinearKernel
.
is_supported
()
result
,
reason
=
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
()
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
...
...
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
def
test_cpu_kernel_accepts_all_configs
():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
"""Test that CPU
Int8
ScaledMMLinearKernel accepts all config combinations."""
configs
=
[
ScaledMMLinearLayerConfig
(
Int8
ScaledMMLinearLayerConfig
(
is_channelwise
=
False
,
is_static_input_scheme
=
True
,
input_symmetric
=
True
,
),
ScaledMMLinearLayerConfig
(
Int8
ScaledMMLinearLayerConfig
(
is_channelwise
=
True
,
is_static_input_scheme
=
False
,
input_symmetric
=
False
,
...
...
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
]
for
config
in
configs
:
can_impl
,
reason
=
CPUScaledMMLinearKernel
.
can_implement
(
config
)
can_impl
,
reason
=
CPU
Int8
ScaledMMLinearKernel
.
can_implement
(
config
)
assert
can_impl
,
(
f
"CPUScaledMMLinearKernel should accept config
{
config
}
:
{
reason
}
"
f
"CPU
Int8
ScaledMMLinearKernel should accept config
{
config
}
:
{
reason
}
"
)
tests/quantization/test_compressed_tensors.py
View file @
148117ea
...
...
@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
]
# TritonScaledMMLinearKernel only supports symmetric quantization.
# Triton
Int8
ScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
=
[
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
...
...
tests/utils.py
View file @
148117ea
...
...
@@ -42,6 +42,17 @@ from vllm.distributed import (
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.cli.serve
import
ServeSubcommand
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
QuantKey
,
)
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.tokenizers
import
get_tokenizer
...
...
@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
if
current_platform
.
is_rocm
():
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
...
...
@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
for
element
in
itertools
.
product
(
*
iterables
):
normalized
=
(
e
if
isinstance
(
e
,
tuple
)
else
(
e
,)
for
e
in
element
)
yield
tuple
(
itertools
.
chain
(
*
normalized
))
class
TestFP8Layer
(
torch
.
nn
.
Module
):
"""
Test helper for FP8 linear operations. Creates random weights and scales
based on quantization configuration.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
activation_quant_key: Activation quantization configuration.
weight_quant_key: Weight quantization configuration.
out_dtype: Output dtype. Defaults to current default dtype.
force_kernel: Optional kernel to force use of specific implementation.
"""
def
__init__
(
self
,
weight_shape
:
tuple
[
int
,
int
],
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
force_kernel
:
FP8ScaledMMLinearKernel
|
None
=
None
,
):
super
().
__init__
()
per_tensor_weights
=
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
is_static_activation_scale
=
activation_quant_key
.
scale
.
static
weight_scale_shape
=
(
1
,)
if
per_tensor_weights
else
(
weight_shape
[
0
],
1
)
self
.
weight_scale
=
torch
.
rand
(
weight_scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
input_scale
=
(
torch
.
rand
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
is_static_activation_scale
else
None
)
self
.
weight
=
torch
.
rand
(
weight_shape
,
device
=
device
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
input_scale_ub
=
None
out_dtype
=
torch
.
get_default_dtype
()
if
out_dtype
is
None
else
out_dtype
self
.
kernel
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
weight_quant_key
,
out_dtype
=
out_dtype
,
force_kernel
=
force_kernel
,
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
kernel
.
quant_fp8
.
enabled
()
def
forward
(
self
,
y
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
self
,
y
,
bias
)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class
TestBlockFP8Layer
:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def
__init__
(
self
,
weight_shape
:
tuple
[
int
,
int
],
group_shape
:
GroupShape
,
cutlass_block_fp8_supported
:
bool
=
False
,
use_aiter_and_is_supported
:
bool
=
False
,
transpose_weights
:
bool
=
False
,
):
weight_scale_shape
=
weight_shape
[
0
]
//
group_shape
[
1
]
self
.
weight_scale
=
torch
.
rand
(
(
weight_scale_shape
,
weight_scale_shape
),
dtype
=
torch
.
float32
)
self
.
weight
=
torch
.
rand
(
weight_shape
).
to
(
dtype
=
FP8_DTYPE
)
self
.
input_scale
=
None
if
transpose_weights
:
self
.
weight
=
self
.
weight
.
t
()
self
.
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
group_shape
[
1
],
group_shape
[
1
]),
act_quant_group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
use_aiter_and_is_supported
,
)
def
__call__
(
self
,
y
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
return
self
.
linear_op
.
apply
(
input
=
y
,
weight
=
self
.
weight
,
weight_scale
=
self
.
weight_scale
,
input_scale
=
self
.
input_scale
,
bias
=
bias
,
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
linear_op
.
input_quant_op
.
enabled
()
vllm/_aiter_ops.py
View file @
148117ea
...
...
@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl(
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# Cutlass
Int8
ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
148117ea
...
...
@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from
torch.nn
import
Parameter
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
create_fp8_input_scale
,
...
...
@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy
,
validate_fp8_block_shape
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
maybe_create_device_identity
,
)
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
...
...
@@ -42,6 +49,18 @@ strategy_to_parameter_type = {
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
}
STATIC_QUANT
=
True
DYNAMIC_QUANT
=
False
activation_quant_key_mapping
=
{
STATIC_QUANT
:
kFp8StaticTensorSym
,
DYNAMIC_QUANT
:
kFp8DynamicTokenSym
,
}
weight_quant_key_mapping
=
{
QuantizationStrategy
.
CHANNEL
:
kFp8StaticTokenSym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
}
logger
=
init_logger
(
__name__
)
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
...
...
@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
strategy
=
weight_quant
.
strategy
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
if
self
.
weight_block_size
is
not
None
:
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
else
:
self
.
act_q_group_shape
=
(
GroupShape
.
PER_TENSOR
if
is_static_input_scheme
else
GroupShape
.
PER_TOKEN
)
if
self
.
weight_block_size
is
not
None
:
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
if
self
.
weight_block_size
is
not
None
:
assert
not
self
.
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
self
.
act_q_group_shape
,
...
...
@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
else
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
self
.
is_static_input_scheme
,
act_quant_group_shape
=
self
.
act_q_group_shape
,
activation_quant_key
=
activation_quant_key_mapping
[
is_static_input_scheme
]
weight_quant_key
=
weight_quant_key_mapping
[
self
.
strategy
]
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
weight_quant_key
,
out_dtype
=
self
.
out_dtype
,
module_name
=
self
.
__class__
.
__name__
,
)
@
classmethod
...
...
@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader
:
Callable
,
**
kwargs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
weight_block_size
=
None
...
...
@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
getattr
(
layer
,
"input_scale"
,
None
),
)
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
...
...
@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
maybe_post_process_fp8_weight_block
(
layer
)
...
...
@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
148117ea
...
...
@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
ScaledMMLinearLayerConfig
,
choose_scaled_mm_linear_kernel
,
init_int8_linear_kernel
,
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
...
...
@@ -25,8 +24,6 @@ logger = init_logger(__name__)
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
):
...
...
@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
):
layer
.
logical_widths
=
output_partition_sizes
s
caled_mm_linear_kernel_config
=
ScaledMMLinearLayerConfig
(
s
elf
.
kernel
=
init_int8_linear_kernel
(
is_channelwise
=
(
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
),
is_static_input_scheme
=
self
.
is_static_input_scheme
,
input_symmetric
=
self
.
input_symmetric
,
module_name
=
self
.
__class__
.
__name__
,
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
scaled_mm_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsW8A8Int8"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# WEIGHT
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
...
...
@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
input_zero_point
=
None
input_scale
=
None
if
self
.
is_static_input_scheme
:
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
...
...
@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
self
.
kernel
=
kernel_type
(
c
=
scaled_mm_linear_kernel_config
,
w_q_param_name
=
"weight"
,
w_s_param_name
=
"weight_scale"
,
i_s_param_name
=
"input_scale"
,
i_zp_param_name
=
"input_zero_point"
,
azp_adj_param_name
=
"azp_adj"
,
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
hasattr
(
layer
,
"azp_adj"
):
layer
.
register_parameter
(
"azp_adj"
,
None
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
148117ea
...
...
@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
,
kFp8DynamicTokenSym
,
kFp8StaticTokenSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
vllm.model_executor.parameter
import
(
...
...
@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
False
,
act_quant_group_shape
=
GroupShape
.
PER_TOKEN
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
kFp8DynamicTokenSym
,
weight_quant_key
=
kFp8StaticTokenSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
def
create_weights
(
self
,
...
...
@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
maybe_create_device_identity
()
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
vllm/model_executor/layers/quantization/fp8.py
View file @
148117ea
...
...
@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
...
...
@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
vllm.model_executor.parameter
import
(
...
...
@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
if
self
.
weight_block_size
:
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
else
:
# Use per-token quantization for better perf if dynamic and cutlass
if
not
self
.
act_q_static
and
cutlass_fp8_supported
():
self
.
act_q_group_shape
=
GroupShape
.
PER_TOKEN
else
:
self
.
act_q_group_shape
=
GroupShape
.
PER_TENSOR
if
self
.
block_quant
:
assert
not
self
.
act_q_static
assert
self
.
weight_block_size
is
not
None
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
self
.
act_q_group_shape
,
act_quant_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
else
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
self
.
act_q_static
,
act_quant_group_shape
=
self
.
act_q_group_shape
,
# Use per-token quantization for better perf if dynamic and cutlass
if
self
.
act_q_static
:
activation_quant_key
=
kFp8StaticTensorSym
elif
cutlass_fp8_supported
():
activation_quant_key
=
kFp8DynamicTokenSym
else
:
activation_quant_key
=
kFp8DynamicTensorSym
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
kFp8StaticTensorSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
def
create_weights
(
...
...
@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
...
...
@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase):
scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"input_scale"
})
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
...
...
@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
View file @
148117ea
...
...
@@ -2,19 +2,58 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
typing
import
Generic
,
TypeVar
import
torch
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
)
from
vllm.platforms
import
current_platform
@
dataclass
class
ScaledMMLinearLayerConfig
:
is_channelwise
:
bool
pass
@
dataclass
class
Int8ScaledMMLinearLayerConfig
(
ScaledMMLinearLayerConfig
):
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme
:
bool
is_channelwise
:
bool
input_symmetric
:
bool
class
ScaledMMLinearKernel
(
ABC
):
@
dataclass
class
FP8ScaledMMLinearLayerConfig
(
ScaledMMLinearLayerConfig
):
weight_quant_key
:
QuantKey
activation_quant_key
:
QuantKey
out_dtype
:
torch
.
dtype
|
None
_FP8ParamsT
=
tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_scale_ub,
]
_Int8ParamsT
=
tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_zp
torch
.
Tensor
|
None
,
# azp_adj
]
_ParamsT
=
TypeVar
(
"_ParamsT"
,
_Int8ParamsT
,
_FP8ParamsT
)
_ConfigT
=
TypeVar
(
"_ConfigT"
,
bound
=
ScaledMMLinearLayerConfig
)
class
ScaledMMLinearKernel
(
Generic
[
_ConfigT
,
_ParamsT
],
ABC
):
@
classmethod
@
abstractmethod
def
is_supported
(
...
...
@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC):
@
classmethod
@
abstractmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayer
Config
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
_
Config
T
)
->
tuple
[
bool
,
str
|
None
]:
raise
NotImplementedError
def
__init__
(
self
,
c
:
ScaledMMLinearLayerConfig
,
w_q_param_name
:
str
,
w_s_param_name
:
str
,
i_s_param_name
:
str
,
i_zp_param_name
:
str
,
azp_adj_param_name
:
str
,
)
->
None
:
assert
self
.
can_implement
(
c
)
assert
self
.
is_supported
()
def
__init__
(
self
,
c
:
_ConfigT
,
layer_param_names
:
Sequence
[
str
])
->
None
:
assert
self
.
can_implement
(
c
)[
0
]
assert
self
.
is_supported
()[
0
]
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
w_s_name
=
w_s_param_name
self
.
i_s_name
=
i_s_param_name
self
.
i_zp_name
=
i_zp_param_name
self
.
azp_adj_name
=
azp_adj_param_name
self
.
layer_param_names
=
layer_param_names
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
...
@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
_get_weight_params
(
self
,
layer
:
torch
.
nn
.
Module
)
->
tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_zp
torch
.
Tensor
|
None
,
# azp_adj
]:
# return a covariant type in the subclass
@
abstractmethod
def
_get_layer_params
(
self
,
layer
)
->
_ParamsT
:
raise
NotImplementedError
class
FP8ScaledMMLinearKernel
(
ScaledMMLinearKernel
[
FP8ScaledMMLinearLayerConfig
,
_FP8ParamsT
],
ABC
):
def
__init__
(
self
,
c
:
FP8ScaledMMLinearLayerConfig
,
layer_param_names
:
Sequence
[
str
]
)
->
None
:
act_scale_descriptor
=
c
.
activation_quant_key
.
scale
self
.
quant_fp8
=
QuantFP8
(
static
=
act_scale_descriptor
.
static
,
group_shape
=
act_scale_descriptor
.
group_shape
,
num_token_padding
=
self
.
get_output_padding
(),
)
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
super
().
__init__
(
c
,
layer_param_names
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
_get_layer_params
(
self
,
layer
)
->
_FP8ParamsT
:
w
,
w_s
,
x_s
,
x_s_ub
=
self
.
layer_param_names
return
(
getattr
(
layer
,
w
),
getattr
(
layer
,
w_s
),
getattr
(
layer
,
x_s
,
None
),
getattr
(
layer
,
x_s_ub
,
None
),
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
fp8_dtype
=
self
.
fp8_dtype
maybe_out_dtype
=
self
.
config
.
out_dtype
w
,
w_s
,
x_s
,
x_s_ub
=
self
.
_get_layer_params
(
layer
)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output_shape
=
[
*
x
.
shape
[:
-
1
],
w
.
shape
[
1
]]
out_dtype
=
x
.
dtype
if
maybe_out_dtype
is
None
else
maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q
=
x_2d
if
x
.
dtype
!=
fp8_dtype
:
x_2d_q
,
x_s
=
self
.
quant_fp8
(
x_2d
,
x_s
,
x_s_ub
,
)
return
self
.
apply_scaled_mm
(
A
=
x_2d_q
,
B
=
w
,
out_dtype
=
out_dtype
,
As
=
x_s
,
Bs
=
w_s
,
bias
=
bias
,
output_shape
=
output_shape
,
)
@
abstractmethod
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_output_padding
(
self
)
->
int
|
None
:
return
None
class
Int8ScaledMMLinearKernel
(
ScaledMMLinearKernel
[
Int8ScaledMMLinearLayerConfig
,
_Int8ParamsT
],
ABC
):
def
_get_layer_params
(
self
,
layer
)
->
_Int8ParamsT
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
layer_param_names
return
(
getattr
(
layer
,
self
.
w_q_name
),
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
self
.
i_s_nam
e
),
getattr
(
layer
,
self
.
i_zp_nam
e
),
getattr
(
layer
,
self
.
azp_adj
_nam
e
),
getattr
(
layer
,
w_q
),
getattr
(
layer
,
w_s
),
getattr
(
layer
,
i_s
,
Non
e
),
getattr
(
layer
,
i_zp
,
Non
e
),
getattr
(
layer
,
azp_adj
,
Non
e
),
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
View file @
148117ea
...
...
@@ -2,76 +2,229 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
typing
import
TypeVar
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
,
Aiter
Int8
ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
CPUScaledMMLinearKernel
,
CPU
Int8
ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
CutlassInt8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.triton
import
(
TritonScaledMMLinearKernel
,
Triton
Int8
ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
from
vllm.platforms
import
PlatformEnum
,
current_platform
logger
=
init_logger
(
__name__
)
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CPUScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
AiterScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
_POSSIBLE_INT8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
Int8ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CPUInt8ScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassInt8ScaledMMLinearKernel
,
TritonInt8ScaledMMLinearKernel
,
],
PlatformEnum
.
ROCM
:
[
AiterInt8ScaledMMLinearKernel
,
TritonInt8ScaledMMLinearKernel
],
}
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
FP8ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CUDA
:
[
FlashInferFP8ScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
PlatformEnum
.
ROCM
:
[
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
PlatformEnum
.
CPU
:
[
PerTensorTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
}
_KernelT
=
TypeVar
(
"_KernelT"
,
bound
=
ScaledMMLinearKernel
)
_KernelConfigT
=
TypeVar
(
"_KernelConfigT"
,
bound
=
ScaledMMLinearLayerConfig
)
def
is_supported_and_can_implement_kernel
(
kernel
:
type
[
_KernelT
],
config
:
_KernelConfigT
,
compute_capability
:
int
|
None
)
->
tuple
[
bool
,
str
]:
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
).
split
(
","
):
return
False
,
f
"
{
kernel
.
__name__
}
is disabled by environment variable"
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
[
0
]
*
10
+
_cc
[
1
]
is_supported
,
failure_reason
=
kernel
.
is_supported
(
compute_capability
)
if
not
is_supported
:
return
False
,
f
"
{
kernel
.
__name__
}
{
failure_reason
}
."
can_implement
,
failure_reason
=
kernel
.
can_implement
(
config
)
if
not
can_implement
:
return
(
False
,
f
"
{
kernel
.
__name__
}
{
failure_reason
}
."
,
)
return
True
,
""
def
choose_scaled_mm_linear_kernel
(
config
:
ScaledMMLinearLayerConfig
,
compute_capability
:
int
|
None
=
None
)
->
type
[
ScaledMMLinearKernel
]:
config
:
_KernelConfigT
,
possible_kernels
:
dict
[
PlatformEnum
,
list
[
type
[
_KernelT
]]],
compute_capability
:
int
|
None
=
None
,
force_kernel
:
type
[
_KernelT
]
|
None
=
None
,
)
->
type
[
_KernelT
]:
"""
Choose a
n ScaledMMLinear
Kernel that can implement the given config for the
Choose a
_
Kernel
T
that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (
ScaledMMLinearLayer
Config): Description of the linear layer
config (
_Kernel
Config
T
): Description of the linear layer
to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
the possible_kernels if it can be implemented. If None, it will only try the
possible kernels.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[ScaledMMLinear
Kernel
]
: Chosen kernel.
_
Kernel
T
: Chosen kernel.
"""
failure_reasons
=
[]
for
kernel
in
_POSSIBLE_KERNELS
[
current_platform
.
_enum
]:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
).
split
(
","
):
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
: disabled by env var"
)
continue
failure_reason_list
=
[]
# If the current platform uses compute_capability,
# make sure the kernel supports the compute capability.
is_supported
,
reason
=
kernel
.
is_supported
(
compute_capability
)
if
not
is_supported
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
continue
if
force_kernel
is
not
None
:
can_implement
,
failure_reason
=
is_supported_and_can_implement_kernel
(
force_kernel
,
config
,
compute_capability
)
if
can_implement
:
return
force_kernel
can_implement
,
reason
=
kernel
.
can_implement
(
config
)
if
not
can_implement
:
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
continue
logger
.
info_once
(
"Tried to force %s, but the kernel couldn't be implemented"
,
force_kernel
.
__name__
,
scope
=
"global"
,
)
for
kernel
in
possible_kernels
[
current_platform
.
_enum
]:
is_supported_and_can_implement
,
failure_reason
=
(
is_supported_and_can_implement_kernel
(
kernel
,
config
,
compute_capability
)
)
if
is_supported_and_can_implement
:
return
kernel
failure_reason_list
.
append
(
failure_reason
)
raise
ValueError
(
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reasons
)
"ScaledMM linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reason_list
)
)
def
init_fp8_linear_kernel
(
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
out_dtype
:
torch
.
dtype
,
force_kernel
:
type
[
FP8ScaledMMLinearKernel
]
|
None
=
None
,
module_name
:
str
|
None
=
None
,
)
->
FP8ScaledMMLinearKernel
:
scaled_mm_linear_kernel_config
=
FP8ScaledMMLinearLayerConfig
(
weight_quant_key
=
weight_quant_key
,
activation_quant_key
=
activation_quant_key
,
out_dtype
=
out_dtype
,
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
scaled_mm_linear_kernel_config
,
_POSSIBLE_FP8_KERNELS
,
force_kernel
=
force_kernel
)
if
module_name
:
logger
.
info_once
(
"Selected %s for %s"
,
kernel_type
.
__name__
,
module_name
,
scope
=
"global"
,
)
return
kernel_type
(
scaled_mm_linear_kernel_config
,
layer_param_names
=
[
"weight"
,
"weight_scale"
,
"input_scale"
,
"input_scale_ub"
],
)
def
init_int8_linear_kernel
(
is_channelwise
:
bool
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
,
module_name
:
str
,
)
->
Int8ScaledMMLinearKernel
:
config
=
Int8ScaledMMLinearLayerConfig
(
is_channelwise
=
is_channelwise
,
is_static_input_scheme
=
is_static_input_scheme
,
input_symmetric
=
input_symmetric
,
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
config
,
_POSSIBLE_INT8_KERNELS
,
)
logger
.
info_once
(
"Selected %s for %s"
,
kernel_type
.
__name__
,
module_name
,
scope
=
"global"
,
)
return
kernel_type
(
config
,
layer_param_names
=
[
"weight"
,
"weight_scale"
,
"input_scale"
,
"input_zero_point"
,
"azp_adj"
,
],
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
View file @
148117ea
...
...
@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.platforms
import
current_platform
from
.cutlass
import
CutlassScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
ScaledMMLinearLayerConfig
from
.cutlass
import
Cutlass
Int8
ScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
Int8
ScaledMMLinearLayerConfig
class
AiterScaledMMLinearKernel
(
CutlassScaledMMLinearKernel
):
class
Aiter
Int8
ScaledMMLinearKernel
(
Cutlass
Int8
ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"currently supported on non-ROCm platform."
,
)
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
return
False
,
"Requires ROCm."
if
compute_capability
is
not
None
and
compute_capability
<
90
:
return
False
,
f
"requires
capability 90, got
{
compute
_
capability
}
"
return
False
,
"requires compute
capability
90 and above.
"
try
:
import
aiter
# noqa: F401 # deliberately attempt to import aiter
except
Exception
:
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"installed on ROCm."
,
)
return
False
,
"requires `aiter` to be installed."
if
not
rocm_aiter_ops
.
is_linear_enabled
():
return
(
False
,
"AiterScaledMMLinearKernel is disabled. "
+
"Enable by setting `VLLM_ROCM_USE_AITER=1` "
"requires setting `VLLM_ROCM_USE_AITER=1` "
+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
)
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
c
.
input_symmetric
:
return
(
False
,
"AiterScaledMMLinearKernel only supports symmetric "
+
"quantization."
,
)
return
False
,
"supports symmetric quantization only."
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`Aiter
Int8
ScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
w8a8 scaled gemm. `Aiter
Int8
ScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric
=
azp_adj
is
None
assert
symmetric
,
(
"AiterScaledMMLinearKernel only supports symmetric quantization."
"Aiter
Int8
ScaledMMLinearKernel only supports symmetric quantization."
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
i_s
,
i_zp
,
symmetric
=
symmetric
)
assert
x_zp
is
None
,
(
"AiterScaledMMLinearKernel only supports symmetric quantization."
"Aiter
Int8
ScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype
=
x
.
dtype
...
...
@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
),
(
"Currently only support per-tensor-per-tensor GEMM "
+
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
" w8a8 scaled gemm. `Aiter
Int8
ScaledMMLinearKernel` "
+
"does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# Cutlass
Int8
ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
rocm_aiter_ops
.
gemm_a8w8
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
View file @
148117ea
...
...
@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
)
class
CPUScaledMMLinearKernel
(
ScaledMMLinearKernel
):
class
CPU
Int8
ScaledMMLinearKernel
(
Int8
ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cpu
():
return
False
,
"
R
equires CPU."
return
False
,
"
r
equires CPU."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
getattr
(
layer
,
self
.
w_q_name
)
w_q_name
,
_
,
_
,
_
,
_
=
self
.
layer_param_names
weight
=
getattr
(
layer
,
w_q_name
)
dtype
=
weight
.
dtype
N
,
K
=
weight
.
size
()
if
(
...
...
@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
def
process_weights_for_onednn
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# WEIGHT
# Transpose to [K, N] for convenience
weight
=
getattr
(
layer
,
self
.
w_q_name
)
w_q_name
,
w_s_name
,
i_s_name
,
i_zp_name
,
azp_adj_name
=
self
.
layer_param_names
weight
=
getattr
(
layer
,
w_q_name
)
replace_parameter
(
layer
,
self
.
w_q_name
,
w_q_name
,
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
)
...
...
@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
input_scale
=
getattr
(
layer
,
i_s_name
)
if
self
.
config
.
input_symmetric
:
replace_parameter
(
layer
,
self
.
i_s_name
,
i_s_name
,
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
else
:
input_zero_point
=
getattr
(
layer
,
self
.
i_zp_name
)
input_zero_point
=
getattr
(
layer
,
i_zp_name
)
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
...
...
@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
replace_parameter
(
layer
,
self
.
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
layer
,
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
)
azp
=
(
(
int8_traits
.
min
-
range_min
/
scale
).
round
().
to
(
dtype
=
torch
.
int32
)
)
replace_parameter
(
layer
,
self
.
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
layer
,
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
)
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
...
...
@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if
not
(
self
.
config
.
input_symmetric
and
self
.
config
.
is_static_input_scheme
):
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight
=
getattr
(
layer
,
w_q_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
float32
)
azp_adj
=
azp_adj
*
weight_scale
.
squeeze
()
setattr
(
layer
,
self
.
azp_adj_name
,
azp_adj_name
,
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
)
else
:
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
self
.
dnnl_handler
=
ops
.
create_onednn_scaled_mm
(
weight
,
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
w_s_name
),
torch
.
get_default_dtype
(),
getattr
(
layer
,
self
.
i_s_name
)
is
None
,
getattr
(
layer
,
i_s_name
)
is
None
,
not
self
.
config
.
input_symmetric
,
32
,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr
(
layer
,
self
.
w_q_name
,
None
)
setattr
(
layer
,
w_q_name
,
None
)
del
weight
def
process_weights_for_sgl
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_q_name
,
w_s_name
,
_
,
_
,
_
=
self
.
layer_param_names
# WEIGHT
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
weight
)
replace_parameter
(
layer
,
self
.
w_q_name
,
torch
.
nn
.
Parameter
(
packed_weight
,
requires_grad
=
False
)
layer
,
w_q_name
,
torch
.
nn
.
Parameter
(
packed_weight
,
requires_grad
=
False
)
)
if
layer
.
bias
is
not
None
:
...
...
@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
not
self
.
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
...
...
@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
w_q
,
w_s
,
_
,
_
,
_
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
_
,
_
,
_
=
self
.
_get_
layer
_params
(
layer
)
return
torch
.
ops
.
_C
.
int8_scaled_mm_with_quant
(
x
,
w_q
,
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
View file @
148117ea
...
...
@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from
vllm.platforms
import
current_platform
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
)
class
CutlassScaledMMLinearKernel
(
ScaledMMLinearKernel
):
class
Cutlass
Int8
ScaledMMLinearKernel
(
Int8
ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"Requires CUDA."
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
if
compute_capability
is
not
None
and
compute_capability
<
75
:
return
False
,
f
"requires capability 75, got
{
compute_capability
}
"
return
False
,
"requires CUDA."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_q_name
,
w_s_name
,
i_s_name
,
i_zp_name
,
azp_adj_name
=
self
.
layer_param_names
config
=
self
.
config
# WEIGHT
# Cutlass kernels need transposed weight.
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
replace_parameter
(
layer
,
self
.
w_q_name
,
w_q_name
,
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
)
...
...
@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
is_fused_module
and
not
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
if
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
i_s_name
)
if
self
.
config
.
input_symmetric
:
if
config
.
input_symmetric
:
replace_parameter
(
layer
,
self
.
i_s_name
,
i_s_name
,
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
i_zp_name
,
None
)
else
:
input_zero_point
=
getattr
(
layer
,
self
.
i_zp_name
)
input_zero_point
=
getattr
(
layer
,
i_zp_name
)
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
...
...
@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
replace_parameter
(
layer
,
self
.
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
layer
,
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
)
# AZP loaded as int8 but used as int32
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
replace_parameter
(
layer
,
self
.
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
layer
,
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
)
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if
not
self
.
config
.
input_symmetric
:
weight
=
getattr
(
layer
,
self
.
w_q_name
)
if
not
config
.
input_symmetric
:
weight
=
getattr
(
layer
,
w_q_name
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
if
self
.
config
.
is_static_input_scheme
:
if
config
.
is_static_input_scheme
:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj
=
getattr
(
layer
,
self
.
i_zp_name
)
*
azp_adj
azp_adj
=
getattr
(
layer
,
i_zp_name
)
*
azp_adj
setattr
(
layer
,
self
.
azp_adj_name
,
azp_adj_name
,
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
)
else
:
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
def
apply_weights
(
self
,
...
...
@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
...
...
@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return
ops
.
cutlass_scaled_mm
(
x_q
,
w_q
,
scale_a
=
x_s
,
scale_b
=
w_s
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
CutlassFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"requires CUDA."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment