Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
148117ea
Unverified
Commit
148117ea
authored
Jan 20, 2026
by
vllmellm
Committed by
GitHub
Jan 20, 2026
Browse files
[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
e9c83cdc
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
971 additions
and
560 deletions
+971
-560
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
+5
-0
tests/compile/distributed/test_fusion_all_reduce.py
tests/compile/distributed/test_fusion_all_reduce.py
+17
-27
tests/compile/distributed/test_sequence_parallelism.py
tests/compile/distributed/test_sequence_parallelism.py
+17
-26
tests/compile/test_functionalization.py
tests/compile/test_functionalization.py
+19
-22
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+182
-146
tests/compile/test_fusion_attn.py
tests/compile/test_fusion_attn.py
+20
-21
tests/compile/test_silu_mul_quant_fusion.py
tests/compile/test_silu_mul_quant_fusion.py
+44
-27
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
...s/kernels/quantization/test_scaled_mm_kernel_selection.py
+24
-22
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+1
-1
tests/utils.py
tests/utils.py
+127
-0
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+33
-30
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+10
-23
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+12
-16
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+21
-26
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
...rs/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
+144
-33
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
...xecutor/layers/quantization/kernels/scaled_mm/__init__.py
+183
-30
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
...l_executor/layers/quantization/kernels/scaled_mm/aiter.py
+17
-36
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
...del_executor/layers/quantization/kernels/scaled_mm/cpu.py
+33
-38
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
...executor/layers/quantization/kernels/scaled_mm/cutlass.py
+61
-35
No files found.
.buildkite/lm-eval-harness/configs/models-small-rocm.txt
0 → 100644
View file @
148117ea
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
tests/compile/distributed/test_fusion_all_reduce.py
View file @
148117ea
...
@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import (
...
@@ -26,15 +26,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
Fp8LinearOp
,
kFp8StaticTensorSym
,
GroupShape
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
...utils
import
has_module_attribute
,
multi_gpu_test
from
...utils
import
TestFP8Layer
,
has_module_attribute
,
multi_gpu_test
from
..backend
import
TestBackend
from
..backend
import
TestBackend
...
@@ -76,49 +75,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
...
@@ -76,49 +75,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
def
__init__
(
self
,
hidden_size
=
16
,
token_num
=
16
,
eps
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
eps
=
eps
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
fp8_linear_layers
=
[
self
.
w
=
[
TestFP8Layer
(
torch
.
rand
(
hidden_size
,
hidden_size
)
weight_shape
=
(
hidden_size
,
hidden_size
),
.
to
(
dtype
=
current_platform
.
fp8_dtype
())
activation_quant_key
=
self
.
quant_key
,
.
t
()
weight_quant_key
=
self
.
quant_key
,
for
_
in
range
(
3
)
)
for
i
in
range
(
3
)
]
]
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
# avoid having graph input be an arg to a pattern directly
z
=
torch
.
relu
(
hidden_states
)
z
=
torch
.
relu
(
hidden_states
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
y
=
self
.
norm
[
0
](
x
)
z2
=
self
.
fp8_linear
.
apply
(
z2
=
self
.
fp8_linear_layers
[
0
](
y
)
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
self
.
fp8_linear
.
apply
(
z3
=
self
.
fp8_linear_layers
[
1
](
y2
)
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
z4
=
self
.
fp8_linear
.
apply
(
z4
=
self
.
fp8_linear_layers
[
2
](
y3
)
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
return
y4
...
@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
...
@@ -130,7 +120,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return
[
return
[
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
vllm
.
all_reduce
.
default
,
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
if
self
.
fp8_linear
.
quant_fp8
.
enabled
()
if
self
.
fp8_linear
_layers
[
0
].
is_
quant_fp8
_
enabled
()
else
torch
.
ops
.
aten
.
reciprocal
.
default
,
else
torch
.
ops
.
aten
.
reciprocal
.
default
,
]
]
...
...
tests/compile/distributed/test_sequence_parallelism.py
View file @
148117ea
...
@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
...
@@ -27,13 +27,14 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.system_utils
import
update_environment_variables
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
from
...utils
import
multi_gpu_test
from
...utils
import
TestFP8Layer
,
multi_gpu_test
from
..backend
import
TestBackend
from
..backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
...
@@ -94,50 +95,40 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
class
TestAllReduceRMSNormStaticQuantFP8Model
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
eps
=
1e-6
):
def
__init__
(
self
,
hidden_size
=
16
,
eps
=
1e-6
):
super
().
__init__
()
super
().
__init__
()
self
.
vllm_config
=
get_current_vllm_config
()
self
.
vllm_config
=
get_current_vllm_config
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
eps
=
eps
self
.
eps
=
eps
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
i
in
range
(
4
)]
self
.
wscale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
fp8_linear_layers
=
[
self
.
w
=
[
TestFP8Layer
(
torch
.
rand
(
hidden_size
,
hidden_size
)
weight_shape
=
(
hidden_size
,
hidden_size
),
.
to
(
dtype
=
current_platform
.
fp8_dtype
())
activation_quant_key
=
self
.
quant_key
,
.
t
()
weight_quant_key
=
self
.
quant_key
,
for
_
in
range
(
3
)
)
for
i
in
range
(
3
)
]
]
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
# avoid having graph input be an arg to a pattern directly
# avoid having graph input be an arg to a pattern directly
z
=
torch
.
relu
(
hidden_states
)
z
=
torch
.
relu
(
hidden_states
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
x
=
resid
=
tensor_model_parallel_all_reduce
(
z
)
y
=
self
.
norm
[
0
](
x
)
y
=
self
.
norm
[
0
](
x
)
z2
=
self
.
fp8_linear
.
apply
(
z2
=
self
.
fp8_linear_layers
[
0
](
y
)
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
x2
=
tensor_model_parallel_all_reduce
(
z2
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
z3
=
self
.
fp8_linear
.
apply
(
z3
=
self
.
fp8_linear_layers
[
1
](
y2
)
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
x3
=
tensor_model_parallel_all_reduce
(
z3
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
z4
=
self
.
fp8_linear
.
apply
(
z4
=
self
.
fp8_linear_layers
[
2
](
y3
)
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
x4
=
tensor_model_parallel_all_reduce
(
z4
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
return
y4
...
@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
...
@@ -160,7 +151,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
return
[
return
[
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
,
]
]
elif
self
.
fp8_linear
.
quant_fp8
.
enabled
():
elif
any
(
layer
.
is_
quant_fp8
_
enabled
()
for
layer
in
self
.
fp8_linear_layers
)
:
return
[
return
[
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
torch
.
ops
.
_C
.
static_scaled_fp8_quant
.
default
,
]
]
...
...
tests/compile/test_functionalization.py
View file @
148117ea
...
@@ -20,11 +20,13 @@ from vllm.config import (
...
@@ -20,11 +20,13 @@ from vllm.config import (
)
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
kFp8StaticTensorSym
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
..utils
import
TestFP8Layer
from
.backend
import
TestBackend
from
.backend
import
TestBackend
TEST_FP8
=
current_platform
.
supports_fp8
()
TEST_FP8
=
current_platform
.
supports_fp8
()
...
@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
...
@@ -32,24 +34,22 @@ FP8_DTYPE = current_platform.fp8_dtype()
class
TestSiluMul
(
torch
.
nn
.
Module
):
class
TestSiluMul
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
:
int
=
128
):
def
__init__
(
self
,
hidden_size
:
int
=
128
):
super
().
__init__
()
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
if
TEST_FP8
:
if
TEST_FP8
:
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
fp8_linear
=
TestFP8Layer
(
self
.
fp8_linear
=
Fp8LinearOp
(
weight_shape
=
(
hidden_size
,
hidden_size
),
act
_quant_static
=
True
,
act
ivation_quant_key
=
self
.
quant_key
,
ac
t_quant_
group_shape
=
GroupShape
.
PER_TENSOR
,
weigh
t_quant_
key
=
self
.
quant_key
,
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
y
=
self
.
silu_and_mul
(
x
)
if
TEST_FP8
:
if
TEST_FP8
:
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
return
self
.
fp8_linear
(
y
)
return
x2
else
:
else
:
return
y
return
y
...
@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
...
@@ -67,6 +67,8 @@ class TestSiluMul(torch.nn.Module):
class
TestFusedAddRMSNorm
(
torch
.
nn
.
Module
):
class
TestFusedAddRMSNorm
(
torch
.
nn
.
Module
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
def
__init__
(
self
,
hidden_size
=
16
,
intermediate_size
=
32
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
...
@@ -81,11 +83,11 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
torch
.
nn
.
init
.
normal_
(
self
.
gate_proj
,
std
=
0.02
)
if
TEST_FP8
:
if
TEST_FP8
:
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
)
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
intermediate_size
),
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
activation_quant_key
=
self
.
quant_key
,
self
.
w
=
torch
.
rand
(
hidden_size
,
intermediate_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
weight_quant_key
=
self
.
quant_key
,
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
)
def
forward
(
self
,
hidden_states
,
residual
):
def
forward
(
self
,
hidden_states
,
residual
):
# Reshape input
# Reshape input
...
@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
...
@@ -100,12 +102,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
if
TEST_FP8
:
if
TEST_FP8
:
# scaled_mm with static input quantization
# scaled_mm with static input quantization
fp8_linear_result
=
self
.
fp8_linear
.
apply
(
fp8_linear_result
=
self
.
fp8_linear
(
norm_output
)
norm_output
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
scale
.
to
(
norm_output
.
device
),
)
return
fp8_linear_result
,
residual_output
return
fp8_linear_result
,
residual_output
...
...
tests/compile/test_fusion.py
View file @
148117ea
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
pytest
import
pytest
import
torch
import
torch
import
vllm.config
import
vllm.plugins
import
vllm.plugins
from
vllm._aiter_ops
import
IS_AITER_FOUND
,
rocm_aiter_ops
from
vllm._aiter_ops
import
IS_AITER_FOUND
,
rocm_aiter_ops
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
from
vllm.compilation.fusion
import
FUSED_OPS
,
FusedRMSQuantKey
,
RMSNormQuantFusionPass
...
@@ -20,8 +21,22 @@ from vllm.config import (
...
@@ -20,8 +21,22 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
)
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
W8A8BlockFp8LinearOp
,
CutlassFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
...
@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -29,15 +44,14 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ScaleDesc
,
ScaleDesc
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.deep_gemm
import
is_deep_gemm_supported
from
vllm.utils.deep_gemm
import
(
is_deep_gemm_supported
,
)
from
..utils
import
override_cutlass_fp8_supported
from
..utils
import
TestBlockFP8Layer
,
TestFP8Layer
from
.backend
import
TestBackend
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -45,157 +59,195 @@ FP8_DTYPE = current_platform.fp8_dtype()
...
@@ -45,157 +59,195 @@ FP8_DTYPE = current_platform.fp8_dtype()
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_OP
=
torch
.
ops
.
_C
.
rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
RMS_ADD_OP
=
torch
.
ops
.
_C
.
fused_add_rms_norm
.
default
# Kernel and group_shape combinations: (kernel, group_shape)
# CUDA kernels
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# FlashInferFP8ScaledMMLinearKernel supports both per-tensor only
(
FlashInferFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# CutlassFP8ScaledMMLinearKernel supports both per-tensor and per-token
(
CutlassFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
(
CutlassFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# PerTensorTorchFP8ScaledMMLinearKernel only supports per-tensor
(
PerTensorTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# Blockwise group shapes (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
)),
(
None
,
GroupShape
(
1
,
64
)),
]
# ROCm kernels
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# ROCmFP8ScaledMMLinearKernel supports per-tensor only
(
ROCmFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
),
# RowWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
),
# Blockwise group shapes (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
)),
(
None
,
GroupShape
(
1
,
64
)),
]
KERNEL_GROUPSHAPE_COMBINATIONS
=
(
CUDA_KERNEL_GROUPSHAPE_COMBINATIONS
if
current_platform
.
is_cuda
()
else
ROCM_KERNEL_GROUPSHAPE_COMBINATIONS
)
# For Aiter tests we toggle use_aiter_quant_op
AITER_KERNEL_GROUPSHAPE_COMBINATIONS
=
[
# Per-token with ROCmFP8ScaledMMLinearKernel
(
ROCmFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TENSOR
,
False
),
# Per-token with RowWiseTorchFP8ScaledMMLinearKernel
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
True
),
(
RowWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
False
),
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
True
),
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
GroupShape
.
PER_TOKEN
,
False
),
# Blockwise (no kernel abstraction)
(
None
,
GroupShape
(
1
,
128
),
True
),
]
class
TestModel
(
torch
.
nn
.
Module
):
class
TestModel
(
torch
.
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
eps
:
float
,
eps
:
float
,
force_kernel
:
FP8ScaledMMLinearKernel
|
None
,
group_shape
:
GroupShape
,
group_shape
:
GroupShape
,
use_aiter
:
bool
=
False
,
use_aiter_fusion
:
bool
=
False
,
cuda_force_torch
:
bool
=
False
,
use_aiter_quant
:
bool
=
False
,
use_aiter_quant_op
:
bool
=
True
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
use_aiter
=
use_aiter
self
.
fp8_linear_layers
:
list
[
torch
.
nn
.
Module
]
self
.
use_aiter_quant_op
=
use_aiter_quant_op
self
.
cuda_force_torch
=
cuda_force_torch
self
.
group_shape
=
group_shape
self
.
group_shape
=
group_shape
self
.
enable_quant_fp8_custom_op
=
None
# Will be set later if applicable
self
.
use_aiter_quant_op
=
use_aiter_quant
self
.
use_aiter_fusion
=
use_aiter_fusion
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
norm
=
[
RMSNorm
(
hidden_size
,
eps
)
for
_
in
range
(
4
)]
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
# Setup quantization scale descriptor
# Determine if blockwise based on group_shape
static
=
group_shape
==
GroupShape
.
PER_TENSOR
and
not
use_aiter
is_blockwise
=
group_shape
.
is_per_group
()
quant_scale
=
ScaleDesc
(
torch
.
float32
,
static
,
group_shape
)
self
.
quant_key
=
QuantKey
(
dtype
=
FP8_DTYPE
,
scale
=
quant_scale
,
symmetric
=
True
)
# Setup scales
if
is_blockwise
:
if
static
:
act_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
False
,
group_shape
)
self
.
scale
=
[
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
self
.
activation_quant_key
=
QuantKey
(
else
:
dtype
=
FP8_DTYPE
,
scale
=
act_quant_scale_desc
,
symmetric
=
True
self
.
scale
=
[
None
for
_
in
range
(
3
)]
)
self
.
fp8_linear_layers
=
[
TestBlockFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
(),
use_aiter_and_is_supported
=
use_aiter_quant
,
transpose_weights
=
use_aiter_fusion
,
)
for
_
in
range
(
3
)
]
# Setup weights
self
.
enable_quant_fp8_custom_op
=
(
self
.
w
=
[
False
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
)
for
_
in
range
(
3
)
if
use_aiter_quant
]
else
self
.
fp8_linear_layers
[
0
].
linear_op
.
input_quant_op
.
enabled
()
if
not
group_shape
.
is_per_group
()
or
use_aiter
:
self
.
w
=
[
self
.
w
[
0
].
t
()
for
_
in
range
(
3
)]
# Setup weight scales
if
group_shape
.
is_per_group
():
scale_size
=
(
(
hidden_size
+
128
-
1
)
//
128
if
use_aiter
else
hidden_size
//
group_shape
[
1
]
)
)
wscale_shape
:
tuple
[
int
,
...]
=
(
scale_size
,
scale_size
)
else
:
else
:
wscale_shape
=
(
1
,)
is_static
=
group_shape
==
GroupShape
.
PER_TENSOR
self
.
wscale
=
[
torch
.
rand
(
wscale_shape
,
dtype
=
torch
.
float32
)
for
_
in
range
(
3
)]
act_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
is_static
,
group_shape
)
w_quant_scale_desc
=
ScaleDesc
(
torch
.
float32
,
True
,
group_shape
)
# Setup FP8 linear operation
self
.
activation_quant_key
=
QuantKey
(
is_per_group
=
group_shape
.
is_per_group
()
dtype
=
FP8_DTYPE
,
scale
=
act_quant_scale_desc
,
symmetric
=
True
if
is_per_group
and
use_aiter
:
self
.
fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
128
,
128
),
act_quant_group_shape
=
group_shape
,
use_aiter_and_is_supported
=
use_aiter_quant_op
,
)
# AITER blockwise doesn't use enable_quant_fp8_custom_op
elif
is_per_group
:
self
.
fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
group_shape
[
1
],
group_shape
[
1
]),
act_quant_group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
(),
use_aiter_and_is_supported
=
False
,
)
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
input_quant_op
.
enabled
()
self
.
weight_quant_key
=
QuantKey
(
elif
use_aiter
:
dtype
=
FP8_DTYPE
,
scale
=
w_quant_scale_desc
,
symmetric
=
True
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
False
,
act_quant_group_shape
=
group_shape
,
)
)
self
.
fp8_linear
.
quant_fp8
.
use_aiter
=
use_aiter_quant_op
self
.
fp8_linear_layers
=
[
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
TestFP8Layer
(
else
:
weight_shape
=
(
hidden_size
,
hidden_size
),
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
activation_quant_key
=
self
.
activation_quant_key
,
self
.
fp8_linear
=
Fp8LinearOp
(
weight_quant_key
=
self
.
weight_quant_key
,
act_quant_static
=
static
,
force_kernel
=
force_kernel
,
act_quant_group_shape
=
group_shape
,
)
)
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
for
_
in
range
(
3
)
]
self
.
enable_rms_norm_custom_op
=
self
.
norm
[
0
].
enabled
()
# Enable aiter quantization if requested
for
layer
in
self
.
fp8_linear_layers
:
layer
.
kernel
.
quant_fp8
.
use_aiter
=
use_aiter_quant
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear_layers
[
0
].
is_quant_fp8_enabled
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
# avoid having graph input be an arg to a pattern directly
# avoid having graph input be an arg to a pattern directly
x
=
resid
=
torch
.
relu
(
x
)
x
=
resid
=
torch
.
relu
(
x
)
y
=
self
.
norm
[
0
](
x
)
y
=
self
.
norm
[
0
](
x
)
x2
=
self
.
fp8_linear
.
apply
(
x2
=
self
.
fp8_linear_layers
[
0
](
y
)
y
,
self
.
w
[
0
],
self
.
wscale
[
0
],
input_scale
=
self
.
scale
[
0
]
)
# make sure resid is used for replacement to work
# make sure resid is used for replacement to work
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
y2
,
resid
=
self
.
norm
[
1
](
x2
,
resid
)
x3
=
self
.
fp8_linear
.
apply
(
x3
=
self
.
fp8_linear_layers
[
1
](
y2
)
y2
,
self
.
w
[
1
],
self
.
wscale
[
1
],
input_scale
=
self
.
scale
[
1
]
)
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
y3
,
resid
=
self
.
norm
[
2
](
x3
,
resid
)
# use resid here
x4
=
self
.
fp8_linear
.
apply
(
x4
=
self
.
fp8_linear_layers
[
2
](
y3
)
y3
,
self
.
w
[
2
],
self
.
wscale
[
2
],
input_scale
=
self
.
scale
[
2
]
)
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
y4
,
resid
=
self
.
norm
[
3
](
x4
,
resid
)
# use resid here
return
y4
return
y4
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
if
(
if
self
.
group_shape
.
is_per_group
():
self
.
use_aiter
# Blockwise path
and
self
.
group_shape
.
is_per_group
()
if
self
.
use_aiter_fusion
and
self
.
use_aiter_quant_op
:
and
current_platform
.
is_fp8_fnuz
()
return
[
rocm_aiter_ops
.
get_group_quant_op
()]
):
if
self
.
use_aiter_fusion
:
return
[
rocm_aiter_ops
.
get_group_quant_op
()]
return
[
torch
.
ops
.
vllm
.
triton_per_token_group_quant_fp8
.
default
]
if
self
.
use_aiter
and
self
.
group_shape
.
is_per_group
():
else
:
return
[
torch
.
ops
.
vllm
.
triton_per_token_group_quant_fp8
.
default
]
if
self
.
use_aiter_quant_op
:
if
self
.
use_aiter
and
self
.
use_aiter_quant_op
:
return
[
rocm_aiter_ops
.
get_per_token_quant_op
()]
return
[
rocm_aiter_ops
.
get_per_token_quant_op
()]
if
self
.
use_aiter
:
# Common path
return
[
QUANT_OPS
[
self
.
quant_key
]]
return
(
if
self
.
enable_quant_fp8_custom_op
:
[
QUANT_OPS
[
self
.
activation_quant_key
]]
return
[
QUANT_OPS
[
self
.
quant_key
]]
if
self
.
enable_quant_fp8_custom_op
return
[
torch
.
ops
.
aten
.
reciprocal
]
else
[
torch
.
ops
.
aten
.
reciprocal
]
)
def
ops_in_model_after
(
self
):
def
ops_in_model_after
(
self
):
if
self
.
use_aiter
and
self
.
group_shape
.
is_per_group
():
if
self
.
use_aiter_fusion
:
from
vllm.compilation.rocm_aiter_fusion
import
(
if
self
.
group_shape
.
is_per_group
():
AiterFusedAddRMSFp8GroupQuantPattern
,
# Blockwise aiter fusion
AiterRMSFp8GroupQuantPattern
,
from
vllm.compilation.rocm_aiter_fusion
import
(
)
AiterFusedAddRMSFp8GroupQuantPattern
,
AiterRMSFp8GroupQuantPattern
,
)
return
[
return
[
AiterFusedAddRMSFp8GroupQuantPattern
.
FUSED_OP
,
AiterFusedAddRMSFp8GroupQuantPattern
.
FUSED_OP
,
AiterRMSFp8GroupQuantPattern
.
FUSED_OP
,
AiterRMSFp8GroupQuantPattern
.
FUSED_OP
,
]
]
if
self
.
use_aiter
:
else
:
from
vllm.compilation.rocm_aiter_fusion
import
(
# Per-token aiter fusion
AiterFusedAddRMSNormDynamicQuantPattern
,
from
vllm.compilation.rocm_aiter_fusion
import
(
AiterRMSNormDynamicQuantPattern
,
AiterFusedAddRMSNormDynamicQuantPattern
,
)
AiterRMSNormDynamicQuantPattern
,
)
return
[
return
[
AiterFusedAddRMSNormDynamicQuantPattern
.
FUSED_OP
,
AiterFusedAddRMSNormDynamicQuantPattern
.
FUSED_OP
,
AiterRMSNormDynamicQuantPattern
.
FUSED_OP
,
AiterRMSNormDynamicQuantPattern
.
FUSED_OP
,
]
]
# Regular fusion
return
[
return
[
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_key
,
True
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
activation_
quant_key
,
True
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
quant_key
,
False
)],
FUSED_OPS
[
FusedRMSQuantKey
(
self
.
activation_
quant_key
,
False
)],
]
]
def
ops_in_model_before_partial
(
self
):
def
ops_in_model_before_partial
(
self
):
...
@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module):
...
@@ -206,14 +258,6 @@ class TestModel(torch.nn.Module):
)
)
GROUP_SHAPES
=
[
GroupShape
.
PER_TOKEN
,
GroupShape
.
PER_TENSOR
,
GroupShape
(
1
,
128
),
GroupShape
(
1
,
64
),
]
def
_run_fusion_test
(
def
_run_fusion_test
(
model
,
model
,
fusion_pass
,
fusion_pass
,
...
@@ -259,14 +303,9 @@ def _run_fusion_test(
...
@@ -259,14 +303,9 @@ def _run_fusion_test(
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"group
_
shape"
,
GROUP
_
SHAPES
)
@
pytest
.
mark
.
parametrize
(
"
kernel_
groupshape"
,
KERNEL_
GROUPSHAPE
_COMBINATION
S
)
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_rms_norm_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_quant_fp8_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_quant_fp8_custom_op"
,
[
True
,
False
])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
parametrize
(
"cuda_force_torch"
,
[
True
,
False
]
if
cutlass_fp8_supported
()
else
[
True
]
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test on CUDA and ROCm"
not
current_platform
.
is_cuda_alike
(),
reason
=
"Only test on CUDA and ROCm"
)
)
...
@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant(
...
@@ -275,11 +314,12 @@ def test_fusion_rmsnorm_quant(
hidden_size
,
hidden_size
,
num_tokens
,
num_tokens
,
eps
,
eps
,
group
_
shape
,
kernel_
groupshape
,
enable_rms_norm_custom_op
,
enable_rms_norm_custom_op
,
enable_quant_fp8_custom_op
,
enable_quant_fp8_custom_op
,
cuda_force_torch
,
):
):
force_kernel
,
group_shape
=
kernel_groupshape
if
not
enable_quant_fp8_custom_op
and
group_shape
.
is_per_group
():
if
not
enable_quant_fp8_custom_op
and
group_shape
.
is_per_group
():
pytest
.
skip
(
"Unsupported unwrapped quant fp8 op for blockwise quantization"
)
pytest
.
skip
(
"Unsupported unwrapped quant fp8 op for blockwise quantization"
)
...
@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant(
...
@@ -310,15 +350,16 @@ def test_fusion_rmsnorm_quant(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
fusion_pass
=
RMSNormQuantFusionPass
(
vllm_config
)
model
=
TestModel
(
model
=
TestModel
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
eps
=
eps
,
eps
=
eps
,
force_kernel
=
force_kernel
,
group_shape
=
group_shape
,
group_shape
=
group_shape
,
use_aiter
=
False
,
use_aiter
_fusion
=
False
,
cuda_force_torch
=
cuda_force_torch
,
use_aiter_quant
=
False
,
)
)
backend
,
_
=
_run_fusion_test
(
backend
,
_
=
_run_fusion_test
(
...
@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant(
...
@@ -339,19 +380,12 @@ def test_fusion_rmsnorm_quant(
assert
n_add_nodes
(
backend
.
graph_post_pass
)
==
2
assert
n_add_nodes
(
backend
.
graph_post_pass
)
==
2
GROUP_SHAPE_QUANT_OPS_MATCHS
=
[
(
GroupShape
.
PER_TOKEN
,
True
),
(
GroupShape
.
PER_TOKEN
,
False
),
(
GroupShape
(
1
,
128
),
True
),
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
256
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
257
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
1e-5
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"group
_
shape
, use_aiter_quant_op"
,
GROUP
_
SHAPE_
QUANT_OPS_MATCH
S
"
kernel_
groupshape
_quant"
,
AITER_KERNEL_
GROUPSHAPE_
COMBINATION
S
)
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
(
not
current_platform
.
is_rocm
()
or
not
IS_AITER_FOUND
),
(
not
current_platform
.
is_rocm
()
or
not
IS_AITER_FOUND
),
...
@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant(
...
@@ -362,10 +396,10 @@ def test_aiter_fusion_rmsnorm_quant(
hidden_size
:
int
,
hidden_size
:
int
,
num_tokens
:
int
,
num_tokens
:
int
,
eps
:
float
,
eps
:
float
,
group_shape
:
GroupShape
,
kernel_groupshape_quant
:
tuple
,
use_aiter_quant_op
:
bool
,
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
):
):
force_kernel
,
group_shape
,
use_aiter_quant_op
=
kernel_groupshape_quant
vllm_config
=
VllmConfig
(
vllm_config
=
VllmConfig
(
model_config
=
ModelConfig
(
dtype
=
dtype
),
model_config
=
ModelConfig
(
dtype
=
dtype
),
compilation_config
=
CompilationConfig
(
compilation_config
=
CompilationConfig
(
...
@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant(
...
@@ -379,20 +413,22 @@ def test_aiter_fusion_rmsnorm_quant(
from
vllm.compilation.rocm_aiter_fusion
import
RocmAiterRMSNormFusionPass
from
vllm.compilation.rocm_aiter_fusion
import
RocmAiterRMSNormFusionPass
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
m
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
rocm_aiter_ops
.
refresh_env_variables
()
rocm_aiter_ops
.
refresh_env_variables
()
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
torch
.
manual_seed
(
1
)
torch
.
manual_seed
(
1
)
maybe_create_device_identity
()
fusion_pass
=
RocmAiterRMSNormFusionPass
(
vllm_config
)
fusion_pass
=
RocmAiterRMSNormFusionPass
(
vllm_config
)
model
=
TestModel
(
model
=
TestModel
(
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
eps
=
eps
,
eps
=
eps
,
force_kernel
=
force_kernel
,
group_shape
=
group_shape
,
group_shape
=
group_shape
,
use_aiter
=
True
,
use_aiter
_fusion
=
True
,
# Always use aiter fusion ops in aiter test
use_aiter_quant
_op
=
use_aiter_quant_op
,
use_aiter_quant
=
use_aiter_quant_op
,
# Toggle aiter quantization
)
)
_run_fusion_test
(
_run_fusion_test
(
...
...
tests/compile/test_fusion_attn.py
View file @
148117ea
...
@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -45,7 +45,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kNvfp4Quant
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
Fp8LinearOp
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.flashinfer
import
has_flashinfer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata
...
@@ -53,6 +52,8 @@ from vllm.v1.attention.backend import AttentionMetadata
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
..utils
import
TestFP8Layer
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP4_DTYPE
=
torch
.
uint8
FP4_DTYPE
=
torch
.
uint8
...
@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
...
@@ -185,32 +186,30 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
self
.
quant_key
.
scale
.
static
,
act_quant_group_shape
=
self
.
quant_key
.
scale
.
group_shape
,
)
hidden_size
=
self
.
num_qo_heads
*
self
.
head_size
hidden_size
=
self
.
num_qo_heads
*
self
.
head_size
self
.
w
=
kwargs
.
get
(
self
.
fp8_linear
=
TestFP8Layer
(
"w"
,
weight_shape
=
(
hidden_size
,
hidden_size
),
{
activation_quant_key
=
self
.
quant_key
,
"weight"
:
torch
.
randn
(
hidden_size
,
hidden_size
)
weight_quant_key
=
self
.
quant_key
,
.
to
(
dtype
=
FP8_DTYPE
,
device
=
self
.
device
)
device
=
self
.
device
,
.
t
(),
"wscale"
:
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
"scale"
:
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
},
)
)
w
=
kwargs
.
get
(
"w"
)
if
w
is
not
None
:
self
.
fp8_linear
.
weight
=
w
[
"weight"
]
self
.
fp8_linear
.
weight_scale
=
w
[
"wscale"
]
self
.
fp8_linear
.
input_scale
=
w
[
"scale"
]
self
.
w
=
{
"weight"
:
self
.
fp8_linear
.
weight
,
"wscale"
:
self
.
fp8_linear
.
weight_scale
,
"scale"
:
self
.
fp8_linear
.
input_scale
,
}
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
):
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
):
"""Forward pass that creates the pattern to be fused."""
"""Forward pass that creates the pattern to be fused."""
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
(
attn_output
)
input
=
attn_output
,
weight
=
self
.
w
[
"weight"
],
weight_scale
=
self
.
w
[
"wscale"
],
input_scale
=
self
.
w
[
"scale"
],
)
class
TestAttentionNvfp4QuantPatternModel
(
AttentionQuantPatternModel
):
class
TestAttentionNvfp4QuantPatternModel
(
AttentionQuantPatternModel
):
...
...
tests/compile/test_silu_mul_quant_fusion.py
View file @
148117ea
...
@@ -25,19 +25,30 @@ from vllm.config import (
...
@@ -25,19 +25,30 @@ from vllm.config import (
set_current_vllm_config
,
set_current_vllm_config
,
)
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
PerTensorTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kNvfp4Quant
,
kNvfp4Quant
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
..utils
import
override_cutlass_fp8_supported
from
..utils
import
TestFP8Layer
from
.backend
import
TestBackend
from
.backend
import
TestBackend
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
...
@@ -49,25 +60,27 @@ def is_nvfp4_supported():
...
@@ -49,25 +60,27 @@ def is_nvfp4_supported():
class
TestSiluMulFp8QuantModel
(
torch
.
nn
.
Module
):
class
TestSiluMulFp8QuantModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
cuda_force_torch
:
bool
,
**
kwargs
):
quant_key
=
kFp8StaticTensorSym
def
__init__
(
self
,
hidden_size
:
int
,
force_kernel
:
FP8ScaledMMLinearKernel
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
silu_and_mul
=
SiluAndMul
()
self
.
wscale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
scale
=
torch
.
rand
(
1
,
dtype
=
torch
.
float32
)
self
.
w
=
torch
.
rand
(
hidden_size
,
hidden_size
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
fp8_linear
=
TestFP8Layer
(
weight_shape
=
(
hidden_size
,
hidden_size
),
activation_quant_key
=
self
.
quant_key
,
weight_quant_key
=
self
.
quant_key
,
force_kernel
=
force_kernel
,
)
with
override_cutlass_fp8_supported
(
not
cuda_force_torch
):
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
True
,
act_quant_group_shape
=
GroupShape
.
PER_TENSOR
,
)
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
self
.
enable_silu_mul_custom_op
=
self
.
silu_and_mul
.
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
quant_fp8
.
enabled
()
self
.
enable_quant_fp8_custom_op
=
self
.
fp8_linear
.
is_
quant_fp8
_
enabled
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
silu_and_mul
(
x
)
y
=
self
.
silu_and_mul
(
x
)
x2
=
self
.
fp8_linear
.
apply
(
y
,
self
.
w
,
self
.
wscale
,
input_scale
=
self
.
wscale
)
x2
=
self
.
fp8_linear
(
y
)
return
x2
return
x2
def
ops_in_model_before
(
self
):
def
ops_in_model_before
(
self
):
...
@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
...
@@ -161,20 +174,27 @@ class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
return
[
torch
.
ops
.
vllm
.
rocm_aiter_act_mul_and_fp8_group_quant
]
return
[
torch
.
ops
.
vllm
.
rocm_aiter_act_mul_and_fp8_group_quant
]
ROCM_KERNELS
=
[
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
]
CUDA_KERNELS
=
[
FlashInferFP8ScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
]
TEST_KERNELS
=
ROCM_KERNELS
if
current_platform
.
is_rocm
()
else
CUDA_KERNELS
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"enable_silu_mul_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_silu_mul_custom_op"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model_class, enable_quant_fp8_custom_op,
cuda_
force_
torch
"
,
"model_class, enable_quant_fp8_custom_op, force_
kernel
"
,
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
[
True
,
False
]
))
list
(
itertools
.
product
([
TestSiluMulFp8QuantModel
],
[
True
,
False
],
TEST_KERNELS
))
+
[
+
[
(
TestSiluMulNvfp4QuantModel
,
False
,
Fals
e
),
(
TestSiluMulNvfp4QuantModel
,
False
,
Non
e
),
(
TestSiluMulGroupFp8QuantModel
,
False
,
Fals
e
),
(
TestSiluMulGroupFp8QuantModel
,
False
,
Non
e
),
],
],
)
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
],
reason
=
"Only test on CUDA and ROCm"
envs
.
VLLM_TARGET_DEVICE
not
in
[
"cuda"
,
"rocm"
],
reason
=
"Only test on CUDA and ROCm"
)
)
...
@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
...
@@ -189,7 +209,7 @@ def test_fusion_silu_and_mul_quant(
],
],
enable_silu_mul_custom_op
:
bool
,
enable_silu_mul_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
enable_quant_fp8_custom_op
:
bool
,
cuda_
force_
torch
:
bool
,
force_
kernel
:
FP8ScaledMMLinearKernel
|
None
,
):
):
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
():
if
model_class
is
TestSiluMulNvfp4QuantModel
and
not
is_nvfp4_supported
():
pytest
.
skip
(
"NVFP4 is not supported on this GPU."
)
pytest
.
skip
(
"NVFP4 is not supported on this GPU."
)
...
@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
...
@@ -198,7 +218,6 @@ def test_fusion_silu_and_mul_quant(
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
dtype
)
maybe_create_device_identity
()
x
=
torch
.
rand
(
num_tokens
,
hidden_size
*
2
)
x
=
torch
.
rand
(
num_tokens
,
hidden_size
*
2
)
...
@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
...
@@ -227,9 +246,7 @@ def test_fusion_silu_and_mul_quant(
passes
=
[
NoOpEliminationPass
(
config
),
*
fusion_passes
,
PostCleanupPass
(
config
)]
passes
=
[
NoOpEliminationPass
(
config
),
*
fusion_passes
,
PostCleanupPass
(
config
)]
backend
=
TestBackend
(
*
passes
)
backend
=
TestBackend
(
*
passes
)
model
=
model_class
(
model
=
model_class
(
hidden_size
=
hidden_size
,
force_kernel
=
force_kernel
,
x
=
x
)
hidden_size
=
hidden_size
,
cuda_force_torch
=
cuda_force_torch
,
x
=
x
)
# First dimension dynamic
# First dimension dynamic
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
x
,
0
)
...
...
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
View file @
148117ea
...
@@ -11,13 +11,13 @@ from abc import ABC
...
@@ -11,13 +11,13 @@ from abc import ABC
import
pytest
import
pytest
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
ScaledMMLinearLayerConfig
,
Int8
ScaledMMLinearLayerConfig
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
,
Aiter
Int8
ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
CPUScaledMMLinearKernel
,
CPU
Int8
ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
ScaledMMLinearKernel
,
ScaledMMLinearKernel
,
...
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
...
@@ -33,36 +33,38 @@ def test_is_supported_is_abstract():
def
test_cpu_kernel_implements_is_supported
():
def
test_cpu_kernel_implements_is_supported
():
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
"""Test that CPU
Int8
ScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
CPUScaledMMLinearKernel
,
"is_supported"
),
(
assert
hasattr
(
CPU
Int8
ScaledMMLinearKernel
,
"is_supported"
),
(
"CPUScaledMMLinearKernel missing is_supported() method"
"CPU
Int8
ScaledMMLinearKernel missing is_supported() method"
)
)
# Verify it's a classmethod by checking if it can be called with the class
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
# and by checking the method type
assert
inspect
.
ismethod
(
CPUScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
assert
inspect
.
ismethod
(
CPUScaledMMLinearKernel
.
is_supported
CPUInt8ScaledMMLinearKernel
.
is_supported
),
"CPUScaledMMLinearKernel.is_supported() should be a classmethod"
)
or
inspect
.
isfunction
(
CPUInt8ScaledMMLinearKernel
.
is_supported
),
(
"CPUInt8ScaledMMLinearKernel.is_supported() should be a classmethod"
)
# Verify it can be called as a classmethod
# Verify it can be called as a classmethod
result
,
reason
=
CPUScaledMMLinearKernel
.
is_supported
()
result
,
reason
=
CPU
Int8
ScaledMMLinearKernel
.
is_supported
()
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
def
test_aiter_kernel_implements_is_supported
():
def
test_aiter_kernel_implements_is_supported
():
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
"""Test that Aiter
Int8
ScaledMMLinearKernel implements is_supported() method."""
assert
hasattr
(
AiterScaledMMLinearKernel
,
"is_supported"
),
(
assert
hasattr
(
Aiter
Int8
ScaledMMLinearKernel
,
"is_supported"
),
(
"AiterScaledMMLinearKernel missing is_supported() method"
"Aiter
Int8
ScaledMMLinearKernel missing is_supported() method"
)
)
# Verify it's a classmethod by checking if it can be called with the class
# Verify it's a classmethod by checking if it can be called with the class
# and by checking the method type
# and by checking the method type
assert
inspect
.
ismethod
(
assert
inspect
.
ismethod
(
AiterScaledMMLinearKernel
.
is_supported
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
)
or
inspect
.
isfunction
(
AiterScaledMMLinearKernel
.
is_supported
),
(
)
or
inspect
.
isfunction
(
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
),
(
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
"Aiter
Int8
ScaledMMLinearKernel.is_supported() should be a classmethod"
)
)
# Verify it can be called as a classmethod
# Verify it can be called as a classmethod
# (will return False on CPU, which is expected)
# (will return False on CPU, which is expected)
result
,
reason
=
AiterScaledMMLinearKernel
.
is_supported
()
result
,
reason
=
Aiter
Int8
ScaledMMLinearKernel
.
is_supported
()
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
isinstance
(
result
,
bool
),
"is_supported() should return a bool"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
assert
reason
is
None
or
isinstance
(
reason
,
str
),
"reason should be str or None"
# On CPU, it should return False with a reason about requiring ROCm
# On CPU, it should return False with a reason about requiring ROCm
...
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
...
@@ -70,14 +72,14 @@ def test_aiter_kernel_implements_is_supported():
def
test_cpu_kernel_accepts_all_configs
():
def
test_cpu_kernel_accepts_all_configs
():
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
"""Test that CPU
Int8
ScaledMMLinearKernel accepts all config combinations."""
configs
=
[
configs
=
[
ScaledMMLinearLayerConfig
(
Int8
ScaledMMLinearLayerConfig
(
is_channelwise
=
False
,
is_channelwise
=
False
,
is_static_input_scheme
=
True
,
is_static_input_scheme
=
True
,
input_symmetric
=
True
,
input_symmetric
=
True
,
),
),
ScaledMMLinearLayerConfig
(
Int8
ScaledMMLinearLayerConfig
(
is_channelwise
=
True
,
is_channelwise
=
True
,
is_static_input_scheme
=
False
,
is_static_input_scheme
=
False
,
input_symmetric
=
False
,
input_symmetric
=
False
,
...
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
...
@@ -85,7 +87,7 @@ def test_cpu_kernel_accepts_all_configs():
]
]
for
config
in
configs
:
for
config
in
configs
:
can_impl
,
reason
=
CPUScaledMMLinearKernel
.
can_implement
(
config
)
can_impl
,
reason
=
CPU
Int8
ScaledMMLinearKernel
.
can_implement
(
config
)
assert
can_impl
,
(
assert
can_impl
,
(
f
"CPUScaledMMLinearKernel should accept config
{
config
}
:
{
reason
}
"
f
"CPU
Int8
ScaledMMLinearKernel should accept config
{
config
}
:
{
reason
}
"
)
)
tests/quantization/test_compressed_tensors.py
View file @
148117ea
...
@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [
...
@@ -41,7 +41,7 @@ ROCM_AITER_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
]
]
# TritonScaledMMLinearKernel only supports symmetric quantization.
# Triton
Int8
ScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
=
[
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
=
[
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
...
...
tests/utils.py
View file @
148117ea
...
@@ -42,6 +42,17 @@ from vllm.distributed import (
...
@@ -42,6 +42,17 @@ from vllm.distributed import (
)
)
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.entrypoints.cli.serve
import
ServeSubcommand
from
vllm.entrypoints.cli.serve
import
ServeSubcommand
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
W8A8BlockFp8LinearOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
QuantKey
,
)
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.tokenizers
import
get_tokenizer
from
vllm.tokenizers
import
get_tokenizer
...
@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
...
@@ -50,6 +61,8 @@ from vllm.utils.mem_constants import GB_bytes
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.network_utils
import
get_open_port
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
from
vllm.utils.torch_utils
import
cuda_device_count_stateless
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
from
amdsmi
import
(
from
amdsmi
import
(
amdsmi_get_gpu_vram_usage
,
amdsmi_get_gpu_vram_usage
,
...
@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
...
@@ -1332,3 +1345,117 @@ def flat_product(*iterables: Iterable[Any]):
for
element
in
itertools
.
product
(
*
iterables
):
for
element
in
itertools
.
product
(
*
iterables
):
normalized
=
(
e
if
isinstance
(
e
,
tuple
)
else
(
e
,)
for
e
in
element
)
normalized
=
(
e
if
isinstance
(
e
,
tuple
)
else
(
e
,)
for
e
in
element
)
yield
tuple
(
itertools
.
chain
(
*
normalized
))
yield
tuple
(
itertools
.
chain
(
*
normalized
))
class
TestFP8Layer
(
torch
.
nn
.
Module
):
"""
Test helper for FP8 linear operations. Creates random weights and scales
based on quantization configuration.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
activation_quant_key: Activation quantization configuration.
weight_quant_key: Weight quantization configuration.
out_dtype: Output dtype. Defaults to current default dtype.
force_kernel: Optional kernel to force use of specific implementation.
"""
def
__init__
(
self
,
weight_shape
:
tuple
[
int
,
int
],
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
out_dtype
:
torch
.
dtype
|
None
=
None
,
device
:
torch
.
device
|
None
=
None
,
force_kernel
:
FP8ScaledMMLinearKernel
|
None
=
None
,
):
super
().
__init__
()
per_tensor_weights
=
weight_quant_key
.
scale
.
group_shape
.
is_per_tensor
()
is_static_activation_scale
=
activation_quant_key
.
scale
.
static
weight_scale_shape
=
(
1
,)
if
per_tensor_weights
else
(
weight_shape
[
0
],
1
)
self
.
weight_scale
=
torch
.
rand
(
weight_scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
self
.
input_scale
=
(
torch
.
rand
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
if
is_static_activation_scale
else
None
)
self
.
weight
=
torch
.
rand
(
weight_shape
,
device
=
device
).
to
(
dtype
=
FP8_DTYPE
).
t
()
self
.
input_scale_ub
=
None
out_dtype
=
torch
.
get_default_dtype
()
if
out_dtype
is
None
else
out_dtype
self
.
kernel
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
weight_quant_key
,
out_dtype
=
out_dtype
,
force_kernel
=
force_kernel
,
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
kernel
.
quant_fp8
.
enabled
()
def
forward
(
self
,
y
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
return
self
.
kernel
.
apply_weights
(
self
,
y
,
bias
)
# TODO: Drop TestBlockFP8Layer in favour of a unified TestFP8Layer
# after refactoring W8A8BlockFp8LinearOp.
# https://github.com/vllm-project/vllm/issues/31818
class
TestBlockFP8Layer
:
"""
Test helper for blockwise FP8 linear operations. Creates random weights
and scales for W8A8BlockFp8LinearOp.
This is a workaround until W8A8BlockFp8LinearOp implements the kernel
abstraction (ScaledMMLinearKernel) for blockwise quantization.
Args:
weight_shape: Shape of the weight tensor (out_features, in_features).
group_shape: Blockwise quantization group shape.
cutlass_block_fp8_supported: Whether CUTLASS blockwise FP8 is available.
use_aiter_and_is_supported: Whether to use aiter quantization ops.
transpose_weights: Whether to transpose weights after creation.
"""
def
__init__
(
self
,
weight_shape
:
tuple
[
int
,
int
],
group_shape
:
GroupShape
,
cutlass_block_fp8_supported
:
bool
=
False
,
use_aiter_and_is_supported
:
bool
=
False
,
transpose_weights
:
bool
=
False
,
):
weight_scale_shape
=
weight_shape
[
0
]
//
group_shape
[
1
]
self
.
weight_scale
=
torch
.
rand
(
(
weight_scale_shape
,
weight_scale_shape
),
dtype
=
torch
.
float32
)
self
.
weight
=
torch
.
rand
(
weight_shape
).
to
(
dtype
=
FP8_DTYPE
)
self
.
input_scale
=
None
if
transpose_weights
:
self
.
weight
=
self
.
weight
.
t
()
self
.
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
group_shape
[
1
],
group_shape
[
1
]),
act_quant_group_shape
=
group_shape
,
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
use_aiter_and_is_supported
,
)
def
__call__
(
self
,
y
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
)
->
torch
.
Tensor
:
return
self
.
linear_op
.
apply
(
input
=
y
,
weight
=
self
.
weight
,
weight_scale
=
self
.
weight_scale
,
input_scale
=
self
.
input_scale
,
bias
=
bias
,
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
linear_op
.
input_quant_op
.
enabled
()
vllm/_aiter_ops.py
View file @
148117ea
...
@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl(
...
@@ -372,7 +372,7 @@ def _rocm_aiter_gemm_a8w8_impl(
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# a to be [M, K]
# b to be [N, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# Cutlass
Int8
ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
148117ea
...
@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
...
@@ -8,9 +8,13 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrate
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
W8A8BlockFp8LinearOp
,
W8A8BlockFp8LinearOp
,
create_fp8_input_scale
,
create_fp8_input_scale
,
...
@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -22,11 +26,14 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_weight_tensor_strategy
,
process_fp8_weight_tensor_strategy
,
validate_fp8_block_shape
,
validate_fp8_block_shape
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
maybe_create_device_identity
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
...
@@ -42,6 +49,18 @@ strategy_to_parameter_type = {
...
@@ -42,6 +49,18 @@ strategy_to_parameter_type = {
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
QuantizationStrategy
.
TENSOR
:
PerTensorScaleParameter
,
}
}
STATIC_QUANT
=
True
DYNAMIC_QUANT
=
False
activation_quant_key_mapping
=
{
STATIC_QUANT
:
kFp8StaticTensorSym
,
DYNAMIC_QUANT
:
kFp8DynamicTokenSym
,
}
weight_quant_key_mapping
=
{
QuantizationStrategy
.
CHANNEL
:
kFp8StaticTokenSym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
}
logger
=
init_logger
(
__name__
)
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
def
__init__
(
self
,
weight_quant
:
QuantizationArgs
,
is_static_input_scheme
:
bool
):
...
@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -49,22 +68,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
strategy
=
weight_quant
.
strategy
self
.
strategy
=
weight_quant
.
strategy
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
self
.
weight_block_size
=
self
.
weight_quant
.
block_structure
if
self
.
weight_block_size
is
not
None
:
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
else
:
self
.
act_q_group_shape
=
(
GroupShape
.
PER_TENSOR
if
is_static_input_scheme
else
GroupShape
.
PER_TOKEN
)
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
if
self
.
weight_block_size
is
not
None
:
if
self
.
weight_block_size
is
not
None
:
self
.
cutlass_block_fp8_supported
=
cutlass_block_fp8_supported
()
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
assert
not
self
.
is_static_input_scheme
assert
not
self
.
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
self
.
act_q_group_shape
,
act_quant_group_shape
=
self
.
act_q_group_shape
,
...
@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -72,9 +82,13 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
)
else
:
else
:
self
.
fp8_linear
=
Fp8LinearOp
(
activation_quant_key
=
activation_quant_key_mapping
[
is_static_input_scheme
]
act_quant_static
=
self
.
is_static_input_scheme
,
weight_quant_key
=
weight_quant_key_mapping
[
self
.
strategy
]
act_quant_group_shape
=
self
.
act_q_group_shape
,
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
weight_quant_key
,
out_dtype
=
self
.
out_dtype
,
module_name
=
self
.
__class__
.
__name__
,
)
)
@
classmethod
@
classmethod
...
@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -93,8 +107,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_loader
:
Callable
,
weight_loader
:
Callable
,
**
kwargs
,
**
kwargs
,
):
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
weight_block_size
=
None
layer
.
weight_block_size
=
None
...
@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -143,7 +155,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
getattr
(
layer
,
"input_scale"
,
None
),
getattr
(
layer
,
"input_scale"
,
None
),
)
)
weight
=
weight
.
t
()
weight
=
weight
.
t
()
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
elif
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_channel_strategy
(
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
layer
.
weight
,
layer
.
weight_scale
,
getattr
(
layer
,
"input_scale"
,
None
)
...
@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -174,7 +185,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
else
:
layer
.
input_scale
=
None
layer
.
input_scale
=
None
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
maybe_post_process_fp8_weight_block
(
layer
)
maybe_post_process_fp8_weight_block
(
layer
)
...
@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -193,11 +203,4 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
bias
=
bias
,
bias
=
bias
,
)
)
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
148117ea
...
@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -11,8 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
ScaledMMLinearLayerConfig
,
init_int8_linear_kernel
,
choose_scaled_mm_linear_kernel
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BasevLLMParameter
,
...
@@ -25,8 +24,6 @@ logger = init_logger(__name__)
...
@@ -25,8 +24,6 @@ logger = init_logger(__name__)
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
_kernel_backends_being_used
:
set
[
str
]
=
set
()
def
__init__
(
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
):
):
...
@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -50,18 +47,13 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
):
):
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
s
caled_mm_linear_kernel_config
=
ScaledMMLinearLayerConfig
(
s
elf
.
kernel
=
init_int8_linear_kernel
(
is_channelwise
=
(
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
),
is_channelwise
=
(
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
),
is_static_input_scheme
=
self
.
is_static_input_scheme
,
is_static_input_scheme
=
self
.
is_static_input_scheme
,
input_symmetric
=
self
.
input_symmetric
,
input_symmetric
=
self
.
input_symmetric
,
module_name
=
self
.
__class__
.
__name__
,
)
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
scaled_mm_linear_kernel_config
)
if
kernel_type
.
__name__
not
in
self
.
_kernel_backends_being_used
:
logger
.
info
(
"Using %s for CompressedTensorsW8A8Int8"
,
kernel_type
.
__name__
)
self
.
_kernel_backends_being_used
.
add
(
kernel_type
.
__name__
)
# WEIGHT
# WEIGHT
weight
=
ModelWeightParameter
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
data
=
torch
.
empty
(
...
@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -90,12 +82,12 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
# INPUT SCALE
input_zero_point
=
None
input_scale
=
None
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
input_scale
=
BasevLLMParameter
(
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
if
not
self
.
input_symmetric
:
# Note: compressed-tensors stores the zp using the same dtype
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# as the weights
...
@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -103,16 +95,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_zero_point
=
BasevLLMParameter
(
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
self
.
kernel
=
kernel_type
(
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
c
=
scaled_mm_linear_kernel_config
,
if
not
hasattr
(
layer
,
"azp_adj"
):
w_q_param_name
=
"weight"
,
layer
.
register_parameter
(
"azp_adj"
,
None
)
w_s_param_name
=
"weight_scale"
,
i_s_param_name
=
"input_scale"
,
i_zp_param_name
=
"input_zero_point"
,
azp_adj_param_name
=
"azp_adj"
,
)
# Checkpoints are serialized in compressed-tensors format, which is
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
# different from the format the kernel may want. Handle repacking here.
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
148117ea
...
@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -18,17 +18,19 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_fp8_layer_for_marlin
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
is_layer_skipped
,
is_layer_skipped
,
kFp8DynamicTokenSym
,
kFp8StaticTokenSym
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
...
@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -91,10 +93,13 @@ class FBGEMMFp8Config(QuantizationConfig):
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
class
FBGEMMFp8LinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
def
__init__
(
self
,
quant_config
:
FBGEMMFp8Config
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
fp8_linear
=
Fp8LinearOp
(
act_quant_static
=
False
,
act_quant_group_shape
=
GroupShape
.
PER_TOKEN
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
kFp8DynamicTokenSym
,
weight_quant_key
=
kFp8StaticTokenSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -106,7 +111,6 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
maybe_create_device_identity
()
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
del
input_size
,
output_size
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -184,12 +188,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias
=
bias
,
bias
=
bias
,
)
)
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
None
,
input_scale_ub
=
layer
.
input_scale_ub
,
bias
=
bias
,
)
vllm/model_executor/layers/quantization/fp8.py
View file @
148117ea
...
@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -48,6 +48,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm
import
(
init_fp8_linear_kernel
,
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
apply_fi_trtllm_fp8_per_tensor_moe
,
apply_fi_trtllm_fp8_per_tensor_moe
,
...
@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...
@@ -76,12 +79,13 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
is_layer_skipped
,
is_layer_skipped
,
kFp8DynamicTensorSym
,
kFp8DynamicTokenSym
,
kFp8StaticTensorSym
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
maybe_create_device_identity
,
normalize_e4m3fn_to_e4m3fnuz
,
normalize_e4m3fn_to_e4m3fnuz
,
)
)
from
vllm.model_executor.parameter
import
(
from
vllm.model_executor.parameter
import
(
...
@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -328,28 +332,30 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
weight_block_size
=
self
.
quant_config
.
weight_block_size
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
if
self
.
weight_block_size
:
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
else
:
# Use per-token quantization for better perf if dynamic and cutlass
if
not
self
.
act_q_static
and
cutlass_fp8_supported
():
self
.
act_q_group_shape
=
GroupShape
.
PER_TOKEN
else
:
self
.
act_q_group_shape
=
GroupShape
.
PER_TENSOR
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
not
self
.
act_q_static
assert
not
self
.
act_q_static
assert
self
.
weight_block_size
is
not
None
assert
self
.
weight_block_size
is
not
None
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
self
.
act_q_group_shape
,
act_quant_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
)
else
:
else
:
self
.
fp8_linear
=
Fp8LinearOp
(
# Use per-token quantization for better perf if dynamic and cutlass
act_quant_static
=
self
.
act_q_static
,
if
self
.
act_q_static
:
act_quant_group_shape
=
self
.
act_q_group_shape
,
activation_quant_key
=
kFp8StaticTensorSym
elif
cutlass_fp8_supported
():
activation_quant_key
=
kFp8DynamicTokenSym
else
:
activation_quant_key
=
kFp8DynamicTensorSym
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
kFp8StaticTensorSym
,
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
)
)
def
create_weights
(
def
create_weights
(
...
@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -362,8 +368,6 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
maybe_create_device_identity
()
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
...
@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -462,8 +466,6 @@ class Fp8LinearMethod(LinearMethodBase):
scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
scale
=
create_fp8_input_scale
(
output_partition_sizes
,
weight_loader
)
set_weight_attrs
(
scale
,
{
"scale_type"
:
"input_scale"
})
set_weight_attrs
(
scale
,
{
"scale_type"
:
"input_scale"
})
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
if
getattr
(
layer
,
"_already_called_process_weights_after_loading"
,
False
):
...
@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -602,14 +604,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias
=
bias
,
bias
=
bias
,
)
)
return
self
.
fp8_linear
.
apply
(
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
out_dtype
=
self
.
out_dtype
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
)
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
class
Fp8MoEMethod
(
FusedMoEMethodBase
):
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py
View file @
148117ea
...
@@ -2,19 +2,58 @@
...
@@ -2,19 +2,58 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Generic
,
TypeVar
import
torch
import
torch
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
)
from
vllm.platforms
import
current_platform
@
dataclass
@
dataclass
class
ScaledMMLinearLayerConfig
:
class
ScaledMMLinearLayerConfig
:
is_channelwise
:
bool
pass
@
dataclass
class
Int8ScaledMMLinearLayerConfig
(
ScaledMMLinearLayerConfig
):
# TODO: Chnage to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme
:
bool
is_static_input_scheme
:
bool
is_channelwise
:
bool
input_symmetric
:
bool
input_symmetric
:
bool
class
ScaledMMLinearKernel
(
ABC
):
@
dataclass
class
FP8ScaledMMLinearLayerConfig
(
ScaledMMLinearLayerConfig
):
weight_quant_key
:
QuantKey
activation_quant_key
:
QuantKey
out_dtype
:
torch
.
dtype
|
None
_FP8ParamsT
=
tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_scale_ub,
]
_Int8ParamsT
=
tuple
[
torch
.
Tensor
,
# weight
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_zp
torch
.
Tensor
|
None
,
# azp_adj
]
_ParamsT
=
TypeVar
(
"_ParamsT"
,
_Int8ParamsT
,
_FP8ParamsT
)
_ConfigT
=
TypeVar
(
"_ConfigT"
,
bound
=
ScaledMMLinearLayerConfig
)
class
ScaledMMLinearKernel
(
Generic
[
_ConfigT
,
_ParamsT
],
ABC
):
@
classmethod
@
classmethod
@
abstractmethod
@
abstractmethod
def
is_supported
(
def
is_supported
(
...
@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC):
...
@@ -24,26 +63,14 @@ class ScaledMMLinearKernel(ABC):
@
classmethod
@
classmethod
@
abstractmethod
@
abstractmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayer
Config
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
_
Config
T
)
->
tuple
[
bool
,
str
|
None
]:
raise
NotImplementedError
raise
NotImplementedError
def
__init__
(
def
__init__
(
self
,
c
:
_ConfigT
,
layer_param_names
:
Sequence
[
str
])
->
None
:
self
,
assert
self
.
can_implement
(
c
)[
0
]
c
:
ScaledMMLinearLayerConfig
,
assert
self
.
is_supported
()[
0
]
w_q_param_name
:
str
,
w_s_param_name
:
str
,
i_s_param_name
:
str
,
i_zp_param_name
:
str
,
azp_adj_param_name
:
str
,
)
->
None
:
assert
self
.
can_implement
(
c
)
assert
self
.
is_supported
()
self
.
config
=
c
self
.
config
=
c
self
.
w_q_name
=
w_q_param_name
self
.
layer_param_names
=
layer_param_names
self
.
w_s_name
=
w_s_param_name
self
.
i_s_name
=
i_s_param_name
self
.
i_zp_name
=
i_zp_param_name
self
.
azp_adj_name
=
azp_adj_param_name
@
abstractmethod
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
...
@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC):
...
@@ -58,19 +85,103 @@ class ScaledMMLinearKernel(ABC):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
def
_get_weight_params
(
# return a covariant type in the subclass
self
,
layer
:
torch
.
nn
.
Module
@
abstractmethod
)
->
tuple
[
def
_get_layer_params
(
self
,
layer
)
->
_ParamsT
:
torch
.
Tensor
,
# weight
raise
NotImplementedError
torch
.
Tensor
,
# weight_scale
torch
.
Tensor
|
None
,
# input_scale,
torch
.
Tensor
|
None
,
# input_zp
class
FP8ScaledMMLinearKernel
(
torch
.
Tensor
|
None
,
# azp_adj
ScaledMMLinearKernel
[
FP8ScaledMMLinearLayerConfig
,
_FP8ParamsT
],
ABC
]:
):
def
__init__
(
self
,
c
:
FP8ScaledMMLinearLayerConfig
,
layer_param_names
:
Sequence
[
str
]
)
->
None
:
act_scale_descriptor
=
c
.
activation_quant_key
.
scale
self
.
quant_fp8
=
QuantFP8
(
static
=
act_scale_descriptor
.
static
,
group_shape
=
act_scale_descriptor
.
group_shape
,
num_token_padding
=
self
.
get_output_padding
(),
)
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
super
().
__init__
(
c
,
layer_param_names
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
_get_layer_params
(
self
,
layer
)
->
_FP8ParamsT
:
w
,
w_s
,
x_s
,
x_s_ub
=
self
.
layer_param_names
return
(
getattr
(
layer
,
w
),
getattr
(
layer
,
w_s
),
getattr
(
layer
,
x_s
,
None
),
getattr
(
layer
,
x_s_ub
,
None
),
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
fp8_dtype
=
self
.
fp8_dtype
maybe_out_dtype
=
self
.
config
.
out_dtype
w
,
w_s
,
x_s
,
x_s_ub
=
self
.
_get_layer_params
(
layer
)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output_shape
=
[
*
x
.
shape
[:
-
1
],
w
.
shape
[
1
]]
out_dtype
=
x
.
dtype
if
maybe_out_dtype
is
None
else
maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q
=
x_2d
if
x
.
dtype
!=
fp8_dtype
:
x_2d_q
,
x_s
=
self
.
quant_fp8
(
x_2d
,
x_s
,
x_s_ub
,
)
return
self
.
apply_scaled_mm
(
A
=
x_2d_q
,
B
=
w
,
out_dtype
=
out_dtype
,
As
=
x_s
,
Bs
=
w_s
,
bias
=
bias
,
output_shape
=
output_shape
,
)
@
abstractmethod
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
get_output_padding
(
self
)
->
int
|
None
:
return
None
class
Int8ScaledMMLinearKernel
(
ScaledMMLinearKernel
[
Int8ScaledMMLinearLayerConfig
,
_Int8ParamsT
],
ABC
):
def
_get_layer_params
(
self
,
layer
)
->
_Int8ParamsT
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
layer_param_names
return
(
return
(
getattr
(
layer
,
self
.
w_q_name
),
getattr
(
layer
,
w_q
),
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
w_s
),
getattr
(
layer
,
self
.
i_s_nam
e
),
getattr
(
layer
,
i_s
,
Non
e
),
getattr
(
layer
,
self
.
i_zp_nam
e
),
getattr
(
layer
,
i_zp
,
Non
e
),
getattr
(
layer
,
self
.
azp_adj
_nam
e
),
getattr
(
layer
,
azp_adj
,
Non
e
),
)
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
View file @
148117ea
...
@@ -2,76 +2,229 @@
...
@@ -2,76 +2,229 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
os
from
typing
import
TypeVar
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
,
Aiter
Int8
ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu
import
(
CPUScaledMMLinearKernel
,
CPU
Int8
ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
CutlassInt8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm
import
(
ROCmFP8ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
ScaledMMLinearKernel
,
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
,
ScaledMMLinearLayerConfig
,
)
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.triton
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.triton
import
(
TritonScaledMMLinearKernel
,
Triton
Int8
ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
from
vllm.platforms
import
PlatformEnum
,
current_platform
from
vllm.platforms
import
PlatformEnum
,
current_platform
logger
=
init_logger
(
__name__
)
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
Int8ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CPUInt8ScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassInt8ScaledMMLinearKernel
,
TritonInt8ScaledMMLinearKernel
,
],
PlatformEnum
.
ROCM
:
[
AiterInt8ScaledMMLinearKernel
,
TritonInt8ScaledMMLinearKernel
],
}
# in priority/performance order (when available)
# in priority/performance order (when available)
_POSSIBLE_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
ScaledMMLinearKernel
]]]
=
{
_POSSIBLE_FP8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
FP8ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CPUScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
FlashInferFP8ScaledMMLinearKernel
,
PlatformEnum
.
ROCM
:
[
AiterScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
CutlassFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
PlatformEnum
.
ROCM
:
[
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
PlatformEnum
.
CPU
:
[
PerTensorTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
}
}
_KernelT
=
TypeVar
(
"_KernelT"
,
bound
=
ScaledMMLinearKernel
)
_KernelConfigT
=
TypeVar
(
"_KernelConfigT"
,
bound
=
ScaledMMLinearLayerConfig
)
def
is_supported_and_can_implement_kernel
(
kernel
:
type
[
_KernelT
],
config
:
_KernelConfigT
,
compute_capability
:
int
|
None
)
->
tuple
[
bool
,
str
]:
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
).
split
(
","
):
return
False
,
f
"
{
kernel
.
__name__
}
is disabled by environment variable"
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
[
0
]
*
10
+
_cc
[
1
]
is_supported
,
failure_reason
=
kernel
.
is_supported
(
compute_capability
)
if
not
is_supported
:
return
False
,
f
"
{
kernel
.
__name__
}
{
failure_reason
}
."
can_implement
,
failure_reason
=
kernel
.
can_implement
(
config
)
if
not
can_implement
:
return
(
False
,
f
"
{
kernel
.
__name__
}
{
failure_reason
}
."
,
)
return
True
,
""
def
choose_scaled_mm_linear_kernel
(
def
choose_scaled_mm_linear_kernel
(
config
:
ScaledMMLinearLayerConfig
,
compute_capability
:
int
|
None
=
None
config
:
_KernelConfigT
,
)
->
type
[
ScaledMMLinearKernel
]:
possible_kernels
:
dict
[
PlatformEnum
,
list
[
type
[
_KernelT
]]],
compute_capability
:
int
|
None
=
None
,
force_kernel
:
type
[
_KernelT
]
|
None
=
None
,
)
->
type
[
_KernelT
]:
"""
"""
Choose a
n ScaledMMLinear
Kernel that can implement the given config for the
Choose a
_
Kernel
T
that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
given compute capability. Attempts to choose the best kernel in terms of
performance.
performance.
Args:
Args:
config (
ScaledMMLinearLayer
Config): Description of the linear layer
config (
_Kernel
Config
T
): Description of the linear layer
to be implemented.
to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
compute capability. Defaults to None.
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
the possible_kernels if it can be implemented. If None, it will only try the
possible kernels.
Raises:
Raises:
ValueError: If no kernel can implement the given config.
ValueError: If no kernel can implement the given config.
Returns:
Returns:
type[ScaledMMLinear
Kernel
]
: Chosen kernel.
_
Kernel
T
: Chosen kernel.
"""
"""
failure_reasons
=
[]
failure_reason_list
=
[]
for
kernel
in
_POSSIBLE_KERNELS
[
current_platform
.
_enum
]:
if
kernel
.
__name__
in
os
.
environ
.
get
(
"VLLM_DISABLED_KERNELS"
,
""
).
split
(
","
):
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
: disabled by env var"
)
continue
# If the current platform uses compute_capability,
if
force_kernel
is
not
None
:
# make sure the kernel supports the compute capability.
can_implement
,
failure_reason
=
is_supported_and_can_implement_kernel
(
is_supported
,
reason
=
kernel
.
is_supported
(
compute_capability
)
force_kernel
,
config
,
compute_capability
if
not
is_supported
:
)
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
if
can_implement
:
continue
return
force_kernel
can_implement
,
reason
=
kernel
.
can_implement
(
config
)
logger
.
info_once
(
if
not
can_implement
:
"Tried to force %s, but the kernel couldn't be implemented"
,
failure_reasons
.
append
(
f
"
{
kernel
.
__name__
}
:
{
reason
}
"
)
force_kernel
.
__name__
,
continue
scope
=
"global"
,
)
return
kernel
for
kernel
in
possible_kernels
[
current_platform
.
_enum
]:
is_supported_and_can_implement
,
failure_reason
=
(
is_supported_and_can_implement_kernel
(
kernel
,
config
,
compute_capability
)
)
if
is_supported_and_can_implement
:
return
kernel
failure_reason_list
.
append
(
failure_reason
)
raise
ValueError
(
raise
ValueError
(
"Failed to find a kernel that can implement the "
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reasons
)
"ScaledMM linear layer. Reasons:
\n
"
+
"
\n
"
.
join
(
failure_reason_list
)
)
def
init_fp8_linear_kernel
(
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
out_dtype
:
torch
.
dtype
,
force_kernel
:
type
[
FP8ScaledMMLinearKernel
]
|
None
=
None
,
module_name
:
str
|
None
=
None
,
)
->
FP8ScaledMMLinearKernel
:
scaled_mm_linear_kernel_config
=
FP8ScaledMMLinearLayerConfig
(
weight_quant_key
=
weight_quant_key
,
activation_quant_key
=
activation_quant_key
,
out_dtype
=
out_dtype
,
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
scaled_mm_linear_kernel_config
,
_POSSIBLE_FP8_KERNELS
,
force_kernel
=
force_kernel
)
if
module_name
:
logger
.
info_once
(
"Selected %s for %s"
,
kernel_type
.
__name__
,
module_name
,
scope
=
"global"
,
)
return
kernel_type
(
scaled_mm_linear_kernel_config
,
layer_param_names
=
[
"weight"
,
"weight_scale"
,
"input_scale"
,
"input_scale_ub"
],
)
def
init_int8_linear_kernel
(
is_channelwise
:
bool
,
is_static_input_scheme
:
bool
,
input_symmetric
:
bool
,
module_name
:
str
,
)
->
Int8ScaledMMLinearKernel
:
config
=
Int8ScaledMMLinearLayerConfig
(
is_channelwise
=
is_channelwise
,
is_static_input_scheme
=
is_static_input_scheme
,
input_symmetric
=
input_symmetric
,
)
kernel_type
=
choose_scaled_mm_linear_kernel
(
config
,
_POSSIBLE_INT8_KERNELS
,
)
logger
.
info_once
(
"Selected %s for %s"
,
kernel_type
.
__name__
,
module_name
,
scope
=
"global"
,
)
return
kernel_type
(
config
,
layer_param_names
=
[
"weight"
,
"weight_scale"
,
"input_scale"
,
"input_zero_point"
,
"azp_adj"
,
],
)
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
View file @
148117ea
...
@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops
...
@@ -8,60 +8,41 @@ from vllm import _custom_ops as ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.cutlass
import
CutlassScaledMMLinearKernel
from
.cutlass
import
Cutlass
Int8
ScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
Int8
ScaledMMLinearLayerConfig
class
AiterScaledMMLinearKernel
(
CutlassScaledMMLinearKernel
):
class
Aiter
Int8
ScaledMMLinearKernel
(
Cutlass
Int8
ScaledMMLinearKernel
):
@
classmethod
@
classmethod
def
is_supported
(
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
if
not
current_platform
.
is_rocm
():
return
(
return
False
,
"Requires ROCm."
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"currently supported on non-ROCm platform."
,
)
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
if
compute_capability
is
not
None
and
compute_capability
<
90
:
if
compute_capability
is
not
None
and
compute_capability
<
90
:
return
False
,
f
"requires
capability 90, got
{
compute
_
capability
}
"
return
False
,
"requires compute
capability
90 and above.
"
try
:
try
:
import
aiter
# noqa: F401 # deliberately attempt to import aiter
import
aiter
# noqa: F401 # deliberately attempt to import aiter
except
Exception
:
except
Exception
:
return
(
return
False
,
"requires `aiter` to be installed."
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"installed on ROCm."
,
)
if
not
rocm_aiter_ops
.
is_linear_enabled
():
if
not
rocm_aiter_ops
.
is_linear_enabled
():
return
(
return
(
False
,
False
,
"AiterScaledMMLinearKernel is disabled. "
"requires setting `VLLM_ROCM_USE_AITER=1` "
+
"Enable by setting `VLLM_ROCM_USE_AITER=1` "
+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
)
)
return
True
,
None
return
True
,
None
@
classmethod
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
if
not
c
.
input_symmetric
:
if
not
c
.
input_symmetric
:
return
(
return
False
,
"supports symmetric quantization only."
False
,
"AiterScaledMMLinearKernel only supports symmetric "
+
"quantization."
,
)
return
True
,
None
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
...
@@ -69,28 +50,28 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
`AiterScaledMMLinearKernel` implements a fused version of
`Aiter
Int8
ScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
broadcasting.
Currently only support per-tensor-per-tensor GEMM
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
w8a8 scaled gemm. `Aiter
Int8
ScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
ATIER block scaled GEMM and mix-precision GEMM.
"""
"""
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
# * static, i_s is scalar and x_s is i_s.
symmetric
=
azp_adj
is
None
symmetric
=
azp_adj
is
None
assert
symmetric
,
(
assert
symmetric
,
(
"AiterScaledMMLinearKernel only supports symmetric quantization."
"Aiter
Int8
ScaledMMLinearKernel only supports symmetric quantization."
)
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
i_s
,
i_zp
,
symmetric
=
symmetric
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
i_s
,
i_zp
,
symmetric
=
symmetric
)
assert
x_zp
is
None
,
(
assert
x_zp
is
None
,
(
"AiterScaledMMLinearKernel only supports symmetric quantization."
"Aiter
Int8
ScaledMMLinearKernel only supports symmetric quantization."
)
)
out_dtype
=
x
.
dtype
out_dtype
=
x
.
dtype
...
@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
...
@@ -117,12 +98,12 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
),
(
),
(
"Currently only support per-tensor-per-tensor GEMM "
"Currently only support per-tensor-per-tensor GEMM "
+
" and per-token-per-channel GEMM through AITER"
+
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
" w8a8 scaled gemm. `Aiter
Int8
ScaledMMLinearKernel` "
+
"does not support AITER block scaled GEMM."
+
"does not support AITER block scaled GEMM."
)
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# a to be [M, K]
# b to be [N, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# Cutlass
Int8
ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
rocm_aiter_ops
.
gemm_a8w8
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
return
rocm_aiter_ops
.
gemm_a8w8
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py
View file @
148117ea
...
@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
...
@@ -14,24 +14,28 @@ from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
CpuArchEnum
from
vllm.platforms.interface
import
CpuArchEnum
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
)
class
CPUScaledMMLinearKernel
(
ScaledMMLinearKernel
):
class
CPU
Int8
ScaledMMLinearKernel
(
Int8
ScaledMMLinearKernel
):
@
classmethod
@
classmethod
def
is_supported
(
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cpu
():
if
not
current_platform
.
is_cpu
():
return
False
,
"
R
equires CPU."
return
False
,
"
r
equires CPU."
return
True
,
None
return
True
,
None
@
classmethod
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
weight
=
getattr
(
layer
,
self
.
w_q_name
)
w_q_name
,
_
,
_
,
_
,
_
=
self
.
layer_param_names
weight
=
getattr
(
layer
,
w_q_name
)
dtype
=
weight
.
dtype
dtype
=
weight
.
dtype
N
,
K
=
weight
.
size
()
N
,
K
=
weight
.
size
()
if
(
if
(
...
@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -49,10 +53,11 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
def
process_weights_for_onednn
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_for_onednn
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# WEIGHT
# WEIGHT
# Transpose to [K, N] for convenience
# Transpose to [K, N] for convenience
weight
=
getattr
(
layer
,
self
.
w_q_name
)
w_q_name
,
w_s_name
,
i_s_name
,
i_zp_name
,
azp_adj_name
=
self
.
layer_param_names
weight
=
getattr
(
layer
,
w_q_name
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_q_name
,
w_q_name
,
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
)
)
...
@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -61,28 +66,27 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
)
# INPUT SCALE
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
if
self
.
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
input_scale
=
getattr
(
layer
,
i_s_name
)
if
self
.
config
.
input_symmetric
:
if
self
.
config
.
input_symmetric
:
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
i_s_name
,
i_s_name
,
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
)
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
else
:
else
:
input_zero_point
=
getattr
(
layer
,
self
.
i_zp_name
)
input_zero_point
=
getattr
(
layer
,
i_zp_name
)
# reconstruct the ranges
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
...
@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -92,20 +96,16 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
replace_parameter
(
replace_parameter
(
layer
,
self
.
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
layer
,
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
)
)
azp
=
(
azp
=
(
(
int8_traits
.
min
-
range_min
/
scale
).
round
().
to
(
dtype
=
torch
.
int32
)
(
int8_traits
.
min
-
range_min
/
scale
).
round
().
to
(
dtype
=
torch
.
int32
)
)
)
replace_parameter
(
replace_parameter
(
layer
,
self
.
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
layer
,
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
)
)
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
# term. Such as:
...
@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -113,38 +113,37 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
# s_a * GEMM_output - s_a * zp_a * adj + bias
if
not
(
self
.
config
.
input_symmetric
and
self
.
config
.
is_static_input_scheme
):
if
not
(
self
.
config
.
input_symmetric
and
self
.
config
.
is_static_input_scheme
):
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
float32
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
float32
)
azp_adj
=
azp_adj
*
weight_scale
.
squeeze
()
azp_adj
=
azp_adj
*
weight_scale
.
squeeze
()
setattr
(
setattr
(
layer
,
layer
,
self
.
azp_adj_name
,
azp_adj_name
,
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
)
)
else
:
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
self
.
dnnl_handler
=
ops
.
create_onednn_scaled_mm
(
self
.
dnnl_handler
=
ops
.
create_onednn_scaled_mm
(
weight
,
weight
,
getattr
(
layer
,
self
.
w_s_name
),
getattr
(
layer
,
w_s_name
),
torch
.
get_default_dtype
(),
torch
.
get_default_dtype
(),
getattr
(
layer
,
self
.
i_s_name
)
is
None
,
getattr
(
layer
,
i_s_name
)
is
None
,
not
self
.
config
.
input_symmetric
,
not
self
.
config
.
input_symmetric
,
32
,
32
,
)
)
# weight is prepacked and maintained by the dnnl_handler,
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
# release the original weight
setattr
(
layer
,
self
.
w_q_name
,
None
)
setattr
(
layer
,
w_q_name
,
None
)
del
weight
del
weight
def
process_weights_for_sgl
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_for_sgl
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_q_name
,
w_s_name
,
_
,
_
,
_
=
self
.
layer_param_names
# WEIGHT
# WEIGHT
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
weight
)
packed_weight
=
torch
.
ops
.
_C
.
convert_weight_packed
(
weight
)
replace_parameter
(
replace_parameter
(
layer
,
self
.
w_q_name
,
torch
.
nn
.
Parameter
(
packed_weight
,
requires_grad
=
False
)
layer
,
w_q_name
,
torch
.
nn
.
Parameter
(
packed_weight
,
requires_grad
=
False
)
)
)
if
layer
.
bias
is
not
None
:
if
layer
.
bias
is
not
None
:
...
@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -156,19 +155,15 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
# WEIGHT SCALE
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
# For per-tensor quant, convert to the per-channel case.
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
not
self
.
config
.
is_channelwise
:
if
not
self
.
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
)
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -187,7 +182,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * dynamic, i_s is None and x_s computed from x.
...
@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -209,7 +204,7 @@ class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
w_q
,
w_s
,
_
,
_
,
_
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
_
,
_
,
_
=
self
.
_get_
layer
_params
(
layer
)
return
torch
.
ops
.
_C
.
int8_scaled_mm_with_quant
(
return
torch
.
ops
.
_C
.
int8_scaled_mm_with_quant
(
x
,
x
,
w_q
,
w_q
,
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
View file @
148117ea
...
@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -11,35 +11,36 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.ScaledMMLinearKernel
import
ScaledMMLinearKernel
,
ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
)
class
CutlassScaledMMLinearKernel
(
ScaledMMLinearKernel
):
class
Cutlass
Int8
ScaledMMLinearKernel
(
Int8
ScaledMMLinearKernel
):
@
classmethod
@
classmethod
def
is_supported
(
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
return
False
,
"Requires CUDA."
return
False
,
"requires CUDA."
if
compute_capability
is
None
:
_cc
=
current_platform
.
get_device_capability
()
if
_cc
is
not
None
:
compute_capability
=
_cc
.
major
*
10
+
_cc
.
minor
if
compute_capability
is
not
None
and
compute_capability
<
75
:
return
False
,
f
"requires capability 75, got
{
compute_capability
}
"
return
True
,
None
return
True
,
None
@
classmethod
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
def
can_implement
(
cls
,
c
:
Int8
ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_q_name
,
w_s_name
,
i_s_name
,
i_zp_name
,
azp_adj_name
=
self
.
layer_param_names
config
=
self
.
config
# WEIGHT
# WEIGHT
# Cutlass kernels need transposed weight.
# Cutlass kernels need transposed weight.
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_q_name
,
w_q_name
,
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight
.
t
().
data
,
requires_grad
=
False
),
)
)
...
@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -48,28 +49,28 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
is_fused_module
=
len
(
layer
.
logical_widths
)
>
1
weight_scale
=
getattr
(
layer
,
self
.
w_s_name
)
weight_scale
=
getattr
(
layer
,
w_s_name
)
if
is_fused_module
and
not
self
.
config
.
is_channelwise
:
if
is_fused_module
and
not
config
.
is_channelwise
:
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
weight_scale
=
convert_to_channelwise
(
weight_scale
,
layer
.
logical_widths
)
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
w_s_name
,
w_s_name
,
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
weight_scale
.
data
,
requires_grad
=
False
),
)
)
# INPUT SCALE
# INPUT SCALE
if
self
.
config
.
is_static_input_scheme
:
if
config
.
is_static_input_scheme
:
input_scale
=
getattr
(
layer
,
self
.
i_s_name
)
input_scale
=
getattr
(
layer
,
i_s_name
)
if
self
.
config
.
input_symmetric
:
if
config
.
input_symmetric
:
replace_parameter
(
replace_parameter
(
layer
,
layer
,
self
.
i_s_name
,
i_s_name
,
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
input_scale
.
max
(),
requires_grad
=
False
),
)
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
setattr
(
layer
,
i_zp_name
,
None
)
else
:
else
:
input_zero_point
=
getattr
(
layer
,
self
.
i_zp_name
)
input_zero_point
=
getattr
(
layer
,
i_zp_name
)
# reconstruct the ranges
# reconstruct the ranges
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
int8_traits
=
torch
.
iinfo
(
torch
.
int8
)
...
@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -79,38 +80,32 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
scale
=
(
range_max
-
range_min
)
/
(
int8_traits
.
max
-
int8_traits
.
min
)
replace_parameter
(
replace_parameter
(
layer
,
self
.
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
layer
,
i_s_name
,
torch
.
nn
.
Parameter
(
scale
,
requires_grad
=
False
)
)
)
# AZP loaded as int8 but used as int32
# AZP loaded as int8 but used as int32
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
azp
=
(
int8_traits
.
min
-
range_min
/
scale
).
to
(
dtype
=
torch
.
int32
)
replace_parameter
(
replace_parameter
(
layer
,
self
.
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
layer
,
i_zp_name
,
torch
.
nn
.
Parameter
(
azp
,
requires_grad
=
False
)
)
)
else
:
setattr
(
layer
,
self
.
i_s_name
,
None
)
setattr
(
layer
,
self
.
i_zp_name
,
None
)
# azp_adj is the AZP adjustment term, used to account for weights.
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if
not
self
.
config
.
input_symmetric
:
if
not
config
.
input_symmetric
:
weight
=
getattr
(
layer
,
self
.
w_q_name
)
weight
=
getattr
(
layer
,
w_q_name
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
azp_adj
=
weight
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
if
self
.
config
.
is_static_input_scheme
:
if
config
.
is_static_input_scheme
:
# cutlass_w8a8 requires azp to be folded into azp_adj
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
# in the per-tensor case
azp_adj
=
getattr
(
layer
,
self
.
i_zp_name
)
*
azp_adj
azp_adj
=
getattr
(
layer
,
i_zp_name
)
*
azp_adj
setattr
(
setattr
(
layer
,
layer
,
self
.
azp_adj_name
,
azp_adj_name
,
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
torch
.
nn
.
Parameter
(
azp_adj
,
requires_grad
=
False
),
)
)
else
:
setattr
(
layer
,
self
.
azp_adj_name
,
None
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
...
@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -118,7 +113,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
weight
_params
(
layer
)
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_
layer
_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * dynamic, i_s is None and x_s computed from x.
...
@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -145,3 +140,34 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
return
ops
.
cutlass_scaled_mm
(
return
ops
.
cutlass_scaled_mm
(
x_q
,
w_q
,
scale_a
=
x_s
,
scale_b
=
w_s
,
out_dtype
=
x
.
dtype
,
bias
=
bias
x_q
,
w_q
,
scale_a
=
x_s
,
scale_b
=
w_s
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
)
class
CutlassFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"requires CUDA."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm
(
A
,
B
,
out_dtype
=
out_dtype
,
scale_a
=
As
,
scale_b
=
Bs
,
bias
=
bias
)
return
output
.
view
(
*
output_shape
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment