Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0c70376b
"vscode:/vscode.git/clone" did not exist on "c0f5fae601cf2649dec3cb06ad80008ced7a46ea"
Commit
0c70376b
authored
Sep 20, 2024
by
zhuwenwen
Browse files
add pa tc
parent
fe1ec8c5
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
614 additions
and
583 deletions
+614
-583
csrc/attention/attention_kernels_opt.cu
csrc/attention/attention_kernels_opt.cu
+565
-562
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+44
-18
requirements-rocm.txt
requirements-rocm.txt
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+4
-2
No files found.
csrc/attention/attention_kernels_opt.cu
View file @
0c70376b
This diff is collapsed.
Click to expand it.
csrc/attention/static_switch.h
View file @
0c70376b
...
...
@@ -9,25 +9,17 @@
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
}
else
{
\
}else
if (NUM_THREAD == 128) {
\
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
...
...
@@ -45,12 +37,12 @@
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
...
...
@@ -74,14 +66,48 @@
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}
else { \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
requirements-rocm.txt
View file @
0c70376b
...
...
@@ -14,4 +14,4 @@ torch == 2.3.0
triton == 2.1.0
flash_attn == 2.6.1
xformers == 0.0.25
lmslim == 0.1.0
\ No newline at end of file
lmslim == 0.1.1
\ No newline at end of file
vllm/attention/ops/paged_attn.py
View file @
0c70376b
...
...
@@ -124,8 +124,10 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
(
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
))
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1
=
(
max_seq_len
<
8192
and
(
max_seq_len
<
1000
or
num_seqs
*
num_heads
>
(
1024
if
num_kv_heads
<
num_heads
else
512
)))
if
use_v1
:
# Run PagedAttention V1.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment