Commit 0c70376b authored by zhuwenwen's avatar zhuwenwen
Browse files

add pa tc

parent fe1ec8c5
This diff is collapsed.
...@@ -9,25 +9,17 @@ ...@@ -9,25 +9,17 @@
} \ } \
}() }()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \ #define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \ [&] { \
if (NUM_THREAD == 256) { \ if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \ constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else { \ }else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \ constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \ } \
}() }()
...@@ -45,12 +37,12 @@ ...@@ -45,12 +37,12 @@
} else if (HEADDIM == 112) { \ } else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \ constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \ } else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \ constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \ } else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \ constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
...@@ -74,10 +66,16 @@ ...@@ -74,10 +66,16 @@
} \ } \
}() }()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \ #define REUSEKV_SWITCH_V2( ...) \
[&] { \ [&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \ if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 2; \ constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else { \ } else { \
constexpr static int REUSE_KV_TIMES = 1; \ constexpr static int REUSE_KV_TIMES = 1; \
...@@ -85,3 +83,31 @@ ...@@ -85,3 +83,31 @@
} \ } \
}() }()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
...@@ -14,4 +14,4 @@ torch == 2.3.0 ...@@ -14,4 +14,4 @@ torch == 2.3.0
triton == 2.1.0 triton == 2.1.0
flash_attn == 2.6.1 flash_attn == 2.6.1
xformers == 0.0.25 xformers == 0.0.25
lmslim == 0.1.0 lmslim == 0.1.1
\ No newline at end of file \ No newline at end of file
...@@ -124,8 +124,10 @@ class PagedAttention: ...@@ -124,8 +124,10 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192 # use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)) # and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = (max_seq_len < 8192
and (max_seq_len<1000 or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment