Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1c2aa04c
Commit
1c2aa04c
authored
Oct 17, 2024
by
zhuwenwen
Browse files
update attention_kernels_opt_tc.cu
parent
83bb8f5a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
csrc/attention/attention_kernels_opt_tc.cu
csrc/attention/attention_kernels_opt_tc.cu
+8
-8
No files found.
csrc/attention/attention_kernels_opt_tc.cu
View file @
1c2aa04c
...
...
@@ -613,7 +613,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel_opt
(
__global__
__launch_bounds__
(
256
,
1
)
void
paged_attention_v2_reduce_kernel_opt
_tc
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
...
...
@@ -793,7 +793,7 @@ void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize
// TODO(woosuk): Tune NUM_THREADS.
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
void
paged_attention_v1_launcher_opt
(
void
paged_attention_v1_launcher_opt
_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
max_seq_len
,
...
...
@@ -862,7 +862,7 @@ void paged_attention_v1_launcher_opt(
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v1_launcher_opt
_tc
<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
...
...
@@ -915,7 +915,7 @@ void paged_attention_v1(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v1_opt
(
void
paged_attention_v1_opt
_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
...
...
@@ -959,7 +959,7 @@ void paged_attention_v1_opt(
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
(vllm::paged_attention_v2_reduce_kernel_opt
_tc
<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
...
...
@@ -1000,7 +1000,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize
template
<
typename
T
,
typename
CACHE_T
,
int
BLOCK_SIZE
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
512
>
void
paged_attention_v2_launcher_opt
(
void
paged_attention_v2_launcher_opt
_tc
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
...
...
@@ -1069,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v2_launcher_opt
_tc
<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
...
...
@@ -1127,7 +1127,7 @@ void paged_attention_v2(
const
int64_t
blocksparse_vert_stride
,
const
int64_t
blocksparse_block_size
,
const
int64_t
blocksparse_head_sliding_step
);
void
paged_attention_v2_opt
(
void
paged_attention_v2_opt
_tc
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment