Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7c0942b
Unverified
Commit
b7c0942b
authored
Aug 09, 2025
by
Charlie Fu
Committed by
GitHub
Aug 08, 2025
Browse files
[ROCm][Misc] Rename the context_len to seq_len in ROCm custom paged attention kernel (#22097)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
9a0c5ded
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
91 additions
and
96 deletions
+91
-96
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+87
-92
csrc/rocm/ops.h
csrc/rocm/ops.h
+2
-2
csrc/rocm/torch_bindings.cpp
csrc/rocm/torch_bindings.cpp
+2
-2
No files found.
csrc/rocm/attention.cu
View file @
b7c0942b
...
...
@@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
auto
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_start_token_idx
=
partition_idx
*
T_PAR_SIZE
;
// partition_size;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context
_len
)
{
if
(
partition_start_token_idx
>=
seq
_len
)
{
return
;
}
...
...
@@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
// across 4 rows x 4 tokens per lane
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
last_
ctx
_block
=
num_
context
_blocks
-
1
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
last_
seq
_block
=
num_
seq
_blocks
-
1
;
const
int
*
block_table_seq
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
...
...
@@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kblock_idx
=
(
kglobal_token_idx
<
context
_len
)
const
int
kblock_idx
=
(
kglobal_token_idx
<
seq
_len
)
?
kglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
kphysical_block_number
[
token_depth
]
=
block_table_seq
[
kblock_idx
];
}
...
...
@@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// tokens
const
int
vglobal_token_idx
=
partition_start_token_idx
+
vlocal_token_idx
;
const
int
vblock_idx
=
(
vglobal_token_idx
<
context
_len
)
const
int
vblock_idx
=
(
vglobal_token_idx
<
seq
_len
)
?
vglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]
=
block_table_seq
[
vblock_idx
];
}
...
...
@@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
if
constexpr
(
ALIBI_ENABLED
)
{
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
const
int
alibi_offset
=
local_token_idx
-
context
_len
+
1
;
const
int
alibi_offset
=
local_token_idx
-
seq
_len
+
1
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
d_out
[
token_depth
][
i
]
+=
alibi_slope
*
(
alibi_offset
+
i
);
}
...
...
@@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context_len
)
?
d_out
[
token_depth
][
i
]
:
-
FLT_MAX
;
const
float
tmp
=
(
local_token_idx
+
i
<
seq_len
)
?
d_out
[
token_depth
][
i
]
:
-
FLT_MAX
;
qk_max
=
fmaxf
(
qk_max
,
tmp
);
}
}
...
...
@@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context
_len
)
const
float
tmp
=
(
local_token_idx
+
i
<
seq
_len
)
?
__expf
(
d_out
[
token_depth
][
i
]
-
qk_max
)
:
0.0
f
;
d_out
[
token_depth
][
i
]
=
tmp
;
...
...
@@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
auto
partition_size
=
blockDim
.
x
;
const
auto
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
partition_start_token_idx
=
partition_idx
*
partition_size
;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context
_len
)
{
if
(
partition_start_token_idx
>=
seq
_len
)
{
return
;
}
// every 4 lanes fetch 4 different qheads
...
...
@@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
warp_start_token_idx
=
partition_start_token_idx
+
warpid
*
WARP_SIZE
;
if
(
warp_start_token_idx
>=
context
_len
)
{
// warp out of context
if
(
warp_start_token_idx
>=
seq
_len
)
{
// warp out of context
#pragma unroll
for
(
int
h
=
0
;
h
<
GQA_RATIO4
;
h
++
)
{
shared_qk_max
[
warpid
][
h
]
=
-
FLT_MAX
;
...
...
@@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
}
else
{
// warp within context
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
last_
ctx
_block
=
num_
context
_blocks
-
1
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
last_
seq
_block
=
num_
seq
_blocks
-
1
;
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// token id within partition
...
...
@@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
global_token_idx
=
partition_start_token_idx
+
local_token_idx
;
// fetch block number for k
const
int
block_idx
=
(
global_token_idx
<
context
_len
)
const
int
block_idx
=
(
global_token_idx
<
seq
_len
)
?
global_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
// fetch k physical block number
// int32 physical_block_number leads to overflow when multiplied with
...
...
@@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for
(
int
b
=
0
;
b
<
VBLOCKS
;
b
++
)
{
const
int
vblock_idx
=
warp_start_block_idx
+
b
;
const
int
vblock_idx_ctx
=
(
vblock_idx
<=
last_
ctx
_block
)
?
vblock_idx
:
last_
ctx
_block
;
(
vblock_idx
<=
last_
seq
_block
)
?
vblock_idx
:
last_
seq
_block
;
vphysical_blocks
[
b
]
=
block_table
[
vblock_idx_ctx
];
}
...
...
@@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
lane4_token_idx
=
4
*
(
global_token_idx
>>
2
);
if
constexpr
(
ALIBI_ENABLED
)
{
const
int
alibi_offset
=
lane4_token_idx
-
context
_len
+
1
;
const
int
alibi_offset
=
lane4_token_idx
-
seq
_len
+
1
;
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
d_out
[
h
][
i
]
+=
alibi_slope
[
h
]
*
(
alibi_offset
+
i
);
...
...
@@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
qk_max
[
h
]
=
-
FLT_MAX
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
qk_max
[
h
]
=
(
lane4_token_idx
+
i
<
context
_len
)
qk_max
[
h
]
=
(
lane4_token_idx
+
i
<
seq
_len
)
?
fmaxf
(
qk_max
[
h
],
d_out
[
h
][
i
])
:
qk_max
[
h
];
}
...
...
@@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for
(
int
h
=
0
;
h
<
QHLOOP
;
h
++
)
{
exp_sum
[
h
]
=
0.0
f
;
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
d_out
[
h
][
i
]
=
(
lane4_token_idx
+
i
<
context
_len
)
d_out
[
h
][
i
]
=
(
lane4_token_idx
+
i
<
seq
_len
)
?
__expf
(
d_out
[
h
][
i
]
-
qk_max
[
h
])
:
0.0
f
;
exp_sum
[
h
]
+=
d_out
[
h
][
i
];
...
...
@@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
}
if
(
warp_start_token_idx
>=
context
_len
)
{
// warp out of context
if
(
warp_start_token_idx
>=
seq
_len
)
{
// warp out of context
for
(
int
qh
=
0
;
qh
<
QHLOOP
;
qh
++
)
{
for
(
int
vh
=
0
;
vh
<
VHELOOP
;
vh
++
)
{
vout_shared
[
qh
][
vh
][
laneid
][
warpid
]
=
{
0
};
...
...
@@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
auto
num_heads
=
gridDim
.
x
;
...
...
@@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return
;
}
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
const
auto
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
...
...
@@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
// length of a seq
const
int
seq_len
=
seq
_lens
[
seq_idx
];
// length of a seq
const
int
partition_start_token_idx
=
partition_idx
*
T_PAR_SIZE
;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context
_len
)
{
if
(
partition_start_token_idx
>=
seq
_len
)
{
return
;
}
...
...
@@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
last_
ctx
_block
=
num_
context
_blocks
-
1
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
last_
seq
_block
=
num_
seq
_blocks
-
1
;
const
int
*
block_table_seq
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
...
...
@@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kblock_idx
=
(
kglobal_token_idx
<
context
_len
)
const
int
kblock_idx
=
(
kglobal_token_idx
<
seq
_len
)
?
kglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
kphysical_block_number
[
token_depth
]
=
block_table_seq
[
kblock_idx
];
}
...
...
@@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
vblock_depth
*
BLOCK_SIZE
;
const
int
vglobal_token_idx
=
partition_start_token_idx
+
vlocal_token_idx
;
const
int
vblock_idx
=
(
vglobal_token_idx
<
context
_len
)
const
int
vblock_idx
=
(
vglobal_token_idx
<
seq
_len
)
?
vglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]
=
block_table_seq
[
vblock_idx
];
}
...
...
@@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
context_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
seq_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
qk_max
=
fmaxf
(
qk_max
,
tmp
);
}
}
...
...
@@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
context
_len
)
const
float
tmp
=
(
local_token_idx
+
2
*
i
<
seq
_len
)
?
__expf
(
dout
[
token_depth
][
i
]
-
qk_max
)
:
0.0
f
;
dout
[
token_depth
][
i
]
=
tmp
;
...
...
@@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
auto
num_heads
=
gridDim
.
x
;
...
...
@@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return
;
}
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
...
...
@@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context
_lens
[
seq_idx
];
// length of a seq
const
int
seq_len
=
seq
_lens
[
seq_idx
];
// length of a seq
const
int
partition_start_token_idx
=
partition_idx
*
T_PAR_SIZE
;
// exit if partition is out of context for seq
if
(
partition_start_token_idx
>=
context
_len
)
{
if
(
partition_start_token_idx
>=
seq
_len
)
{
return
;
}
...
...
@@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
const
int
num_
context
_blocks
=
DIVIDE_ROUND_UP
(
context
_len
,
BLOCK_SIZE
);
const
int
last_
ctx
_block
=
num_
context
_blocks
-
1
;
const
int
num_
seq
_blocks
=
DIVIDE_ROUND_UP
(
seq
_len
,
BLOCK_SIZE
);
const
int
last_
seq
_block
=
num_
seq
_blocks
-
1
;
const
int
*
block_table_seq
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
...
...
@@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
klocal_token_idx
=
TOKENS_PER_WARP
*
warpid
+
token_depth
*
16
+
lane16id
;
const
int
kglobal_token_idx
=
partition_start_token_idx
+
klocal_token_idx
;
const
int
kblock_idx
=
(
kglobal_token_idx
<
context
_len
)
const
int
kblock_idx
=
(
kglobal_token_idx
<
seq
_len
)
?
kglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
kphysical_block_number
[
token_depth
]
=
block_table_seq
[
kblock_idx
];
}
...
...
@@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
rowid
*
VTOKENS_PER_LANE
+
vblock_depth
*
BLOCK_SIZE
;
const
int
vglobal_token_idx
=
partition_start_token_idx
+
vlocal_token_idx
;
const
int
vblock_idx
=
(
vglobal_token_idx
<
context
_len
)
const
int
vblock_idx
=
(
vglobal_token_idx
<
seq
_len
)
?
vglobal_token_idx
/
BLOCK_SIZE
:
last_
ctx
_block
;
:
last_
seq
_block
;
vphysical_block_number
[
vtoken_depth
][
vblock_depth
]
=
block_table_seq
[
vblock_idx
];
}
...
...
@@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context
_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
(
local_token_idx
+
i
<
seq
_len
)
?
dout
[
token_depth
][
i
]
:
-
FLT_MAX
;
qk_max
=
fmaxf
(
qk_max
,
tmp
);
}
}
...
...
@@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for
(
int
token_depth
=
0
;
token_depth
<
TLOOP
;
token_depth
++
)
{
const
int
local_token_idx
=
qkout_token_idx
+
token_depth
*
16
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
const
float
tmp
=
(
local_token_idx
+
i
<
context
_len
)
const
float
tmp
=
(
local_token_idx
+
i
<
seq
_len
)
?
__expf
(
dout
[
token_depth
][
i
]
-
qk_max
)
:
0.0
f
;
dout
[
token_depth
][
i
]
=
tmp
;
...
...
@@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
const
auto
num_heads
=
gridDim
.
x
;
...
...
@@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return
;
}
const
int
context_len
=
context
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context
_len
,
PARTITION_SIZE
);
const
int
seq_len
=
seq
_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq
_len
,
PARTITION_SIZE
);
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
...
...
@@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
num_kv_heads
,
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
...
...
@@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context
_lens
,
// [num_seqs]
const
int
*
__restrict__
seq
_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
,
const
float
*
__restrict__
fp8_out_scale_ptr
)
{
UNREACHABLE_CODE
...
...
@@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr,
context
_lens_ptr, query_start_loc_ptr, \
block_tables_ptr,
seq
_lens_ptr, query_start_loc_ptr,
\
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
...
...
@@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr,
context
_lens_ptr, query_start_loc_ptr, \
block_tables_ptr,
seq
_lens_ptr, query_start_loc_ptr,
\
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE_OLD
,
...
...
@@ -3077,8 +3074,8 @@ void paged_attention_custom_launcher(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_
context
_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_
seq
_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
...
...
@@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher(
KVT
*
key_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
key_cache
.
data_ptr
());
KVT
*
value_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
// NOTE: fp8_out_scale is optional.
...
...
@@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher(
:
nullptr
;
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
out
.
data_ptr
());
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
BLOCK_SIZE
);
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
BLOCK_SIZE
);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr
int
PARTITION_SIZE
=
256
;
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
num_heads
/
num_kv_heads
;
assert
(
num_heads
%
num_kv_heads
==
0
);
assert
(
head_size
==
HEAD_SIZE
);
...
...
@@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_
context
_len
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_
seq
_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
...
...
@@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi(
KVT
*
key_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
key_cache
.
data_ptr
());
KVT
*
value_cache_ptr
=
reinterpret_cast
<
KVT
*>
(
value_cache
.
data_ptr
());
int
*
block_tables_ptr
=
block_tables
.
data_ptr
<
int
>
();
int
*
context
_lens_ptr
=
context
_lens
.
data_ptr
<
int
>
();
int
*
seq
_lens_ptr
=
seq
_lens
.
data_ptr
<
int
>
();
const
float
*
k_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
k_scale
.
data_ptr
());
const
float
*
v_scale_ptr
=
reinterpret_cast
<
const
float
*>
(
v_scale
.
data_ptr
());
...
...
@@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi(
const
auto
fp8_out_scale_ptr
=
nullptr
;
OUTT
*
out_ptr
=
reinterpret_cast
<
OUTT
*>
(
out
.
data_ptr
());
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_
context
_len
,
BLOCK_SIZE
);
const
int
max_ctx_blocks
=
DIVIDE_ROUND_UP
(
max_
seq
_len
,
BLOCK_SIZE
);
constexpr
int
PARTITION_SIZE
=
256
;
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_context_len
,
PARTITION_SIZE
);
const
int
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
const
int
gqa_ratio
=
num_heads
/
num_kv_heads
;
assert
(
num_heads
%
num_kv_heads
==
0
);
assert
(
head_size
==
HEAD_SIZE
);
...
...
@@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables,
context
_lens, query_start_loc, \
max_
context
_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
num_kv_heads, scale, block_tables,
seq
_lens, query_start_loc,
\
max_
seq
_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
\
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables,
context
_lens, query_start_loc, \
max_
context
_len, alibi_slopes, k_scale, v_scale); \
num_kv_heads, scale, block_tables,
seq
_lens, query_start_loc,
\
max_
seq
_len, alibi_slopes, k_scale, v_scale);
\
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
...
...
@@ -3502,9 +3497,9 @@ void paged_attention(
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context
_lens
,
// [num_seqs]
torch
::
Tensor
&
seq
_lens
,
// [num_seqs]
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_
context
_len
,
int64_t
block_size
,
int64_t
max_
seq
_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
...
...
csrc/rocm/ops.h
View file @
b7c0942b
...
...
@@ -15,8 +15,8 @@ void paged_attention(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context
_lens
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq
_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int64_t
block_size
,
int64_t
max_
context
_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int64_t
max_
seq
_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
,
const
std
::
optional
<
torch
::
Tensor
>&
fp8_out_scale
);
csrc/rocm/torch_bindings.cpp
View file @
b7c0942b
...
...
@@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor
context
_lens,"
" Tensor
seq
_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_
context
_len,"
" int max_
seq
_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment