Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ea79ca42
Commit
ea79ca42
authored
Jun 04, 2025
by
zhangshao
Browse files
解决cudagraph模式下,小seq大batch PA变慢的bug
parent
82e8ca03
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
47 deletions
+98
-47
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+50
-24
csrc/attention/attention_with_mask_kernels_opt_tc.cu
csrc/attention/attention_with_mask_kernels_opt_tc.cu
+48
-23
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
ea79ca42
...
@@ -19,6 +19,23 @@ typedef __hip_bfloat16 __nv_bfloat16;
...
@@ -19,6 +19,23 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template
<
bool
>
struct
AccType
{};
template
<
>
struct
AccType
<
true
>
{
using
type
=
uint16_t
;
};
template
<
>
struct
AccType
<
false
>
{
using
type
=
float
;
};
template
<
bool
is_half
>
using
__acc_type
=
typename
AccType
<
is_half
>::
type
;
std
::
string
get_device_name
()
std
::
string
get_device_name
()
{
{
hipDeviceProp_t
props
{};
hipDeviceProp_t
props
{};
...
@@ -230,6 +247,7 @@ __global__ void paged_attention_kernel_TC(
...
@@ -230,6 +247,7 @@ __global__ void paged_attention_kernel_TC(
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
using
ACC_TYPE
=
__acc_type
<
is_half
>
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
...
@@ -292,7 +310,7 @@ __global__ void paged_attention_kernel_TC(
...
@@ -292,7 +310,7 @@ __global__ void paged_attention_kernel_TC(
}
}
__syncthreads
();
__syncthreads
();
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
ACC_TYPE
*
logits
=
reinterpret_cast
<
ACC_TYPE
*>
(
shared_mem
);
// __shared__ float red_smem[2 * NUM_WARPS];
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
...
@@ -350,11 +368,9 @@ __global__ void paged_attention_kernel_TC(
...
@@ -350,11 +368,9 @@ __global__ void paged_attention_kernel_TC(
qk_vec
[
i
]
+=
alibi
;
qk_vec
[
i
]
+=
alibi
;
}
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
if
(
mask
)
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
0.
f
;
}
else
{
else
{
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
qk_vec
[
i
];
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
qk_vec
[
i
]
)
;
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
}
}
}
}
...
@@ -387,15 +403,15 @@ __global__ void paged_attention_kernel_TC(
...
@@ -387,15 +403,15 @@ __global__ void paged_attention_kernel_TC(
}
}
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max_tmp
);
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
-
qk_max_tmp
);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
)
;
exp_sum
+=
val
;
exp_sum
+=
val
;
}
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
*
inv_sum
;
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
*
inv_sum
)
;
}
}
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
...
@@ -423,10 +439,13 @@ __global__ void paged_attention_kernel_TC(
...
@@ -423,10 +439,13 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
4
*
q_boundary
){
if
(
rowid
<
4
*
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -526,10 +545,13 @@ __global__ void paged_attention_kernel_TC(
...
@@ -526,10 +545,13 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
if
(
rowid
<
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -638,10 +660,13 @@ __global__ void paged_attention_kernel_TC(
...
@@ -638,10 +660,13 @@ __global__ void paged_attention_kernel_TC(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
if
(
rowid
<
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -904,7 +929,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
...
@@ -904,7 +929,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
{
{
reusekv
=
1
;
reusekv
=
1
;
num_thread
=
256
;
num_thread
=
256
;
PARTITION_SIZE
=
512
;
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
max_num_partitions
=
DIVIDE_ROUND_UP
(
max_seq_len
,
PARTITION_SIZE
);
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
max_seq_len
==
8192
&&
num_blocks
==
1024
){
//ali test
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
if
(
batchsize
==
1
&&
qheads
==
16
&&
kvheads
==
16
){
num_thread
=
128
;
return
;}
...
@@ -1037,10 +1061,12 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1037,10 +1061,12 @@ void paged_attention_v2_launcher_opt_tc(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
constexpr
bool
is_half
=
std
::
is_same
<
T
,
uint16_t
>::
value
;
using
ACC_TYPE
=
__acc_type
<
is_half
>
;
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
=
512
;
if
(
!
is_half
)
PARTITION_SIZE
=
256
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
if
(
PA_PARTITION_SIZE
!=
0
){
if
(
PA_PARTITION_SIZE
!=
0
){
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
...
@@ -1055,7 +1081,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1055,7 +1081,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
4
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
ACC_TYPE
)
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
dim3
grid
;
...
...
csrc/attention/attention_with_mask_kernels_opt_tc.cu
View file @
ea79ca42
...
@@ -19,6 +19,21 @@ typedef __hip_bfloat16 __nv_bfloat16;
...
@@ -19,6 +19,21 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
template
<
bool
>
struct
AccType
{};
template
<
>
struct
AccType
<
true
>
{
using
type
=
uint16_t
;
};
template
<
>
struct
AccType
<
false
>
{
using
type
=
float
;
};
template
<
bool
is_half
>
using
__acc_type
=
typename
AccType
<
is_half
>::
type
;
std
::
string
get_device_name
();
std
::
string
get_device_name
();
static
const
std
::
string
device_name
=
get_device_name
();
static
const
std
::
string
device_name
=
get_device_name
();
...
@@ -214,6 +229,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -214,6 +229,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
if
(
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
return
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
scalar_t
,
uint16_t
>::
value
;
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
constexpr
bool
is_fp8
=
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kFp8E4M3
);
using
ACC_TYPE
=
__acc_type
<
is_half
>
;
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
static_assert
(
HEAD_SIZE
<=
4
*
NUM_THREADS
,
"HEAD_SIZE<=4*NUM_THREADS"
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
const
int
partition_size
=
USE_PARTITIONING
?
PARTITION_SIZE
:
num_seq_blocks
*
BLOCK_SIZE
;
...
@@ -276,7 +292,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -276,7 +292,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
__syncthreads
();
__syncthreads
();
extern
__shared__
char
shared_mem
[];
extern
__shared__
char
shared_mem
[];
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
ACC_TYPE
*
logits
=
reinterpret_cast
<
ACC_TYPE
*>
(
shared_mem
);
// __shared__ float red_smem[2 * NUM_WARPS];
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_max
[
REUSE_KV_TIMES
][
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
__shared__
float
s_logit
[
NUM_WARPS
];
...
@@ -341,11 +357,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -341,11 +357,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
}
}
const
bool
mask
=
(
token_idx
>=
seq_len
);
const
bool
mask
=
(
token_idx
>=
seq_len
);
if
(
mask
){
if
(
mask
)
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
0.
f
);
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
0.
f
;
}
else
{
else
{
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
=
qk_vec
[
i
];
from_float
(
logits
[
partition_size
*
reuse_kv_idx
+
token_idx
-
start_token_idx
]
,
qk_vec
[
i
]
)
;
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
qk_max
[
i
]
=
fmaxf
(
qk_max
[
i
],
qk_vec
[
i
]);
}
}
}
}
...
@@ -378,15 +392,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -378,15 +392,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
}
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
qk_max_tmp
=
__shfl
(
qk_max_tmp
,
0
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
-
qk_max_tmp
);
float
val
=
__expf
(
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
-
qk_max_tmp
);
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
val
;
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
val
)
;
exp_sum
+=
val
;
exp_sum
+=
val
;
}
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
exp_sum
=
block_sum
<
NUM_WARPS
>
(
s_logit
,
exp_sum
);
// Compute softmax.
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
=
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
*
inv_sum
;
from_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
,
to_float
(
logits
[(
reuse_kv_idx
*
partition_size
)
+
i
]
)
*
inv_sum
)
;
}
}
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
if
(
USE_PARTITIONING
&&
thread_idx
==
0
){
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
max_out
[
reuse_kv_idx
]
=
qk_max_tmp
;
...
@@ -414,10 +428,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -414,10 +428,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
4
*
q_boundary
){
if
(
rowid
<
4
*
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
/
4
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -517,10 +534,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -517,10 +534,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
if
(
rowid
<
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -629,10 +649,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
...
@@ -629,10 +649,13 @@ __global__ void paged_attention_kernel_TC_with_mask(
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
rows
*
4
;
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
half4_t
logits_vec
=
{
0
,
0
,
0
,
0
};
if
(
rowid
<
q_boundary
){
if
(
rowid
<
q_boundary
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
if
constexpr
(
is_half
)
logits_vec
=*
reinterpret_cast
<
half4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
else
{
for
(
int
i
=
0
;
i
<
4
;
i
++
){
auto
f_logits
=
*
reinterpret_cast
<
float4_t
*>
(
logits
+
rowid
*
partition_size
+
token_idx
-
start_token_idx
);
from_float
(
p
[
i
],
f_logits
[
i
]);
scalar_t
*
p
=
reinterpret_cast
<
scalar_t
*>
(
&
logits_vec
);
for
(
int
i
=
0
;
i
<
4
;
i
++
){
from_float
(
p
[
i
],
f_logits
[
i
]);
}
}
}
}
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
...
@@ -943,10 +966,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
...
@@ -943,10 +966,12 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
dim3
reduce_grid
(
num_heads
,
num_seqs
);
dim3
reduce_grid
(
num_heads
,
num_seqs
);
constexpr
bool
is_half
=
std
::
is_same
<
T
,
uint16_t
>::
value
;
using
ACC_TYPE
=
__acc_type
<
is_half
>
;
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
if
constexpr
(
BLOCK_SIZE
==
16
&&
IS_BLOCK_SPARSE
==
false
&&
sizeof
(
T
)
==
2
){
constexpr
int
HEAD_SIZE
=
128
;
constexpr
int
HEAD_SIZE
=
128
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
;
int
reusekv
,
num_thread
,
max_num_partitions
,
PARTITION_SIZE
=
512
;
if
(
!
is_half
)
PARTITION_SIZE
=
256
;
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
get_numberthread_and_reuse_kv_v2
(
num_thread
,
reusekv
,
PARTITION_SIZE
,
max_num_partitions
,
num_seqs
,
max_seq_len
,
num_heads
,
num_kv_heads
,
num_blocks
);
if
(
PA_PARTITION_SIZE
!=
0
){
if
(
PA_PARTITION_SIZE
!=
0
){
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
PARTITION_SIZE
=
PA_PARTITION_SIZE
;
...
@@ -961,7 +986,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
...
@@ -961,7 +986,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
4
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
sizeof
(
ACC_TYPE
)
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
if
(
max_num_partitions
==
1
)
PARTITION_SIZE
=
0
;
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
dim3
grid
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment