Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e06b5899
Commit
e06b5899
authored
Aug 22, 2024
by
zhuwenwen
Browse files
Update refactoring operation of pa
parent
421310ba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
24 deletions
+24
-24
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+12
-12
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+12
-12
No files found.
csrc/attention/attention_kernels.cu
View file @
e06b5899
...
...
@@ -87,7 +87,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
_opt
(
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -499,7 +499,7 @@ __device__ void paged_attention_kernel_opt(
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
_opt
(
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
...
...
@@ -516,7 +516,7 @@ __global__ void paged_attention_v1_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
...
...
@@ -531,7 +531,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
_opt
(
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -552,7 +552,7 @@ __global__ void paged_attention_v2_kernel_opt(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
...
...
@@ -564,7 +564,7 @@ __global__ void paged_attention_v2_kernel_opt(
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
_opt
(
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -672,11 +672,11 @@ __global__ void paged_attention_v2_reduce_kernel_opt(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \
vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
...
...
@@ -805,7 +805,7 @@ void paged_attention_v1_launcher(
break; \
}
void
paged_attention_v1
_opt
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
...
...
@@ -829,7 +829,7 @@ void paged_attention_v1_opt(
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
vllm::paged_attention_v2_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
PARTITION_SIZE> \
<<<grid, block, shared_mem_size, stream>>>( \
...
...
@@ -839,7 +839,7 @@ void paged_attention_v1_opt(
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel
_opt
<T, HEAD_SIZE, NUM_THREADS, \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
...
...
@@ -970,7 +970,7 @@ void paged_attention_v2_launcher(
break; \
}
void
paged_attention_v2
_opt
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
...
csrc/attention/attention_kernels_opt.cu
View file @
e06b5899
...
...
@@ -73,7 +73,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
bool
odd_nheads
=
false
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
__device__
void
paged_attention_kernel
_opt
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -593,7 +593,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
REUSE_KV_TIMES
=
1
,
bool
IS_BLOCK_SPARSE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v1_kernel
_opt
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
...
...
@@ -612,7 +612,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
...
...
@@ -629,7 +629,7 @@ template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int
REUSE_KV_TIMES
,
int
PARTITION_SIZE
,
bool
odd_nheads
=
false
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_kernel
_opt
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -652,7 +652,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
const
float
k_scale
,
const
float
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
paged_attention_kernel
_opt
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
REUSE_KV_TIMES
,
odd_nheads
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_heads
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
...
...
@@ -664,7 +664,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel
_opt
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -772,11 +772,11 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
((void*)vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
...
...
@@ -899,7 +899,7 @@ void paged_attention_v1_launcher(
break; \
}
void
paged_attention_v1
(
void
paged_attention_v1
_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
...
...
@@ -923,7 +923,7 @@ void paged_attention_v1(
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel
_opt
<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
...
...
@@ -933,7 +933,7 @@ void paged_attention_v1(
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel
_opt
<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
...
...
@@ -1046,7 +1046,7 @@ void paged_attention_v2_launcher(
break; \
}
void
paged_attention_v2
(
void
paged_attention_v2
_opt
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment