Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ee14644b
Unverified
Commit
ee14644b
authored
Dec 09, 2025
by
vllmellm
Committed by
GitHub
Dec 09, 2025
Browse files
[ROCm] Aiter Quant Kernels (#25552)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
1166c31c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
2 deletions
+123
-2
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+87
-0
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+31
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+5
-2
No files found.
vllm/_aiter_ops.py
View file @
ee14644b
...
...
@@ -9,6 +9,8 @@ import vllm.envs as envs
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
direct_register_custom_op
,
is_torch_equal_or_newer
_FP8_DTYPE
=
current_platform
.
fp8_dtype
()
def
is_aiter_found
()
->
bool
:
from
importlib.util
import
find_spec
...
...
@@ -467,6 +469,59 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
return
torch
.
empty_like
(
x
),
torch
.
empty_like
(
residual
)
def
_rocm_aiter_per_tensor_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
per_tensor_quant_hip
return
per_tensor_quant_hip
(
x
,
scale
,
quant_dtype
)
def
_rocm_aiter_per_tensor_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
empty_like
(
x
,
dtype
=
quant_dtype
),
torch
.
empty
(
1
,
dtype
=
torch
.
float32
,
device
=
x
.
device
)
def
_rocm_aiter_per_token_quant_impl
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
aiter.ops.quant
import
dynamic_per_token_scaled_quant
assert
quant_dtype
in
[
torch
.
int8
,
_FP8_DTYPE
]
out_shape
=
x
.
shape
out
=
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
)
if
scale
is
None
:
scale
=
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
dynamic_per_token_scaled_quant
(
out
,
x
,
scale
,
scale_ub
=
None
,
shuffle_scale
=
False
,
num_rows
=
None
,
num_rows_factor
=
1
,
)
return
out
,
scale
def
_rocm_aiter_per_token_quant_fake
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
out_shape
=
x
.
shape
return
(
torch
.
empty
(
x
.
shape
,
dtype
=
_FP8_DTYPE
,
device
=
x
.
device
),
torch
.
empty
((
*
out_shape
[:
-
1
],
1
),
dtype
=
torch
.
float32
,
device
=
x
.
device
),
)
# Global flag to ensure ops are registered only once
_OPS_REGISTERED
=
False
...
...
@@ -665,6 +720,22 @@ class rocm_aiter_ops:
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_tensor_quant"
,
op_func
=
_rocm_aiter_per_tensor_quant_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_per_tensor_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_per_token_quant"
,
op_func
=
_rocm_aiter_per_token_quant_impl
,
mutates_args
=
[
"scale"
],
fake_impl
=
_rocm_aiter_per_token_quant_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
_OPS_REGISTERED
=
True
@
staticmethod
...
...
@@ -859,6 +930,22 @@ class rocm_aiter_ops:
kv_scale
=
kv_scale
,
)
@
staticmethod
def
per_tensor_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_tensor_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
per_token_quant
(
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
scale
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
torch
.
ops
.
vllm
.
rocm_aiter_per_token_quant
(
x
,
quant_dtype
,
scale
)
@
staticmethod
def
triton_fp4_gemm_dynamic_qaunt
(
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
ee14644b
...
...
@@ -5,6 +5,7 @@ import torch
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.platforms
import
current_platform
...
...
@@ -45,10 +46,13 @@ class QuantFP8(CustomOp):
super
().
__init__
()
self
.
static
=
static
self
.
group_shape
=
group_shape
self
.
use_per_token_if_dynamic
=
group_shape
==
GroupShape
.
PER_TOKEN
self
.
num_token_padding
=
num_token_padding
self
.
column_major_scales
=
column_major_scales
self
.
use_ue8m0
=
use_ue8m0
self
.
use_aiter
=
rocm_aiter_ops
.
is_linear_fp8_enaled
()
self
.
is_group_quant
=
group_shape
.
is_per_group
()
if
self
.
is_group_quant
:
assert
not
static
,
"Group quantization only supports dynamic mode"
...
...
@@ -92,6 +96,33 @@ class QuantFP8(CustomOp):
use_per_token_if_dynamic
=
self
.
use_per_token_if_dynamic
,
)
def
forward_hip
(
self
,
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
|
None
=
None
,
scale_ub
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
use_aiter_quant
=
(
not
self
.
is_group_quant
and
self
.
use_aiter
and
scale_ub
is
None
and
x
.
is_contiguous
()
)
use_aiter_per_tensor_quant
=
(
use_aiter_quant
and
self
.
group_shape
==
GroupShape
.
PER_TENSOR
)
use_aiter_per_token_quant
=
(
use_aiter_quant
and
self
.
group_shape
==
GroupShape
.
PER_TOKEN
)
if
use_aiter_per_tensor_quant
:
return
rocm_aiter_ops
.
per_tensor_quant
(
x
,
_FP8_DTYPE
,
scale
)
if
use_aiter_per_token_quant
:
return
rocm_aiter_ops
.
per_token_quant
(
x
,
_FP8_DTYPE
,
scale
)
# Fallback to CUDA implementation
return
self
.
forward_cuda
(
x
,
scale
,
scale_ub
)
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
...
...
vllm/platforms/rocm.py
View file @
ee14644b
...
...
@@ -381,6 +381,8 @@ class RocmPlatform(Platform):
compilation_config
=
vllm_config
.
compilation_config
parallel_config
=
vllm_config
.
parallel_config
is_eager_execution
=
compilation_config
==
CUDAGraphMode
.
NONE
use_aiter_rms_norm
=
rocm_aiter_ops
.
is_rmsnorm_enabled
()
use_aiter_fp8_linear
=
rocm_aiter_ops
.
is_linear_fp8_enaled
()
if
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
# decode context parallel does not support full cudagraphs
...
...
@@ -400,8 +402,6 @@ class RocmPlatform(Platform):
)
compilation_config
.
cudagraph_mode
=
CUDAGraphMode
.
PIECEWISE
use_aiter_rms_norm
=
rocm_aiter_ops
.
is_rmsnorm_enabled
()
if
cache_config
and
cache_config
.
block_size
is
None
:
cache_config
.
block_size
=
16
...
...
@@ -415,6 +415,9 @@ class RocmPlatform(Platform):
):
compilation_config
.
custom_ops
.
append
(
"+rms_norm"
)
if
use_aiter_fp8_linear
and
"-quant_fp8"
not
in
compilation_config
.
custom_ops
:
compilation_config
.
custom_ops
.
append
(
"+quant_fp8"
)
@
classmethod
def
verify_model_arch
(
cls
,
model_arch
:
str
)
->
None
:
if
model_arch
in
_ROCM_UNSUPPORTED_MODELS
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment