Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
675bceed
Commit
675bceed
authored
Aug 13, 2024
by
zhuwenwen
Browse files
restore support for block_size 8 and 32
parent
d231153f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
22 deletions
+12
-22
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+12
-22
No files found.
csrc/attention/attention_kernels.cu
View file @
675bceed
...
@@ -885,25 +885,20 @@ void paged_attention_v1_launcher(
...
@@ -885,25 +885,20 @@ void paged_attention_v1_launcher(
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
}
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void
paged_attention_v1
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
query
,
// [num_seqs, num_heads, head_size]
...
@@ -1037,25 +1032,20 @@ void paged_attention_v2_launcher(
...
@@ -1037,25 +1032,20 @@ void paged_attention_v2_launcher(
// 1, 2, 4, 64, 128, 256.
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
break; \
}
}
// // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// // 1, 2, 4, 64, 128, 256.
// #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
// switch (block_size) { \
// case 16: \
// CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
// break; \
// TORCH_CHECK(false, "Unsupported block size: ", block_size); \
// break; \
// }
void
paged_attention_v2
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
out
,
// [num_seqs, num_heads, head_size]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
torch
::
Tensor
&
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment