Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca00b1bf
Unverified
Commit
ca00b1bf
authored
Nov 13, 2025
by
Pleaplusone
Committed by
GitHub
Nov 12, 2025
Browse files
[ROCm][BugFix] Remove the usage of `device_info` from aiter (#28383)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
d44fbbab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
6 deletions
+5
-6
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+5
-6
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
ca00b1bf
...
@@ -31,15 +31,14 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
...
@@ -31,15 +31,14 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
import
aiter
import
aiter
from
aiter.ops.triton.utils.device_info
import
get_num_sms
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
def
block_size
(
x
,
head_dim
):
def
block_size
(
x
,
head_dim
):
return
min
(
65536
//
x
.
element_size
(),
triton
.
next_power_of_2
(
head_dim
))
return
min
(
65536
//
x
.
element_size
(),
triton
.
next_power_of_2
(
head_dim
))
def
num_programs
(
head_dim
):
def
num_programs
(
total_tokens
):
return
min
(
head_dim
,
get_num_sms
())
return
min
(
total_tokens
,
current_platform
.
get_cu_count
())
@
triton
.
jit
@
triton
.
jit
def
cp_mha_gather_cache_kernel
(
def
cp_mha_gather_cache_kernel
(
...
@@ -58,11 +57,11 @@ if current_platform.is_rocm():
...
@@ -58,11 +57,11 @@ if current_platform.is_rocm():
x
,
x
,
max_block_num
,
max_block_num
,
num_tokens
,
num_tokens
,
num_programs
,
DEQUANT
:
tl
.
constexpr
,
DEQUANT
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
CACHE_FORMAT
:
tl
.
constexpr
,
CACHE_FORMAT
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_PRGMS
:
tl
.
constexpr
,
):
):
bid
=
tl
.
program_id
(
0
)
bid
=
tl
.
program_id
(
0
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
...
@@ -70,7 +69,7 @@ if current_platform.is_rocm():
...
@@ -70,7 +69,7 @@ if current_platform.is_rocm():
k_scale
=
tl
.
load
(
k_scale_ptr
)
k_scale
=
tl
.
load
(
k_scale_ptr
)
v_scale
=
tl
.
load
(
v_scale_ptr
)
v_scale
=
tl
.
load
(
v_scale_ptr
)
for
token_id
in
tl
.
range
(
bid
,
num_tokens
,
NUM_PRGMS
):
for
token_id
in
tl
.
range
(
bid
,
num_tokens
,
num_programs
):
key_ptr_offset
=
key_ptr
+
token_id
*
head_size
*
num_heads
key_ptr_offset
=
key_ptr
+
token_id
*
head_size
*
num_heads
value_ptr_offset
=
value_ptr
+
token_id
*
head_size
*
num_heads
value_ptr_offset
=
value_ptr
+
token_id
*
head_size
*
num_heads
batch_idx
=
tl
.
load
(
token_to_batch_ptr
+
token_id
)
batch_idx
=
tl
.
load
(
token_to_batch_ptr
+
token_id
)
...
@@ -162,11 +161,11 @@ if current_platform.is_rocm():
...
@@ -162,11 +161,11 @@ if current_platform.is_rocm():
x
,
x
,
block_tables
.
size
(
1
),
block_tables
.
size
(
1
),
total_tokens
,
total_tokens
,
NUM_PRGMS
,
DEQUANT
=
dequant
,
DEQUANT
=
dequant
,
PAGE_SIZE
=
page_size
,
PAGE_SIZE
=
page_size
,
CACHE_FORMAT
=
kv_cache_layout
,
CACHE_FORMAT
=
kv_cache_layout
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
NUM_PRGMS
=
NUM_PRGMS
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment