Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
903593d3
Commit
903593d3
authored
Oct 22, 2024
by
zhuwenwen
Browse files
解决pa v1 tc 部分 size bug
parent
2009d4a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
58 deletions
+48
-58
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+48
-58
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
903593d3
...
@@ -302,6 +302,7 @@ __device__ void paged_attention_kernel_TC(
...
@@ -302,6 +302,7 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
}
}
// if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]);
// Perform reduction across the threads in the same warp to get the
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
// The 0-th thread of each thread group already has its max qk value.
...
@@ -401,33 +402,30 @@ __device__ void paged_attention_kernel_TC(
...
@@ -401,33 +402,30 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
__syncthreads
();
__syncthreads
();
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
// Perform reduction across warps.
// Perform reduction across warps.
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
q_boundary
;
reuse_kv_idx
++
)
{
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
if
constexpr
(
NUM_THREADS
>
64
){
#pragma unroll
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared_mem
);
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
#pragma unroll
int
mid
=
i
/
2
;
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
// Upper warps write to shared memory.
int
mid
=
i
/
2
;
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
// Upper warps write to shared memory.
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
#pragma unroll
out_smem
[(
warp_idx
-
mid
)
*
64
+
lane
]
=*
(
floatV_t
*
)(
accs
[
reuse_kv_idx
]);
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
][
i
];
}
}
}
__syncthreads
();
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
// Lower warps update the output.
floatV_t
tmp
=
out_smem
[
thread_idx
];
if
(
warp_idx
<
mid
)
{
#pragma unroll
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
#pragma unroll
accs
[
reuse_kv_idx
][
i
]
+=
tmp
[
i
];
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
}
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
}
}
__syncthreads
();
}
}
__syncthreads
();
}
}
// Write the final output.
// Write the final output.
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
...
@@ -487,37 +485,35 @@ __device__ void paged_attention_kernel_TC(
...
@@ -487,37 +485,35 @@ __device__ void paged_attention_kernel_TC(
}
}
}
}
}
}
}
}
__syncthreads
();
if
constexpr
(
NUM_THREADS
>
64
){
// Perform reduction across warps.
__syncthreads
();
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
GROUPS
;
reuse_kv_idx
++
)
{
using
floatV_t
=
__attribute__
(
(
__vector_size__
(
NUM_ROWS_PER_THREAD
*
sizeof
(
float
))
))
float
;
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Perform reduction across warps.
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
for
(
int
reuse_kv_idx
=
0
;
reuse_kv_idx
<
GROUPS
;
reuse_kv_idx
++
)
{
int
mid
=
i
/
2
;
// Upper warps write to
shared
mem
ory.
floatV_t
*
out_smem
=
reinterpret_cast
<
floatV_t
*>
(
shared
_
mem
);
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
#pragma unroll
float
*
dst
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
(
warp_idx
-
mid
)
*
HEAD_SIZE
];
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
#pragma unroll
int
mid
=
i
/
2
;
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
// Upper warps write to shared memory.
const
int
row
_idx
=
lane
+
i
*
WARP_SIZE
;
if
(
warp
_idx
>
=
mid
&&
warp_idx
<
i
)
{
dst
[
row_idx
]
=
accs
[
reuse_kv_idx
]
[
i
]
;
out_smem
[(
warp_idx
-
mid
)
*
64
+
lane
]
=*
(
floatV_t
*
)(
accs
[
reuse_kv_idx
]
)
;
}
}
}
__syncthreads
();
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[(
reuse_kv_idx
*
(
NUM_WARPS
/
2
)
*
HEAD_SIZE
)
+
warp_idx
*
HEAD_SIZE
];
float
V_t
tmp
=
out_smem
[
thread_idx
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
+
i
*
WARP_SIZE
;
accs
[
reuse_kv_idx
][
i
]
+=
tmp
[
i
]
;
accs
[
reuse_kv_idx
][
i
]
+=
src
[
row_idx
];
}
}
}
__syncthreads
();
}
}
__syncthreads
();
}
}
// Write the final output.
}
}
if
(
warp_idx
==
0
)
{
if
(
warp_idx
==
0
)
{
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
for
(
int
g
=
0
;
g
<
reuse_group
;
g
++
){
int
reusekvid
=
g
*
4
+
rows
;
int
reusekvid
=
g
*
4
+
rows
;
...
@@ -842,16 +838,10 @@ void paged_attention_v1_launcher_opt_tc(
...
@@ -842,16 +838,10 @@ void paged_attention_v1_launcher_opt_tc(
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
padded_max_seq_len
*
2
;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
if
(
REUSE_KV_TIMES
==
1
)
outputs_size
=
0
;
if
(
NUM_WARPS
==
64
)
outputs_size
=
0
;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
int
shared_mem_size
=
::
max
(
logits_size
,
outputs_size
);
if
(
num_heads
==
num_kv_heads
)
shared_mem_size
=
::
max
(
12
*
1024
,
shared_mem_size
);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
grid
((
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
,
1
,
num_seqs
);
dim3
block
(
NUM_THREADS
);
dim3
block
(
NUM_THREADS
);
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
if
(
PA_PRINT_PARAM
)
printf
(
"reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d
\n
"
,
...
@@ -1054,7 +1044,7 @@ void paged_attention_v2_launcher_opt_tc(
...
@@ -1054,7 +1044,7 @@ void paged_attention_v2_launcher_opt_tc(
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
NUM_THREADS_SWITCH
(
num_thread
,
[
&
]
{
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
logits_size
=
REUSE_KV_TIMES
*
PARTITION_SIZE
*
2
;
int
outputs_size
=
REUSE_KV_TIMES
*
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
int
outputs_size
=
(
NUM_WARPS
/
2
)
*
head_size
*
sizeof
(
float
);
dim3
grid
;
dim3
grid
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
x
=
(
num_heads
/
num_kv_heads
+
REUSE_KV_TIMES
-
1
)
/
REUSE_KV_TIMES
*
num_kv_heads
;
grid
.
y
=
max_num_partitions
;
grid
.
y
=
max_num_partitions
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment