Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4965ec42
Unverified
Commit
4965ec42
authored
Mar 29, 2025
by
TJian
Committed by
GitHub
Mar 29, 2025
Browse files
[FEAT] [ROCm] Add AITER int8 scaled gemm kernel (#15433)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
73aa7041
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
202 additions
and
5 deletions
+202
-5
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+72
-4
vllm/envs.py
vllm/envs.py
+8
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
...xecutor/layers/quantization/kernels/scaled_mm/__init__.py
+3
-1
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
...l_executor/layers/quantization/kernels/scaled_mm/aiter.py
+119
-0
No files found.
tests/quantization/test_compressed_tensors.py
View file @
4965ec42
...
@@ -20,6 +20,23 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...
@@ -20,6 +20,23 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported
)
sparse_cutlass_supported
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL
=
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
]
# TritonScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
=
[
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor"
,
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
,
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
,
]
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
def
use_v0_only
(
monkeypatch
):
...
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
...
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
)
)
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
def
test_compressed_tensors_w8a8_static_setup
(
vllm_runner
,
model_args
):
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
model_path
,
strategy
,
quant_type
,
shape_0
,
is_symmetric
=
model_args
if
current_platform
.
is_rocm
(
)
and
model_path
not
in
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
:
pytest
.
skip
(
f
"Skip model
{
model_path
}
as it is not support on ROCm."
)
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
with
vllm_runner
(
model_path
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
def
check_model
(
model
):
...
@@ -123,6 +145,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
...
@@ -123,6 +145,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
)
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"use_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
])
def
test_compressed_tensors_w8a8_logprobs
(
def
test_compressed_tensors_w8a8_logprobs
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
...
@@ -130,7 +154,21 @@ def test_compressed_tensors_w8a8_logprobs(
...
@@ -130,7 +154,21 @@ def test_compressed_tensors_w8a8_logprobs(
model_path
,
model_path
,
max_tokens
,
max_tokens
,
num_logprobs
,
num_logprobs
,
use_aiter
,
monkeypatch
,
):
):
if
current_platform
.
is_rocm
(
)
and
model_path
not
in
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
:
pytest
.
skip
(
f
"Skip model
{
model_path
}
as it is not support on ROCm."
)
if
use_aiter
:
if
model_path
not
in
ROCM_AITER_SUPPORTED_INT8_MODEL
:
pytest
.
skip
(
f
"Skip model
{
model_path
}
as it is not support by aiter."
)
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
dtype
=
"bfloat16"
dtype
=
"bfloat16"
# skip language translation prompt for the static per tensor asym model
# skip language translation prompt for the static per tensor asym model
...
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
...
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
name_1
=
"vllm"
,
name_1
=
"vllm"
,
)
)
if
current_platform
.
is_rocm
():
torch
.
cuda
.
synchronize
()
def
test_compressed_tensors_no_enforce_eager
(
vllm_runner
):
def
test_compressed_tensors_no_enforce_eager
(
vllm_runner
):
model_path
=
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
model_path
=
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
...
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
...
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
),
),
],
],
)
)
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
):
@
pytest
.
mark
.
parametrize
(
"use_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
])
def
test_compressed_tensors_w8a8_dynamic_per_token
(
vllm_runner
,
model_args
,
use_aiter
,
monkeypatch
,
):
model_path
,
strategy
=
model_args
model_path
,
strategy
=
model_args
if
current_platform
.
is_rocm
(
)
and
model_path
not
in
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
:
pytest
.
skip
(
f
"Skip model
{
model_path
}
as it is not support on ROCm."
)
if
use_aiter
:
if
model_path
not
in
ROCM_AITER_SUPPORTED_INT8_MODEL
:
pytest
.
skip
(
f
"Skip model
{
model_path
}
as it is not support by aiter."
)
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch
.
setenv
(
"VLLM_ROCM_USE_AITER"
,
"1"
)
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
with
vllm_runner
(
model_path
,
dtype
=
torch
.
float16
)
as
llm
:
def
check_model
(
model
):
def
check_model
(
model
):
...
@@ -207,6 +267,8 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
...
@@ -207,6 +267,8 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
),
(
"nm-testing/tinyllama-oneshot-w8a16-per-channel"
,
"channel"
,
None
,
4
),
],
],
)
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"The tests are skipped on non-CUDA platform."
)
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
def
test_compressed_tensors_wNa16
(
vllm_runner
,
wNa16_args
):
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
model
,
strategy
,
group
,
pack_factor
=
wNa16_args
with
vllm_runner
(
model
)
as
llm
:
with
vllm_runner
(
model
)
as
llm
:
...
@@ -231,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
...
@@ -231,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
def
test_compressed_tensors_w4a16_marlin24
(
vllm_runner
):
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
model_path
=
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
)
as
llm
:
...
@@ -271,7 +335,7 @@ def test_compressed_tensors_fp8(vllm_runner):
...
@@ -271,7 +335,7 @@ def test_compressed_tensors_fp8(vllm_runner):
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
if
isinstance
(
qkv_proj
.
scheme
,
CompressedTensorsW8A8Fp8
):
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
assert
qkv_proj
.
weight
.
dtype
is
current_platform
.
fp8_dtype
()
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
qkv_proj
.
weight_scale
.
dtype
is
torch
.
float32
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
...
@@ -281,6 +345,8 @@ def test_compressed_tensors_fp8(vllm_runner):
...
@@ -281,6 +345,8 @@ def test_compressed_tensors_fp8(vllm_runner):
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
def
test_compressed_tensors_kv_cache
(
vllm_runner
):
def
test_compressed_tensors_kv_cache
(
vllm_runner
):
model_path
=
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
model_path
=
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with
vllm_runner
(
model_path
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
with
vllm_runner
(
model_path
,
kv_cache_dtype
=
"fp8"
)
as
llm
:
...
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
...
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -356,7 +423,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
...
@@ -356,7 +423,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
90
),
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
vllm/envs.py
View file @
4965ec42
...
@@ -75,6 +75,7 @@ if TYPE_CHECKING:
...
@@ -75,6 +75,7 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
True
VLLM_USE_V1
:
bool
=
True
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER_LINEAR
:
bool
=
True
VLLM_ROCM_USE_AITER_MOE
:
bool
=
True
VLLM_ROCM_USE_AITER_MOE
:
bool
=
True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
:
bool
=
False
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
:
bool
=
False
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
...
@@ -524,6 +525,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -524,6 +525,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER"
,
"False"
).
lower
()
in
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
"VLLM_ROCM_USE_AITER_LINEAR"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_LINEAR"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use aiter moe ops.
# Whether to use aiter moe ops.
# By default is enabled.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE"
:
"VLLM_ROCM_USE_AITER_MOE"
:
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py
View file @
4965ec42
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
import
os
import
os
from
typing
import
Dict
,
List
,
Optional
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Type
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter
import
(
AiterScaledMMLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass
import
(
CutlassScaledMMLinearKernel
)
CutlassScaledMMLinearKernel
)
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel
import
(
# noqa: E501
...
@@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
...
@@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
_POSSIBLE_KERNELS
:
Dict
[
PlatformEnum
,
List
[
Type
[
ScaledMMLinearKernel
]]]
=
{
_POSSIBLE_KERNELS
:
Dict
[
PlatformEnum
,
List
[
Type
[
ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CPU
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CPU
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
CUDA
:
[
CutlassScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
TritonScaledMMLinearKernel
],
PlatformEnum
.
ROCM
:
[
AiterScaledMMLinearKernel
,
TritonScaledMMLinearKernel
],
PlatformEnum
.
TPU
:
[
XLAScaledMMLinearKernel
],
PlatformEnum
.
TPU
:
[
XLAScaledMMLinearKernel
],
}
}
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py
0 → 100644
View file @
4965ec42
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
.cutlass
import
CutlassScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
ScaledMMLinearLayerConfig
class
AiterScaledMMLinearKernel
(
CutlassScaledMMLinearKernel
):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
90
@
classmethod
def
can_implement
(
cls
,
c
:
ScaledMMLinearLayerConfig
)
->
Tuple
[
bool
,
Optional
[
str
]]:
if
not
current_platform
.
is_rocm
():
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"currently supported on non-ROCm platform."
)
try
:
import
aiter
# noqa: F401 # deliberately attempt to import aiter
except
Exception
:
return
(
False
,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+
"installed on ROCm."
)
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if
not
(
envs
.
VLLM_ROCM_USE_AITER_LINEAR
\
and
envs
.
VLLM_ROCM_USE_AITER
):
return
(
False
,
"AiterScaledMMLinearKernel is disabled. "
+
"Enable by setting `VLLM_ROCM_USE_AITER=1` "
+
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
)
if
not
c
.
input_symmetric
:
return
(
False
,
"AiterScaledMMLinearKernel only supports symmetric "
+
"quantization."
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
super
().
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q
,
w_s
,
i_s
,
i_zp
,
azp_adj
=
self
.
_get_weight_params
(
layer
)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric
=
azp_adj
is
None
assert
symmetric
,
(
"AiterScaledMMLinearKernel only supports"
" symmetric quantization."
)
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
i_s
,
i_zp
,
symmetric
=
symmetric
)
assert
x_zp
is
None
,
(
"AiterScaledMMLinearKernel only supports"
" symmetric quantization."
)
out_dtype
=
x
.
dtype
assert
(
w_q
.
shape
[
0
]
%
16
==
0
and
w_q
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
assert
bias
is
None
or
bias
.
shape
[
0
]
==
w_q
.
shape
[
1
]
and
bias
.
dtype
==
out_dtype
m
=
x_q
.
shape
[
0
]
# a
n
=
w_q
.
shape
[
1
]
# b
per_tensor_scale_a
=
(
x_s
.
numel
()
==
1
)
per_tensor_scale_b
=
(
w_s
.
numel
()
==
1
)
per_token_scale_a
=
(
x_s
.
numel
()
==
m
)
per_channel_scale_b
=
(
w_s
.
numel
()
==
n
)
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert
((
per_tensor_scale_a
and
per_tensor_scale_b
)
or
(
per_token_scale_a
and
per_channel_scale_b
)),
(
"Currently only support per-tensor-per-tensor GEMM "
+
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
+
"does not support AITER block scaled GEMM."
)
from
aiter
import
gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
gemm_a8w8_CK
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
).
to
(
out_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment