Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87bdb89f
Commit
87bdb89f
authored
Mar 15, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'origin/0.7.2-zhangshao' into v0.7.2-pa
parents
e6fd8fda
4f8d38c8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
19 deletions
+28
-19
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+16
-12
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+12
-7
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
87bdb89f
...
@@ -124,11 +124,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
...
@@ -124,11 +124,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{
{
if
constexpr
(
is_half
){
if
constexpr
(
is_half
){
asm
volatile
(
"v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
else
{
else
{
asm
volatile
(
"v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
}
}
...
@@ -147,7 +147,7 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t&
...
@@ -147,7 +147,7 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t&
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
>
// Zero means no partitioning.
bool
IS_BLOCK_SPARSE
,
int
REUSE_KV_TIMES
>
// Zero means no partitioning.
__global__
void
paged_attention_kernel_TC
_with_mask
(
__global__
void
paged_attention_kernel_TC
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads,head_size]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads,head_size]
...
@@ -175,7 +175,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -175,7 +175,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
...
@@ -225,10 +225,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -225,10 +225,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
if
(
thread_idx
<
16
){
if
(
thread_idx
<
16
){
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
#pragma unroll
if
constexpr
(
is_half
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
scalar_t
*
t
=
reinterpret_cast
<
scalar_t
*>
(
&
temp
);
temp
.
data
[
0
][
k
]
=
((
float
)
temp
.
data
[
0
][
k
])
*
scale
;
#pragma unroll
temp
.
data
[
1
][
k
]
=
((
float
)
temp
.
data
[
1
][
k
])
*
scale
;
for
(
int
k
=
0
;
k
<
8
;
k
++
){
from_float
(
t
[
k
],
to_float
(
t
[
k
])
*
scale
);
}
}
}
q_vecs
[
i
][
thread_idx
]
=
temp
;
q_vecs
[
i
][
thread_idx
]
=
temp
;
}
}
...
@@ -265,12 +267,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -265,12 +267,14 @@ __global__ void paged_attention_kernel_TC_with_mask(
int
reuse_kv_idx
=
rows
+
i
*
4
;
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
else
{
if
constexpr
(
!
is_half
)
qk_vec
[
i
]
*=
scale
;
}
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
if
(
alibi_slope
[
i
]
!=
0
){
if
(
alibi_slope
[
i
]
!=
0
){
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
qk_vec
[
i
]
+=
alibi
;
qk_vec
[
i
]
+=
alibi
;
}
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
if
(
mask
){
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
...
@@ -764,16 +768,16 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
...
@@ -764,16 +768,16 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
(vllm::paged_attention_kernel_TC
_with_mask<
\
(vllm::paged_attention_kernel_TC
<
\
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks,
\
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step,PARTITION_SIZE);\
blocksparse_head_sliding_step,PARTITION_SIZE);
\
if (max_num_partitions<=64&&max_num_partitions>1){ \
if (max_num_partitions<=64&&max_num_partitions>1){ \
hipLaunchKernelGGL( \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
...
...
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
87bdb89f
...
@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
...
@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{
{
if
constexpr
(
is_half
){
if
constexpr
(
is_half
){
asm
volatile
(
"v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
else
{
else
{
asm
volatile
(
"v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
asm
volatile
(
"
\n
s_nop 1
\n
v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0"
:
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
"=v"
(
reg_c
)
:
"v"
(
reg_a
),
"v"
(
reg_b
),
"0"
(
reg_c
));
}
}
}
}
...
@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
seq_len
=
__builtin_amdgcn_readfirstlane
(
seq_lens
[
seq_idx
]);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
const
bool
USE_PARTITIONING
=
PARTITION_SIZE
<
num_seq_blocks
*
BLOCK_SIZE
&&
PARTITION_SIZE
>
0
;
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
...
@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
for
(
int
i
=
0
;
i
<
q_boundary
;
i
++
){
if
(
thread_idx
<
16
){
if
(
thread_idx
<
16
){
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
half4x2
temp
=
*
reinterpret_cast
<
const
half4x2
*>
(
q_ptr
+
i
*
HEAD_SIZE
+
thread_idx
*
8
);
#pragma unroll
if
constexpr
(
is_half
){
for
(
int
k
=
0
;
k
<
4
;
k
++
){
scalar_t
*
t
=
reinterpret_cast
<
scalar_t
*>
(
&
temp
);
temp
.
data
[
0
][
k
]
=
((
float
)
temp
.
data
[
0
][
k
])
*
scale
;
#pragma unroll
temp
.
data
[
1
][
k
]
=
((
float
)
temp
.
data
[
1
][
k
])
*
scale
;
for
(
int
k
=
0
;
k
<
8
;
k
++
){
from_float
(
t
[
k
],
to_float
(
t
[
k
])
*
scale
);
}
}
}
q_vecs
[
i
][
thread_idx
]
=
temp
;
q_vecs
[
i
][
thread_idx
]
=
temp
;
}
}
...
@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int
reuse_kv_idx
=
rows
+
i
*
4
;
int
reuse_kv_idx
=
rows
+
i
*
4
;
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
<
REUSE_KV_TIMES
){
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
if
(
reuse_kv_idx
>=
q_boundary
)
qk_vec
[
i
]
=
0
;
else
{
if
constexpr
(
!
is_half
)
qk_vec
[
i
]
*=
scale
;
}
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rowid
;
if
(
alibi_slope
[
i
]
!=
0
){
if
(
alibi_slope
[
i
]
!=
0
){
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
float
alibi
=
alibi_slope
[
i
]
*
(
token_idx
-
seq_len
+
1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment