"vscode:/vscode.git/clone" did not exist on "c0f5fae601cf2649dec3cb06ad80008ced7a46ea"
Commit 0c70376b authored by zhuwenwen's avatar zhuwenwen
Browse files

add pa tc

parent fe1ec8c5
This diff is collapsed.
......@@ -9,25 +9,17 @@
} \
}()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
} else { \
}else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
......@@ -45,12 +37,12 @@
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
......@@ -74,14 +66,48 @@
} \
}()
#define REUSEKV_SWITCH_V1(num_blocks , ...) \
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads > num_kv_heads && num_blocks >= 1200){ \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
......@@ -14,4 +14,4 @@ torch == 2.3.0
triton == 2.1.0
flash_attn == 2.6.1
xformers == 0.0.25
lmslim == 0.1.0
\ No newline at end of file
lmslim == 0.1.1
\ No newline at end of file
......@@ -124,8 +124,10 @@ class PagedAttention:
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = (max_seq_len < 8192
and (max_seq_len<1000 or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
if use_v1:
# Run PagedAttention V1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment