Commit f164171b authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_TC_PAGED_ATTN to convert tc pa

parent 2a9c497e
...@@ -190,6 +190,7 @@ set(VLLM_EXT_SRC ...@@ -190,6 +190,7 @@ set(VLLM_EXT_SRC
"csrc/opt/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu" "csrc/opt/activation_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu" "csrc/attention/attention_kernels_opt.cu"
"csrc/attention/attention_kernels_opt_tc.cu"
"csrc/opt/layernorm_kernels_opt.cu" "csrc/opt/layernorm_kernels_opt.cu"
# "csrc/quantization/gptq/q_gemm.cu" # "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
This diff is collapsed.
This diff is collapsed.
...@@ -9,16 +9,24 @@ ...@@ -9,16 +9,24 @@
} \ } \
}() }()
#define OPT_SWITCH(COND, ...) \
[&] { \
if (COND) { \
constexpr static int opt = 1; \
return __VA_ARGS__(); \
} else { \
constexpr static int opt = 2; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \ #define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \ [&] { \
if (NUM_THREAD == 256) { \ if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \ constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \ } else { \
constexpr static int NUM_THREADS = 64; \ constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
...@@ -40,6 +48,9 @@ ...@@ -40,6 +48,9 @@
} else if (HEADDIM == 128) { \ } else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \ constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else if (HEADDIM == 192) { \
constexpr static int HEAD_SIZE = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \ } else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \ constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
...@@ -49,33 +60,28 @@ ...@@ -49,33 +60,28 @@
} \ } \
}() }()
#define REUSEKV_SWITCH(reusekv,...) \ #define REUSEKV_SWITCH(num_blocks , ...) \
[&] { \ [&] { \
if (reusekv==16){ \ if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \ constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (reusekv==2){ \ } else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \ constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else { \ } else { \
constexpr static int REUSE_KV_TIMES = 1; \ constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \ #define REUSEKV_SWITCH_V1(num_blocks , ...) \
[&] { \ [&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \ if (num_heads > num_kv_heads && num_blocks >= 1200){ \
constexpr static int use_vmac = false; \ constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} else { \ } else { \
constexpr static int use_vmac = true; \ constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} \ } \
}() }()
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \
[&] { \
if (NUM_THREAD == 256) { \
constexpr static int NUM_THREADS = 256; \
return __VA_ARGS__(); \
}else if (NUM_THREAD == 128) { \
constexpr static int NUM_THREADS = 128; \
return __VA_ARGS__(); \
} else { \
constexpr static int NUM_THREADS = 64; \
return __VA_ARGS__(); \
} \
}()
#define HEADSIZE_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM == 64) { \
constexpr static int HEAD_SIZE = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM == 80) { \
constexpr static int HEAD_SIZE = 80; \
return __VA_ARGS__(); \
} else if (HEADDIM == 96) { \
constexpr static int HEAD_SIZE = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM == 256) { \
constexpr static int HEAD_SIZE = 256; \
return __VA_ARGS__(); \
} \
else { \
TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\
} \
}()
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (reusekv==2){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define USEVMAC_SWITCH_V1(num_blocks , ...) \
[&] { \
if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \
constexpr static int use_vmac = false; \
return __VA_ARGS__(); \
} else { \
constexpr static int use_vmac = true; \
return __VA_ARGS__(); \
} \
}()
\ No newline at end of file
...@@ -47,6 +47,27 @@ void paged_attention_v2_opt( ...@@ -47,6 +47,27 @@ void paged_attention_v2_opt(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
......
...@@ -75,6 +75,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -75,6 +75,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt); ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt_tc)
ops.def(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt_tc).
ops.def(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank,"
" int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
......
...@@ -211,6 +211,65 @@ def paged_attention_v2_opt( ...@@ -211,6 +211,65 @@ def paged_attention_v2_opt(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt_tc(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt_tc(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt_tc(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm( def paged_attention_rocm(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
......
...@@ -124,10 +124,12 @@ class PagedAttention: ...@@ -124,10 +124,12 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
# use_v1 = (max_seq_len <= 8192 if envs.VLLM_USE_TC_PAGED_ATTN:
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = (max_seq_len < 8192 use_v1 = (max_seq_len < 8192
and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512))) and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
else:
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
...@@ -137,6 +139,29 @@ class PagedAttention: ...@@ -137,6 +139,29 @@ class PagedAttention:
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v1_opt( ops.paged_attention_v1_opt(
output, output,
query, query,
...@@ -201,7 +226,33 @@ class PagedAttention: ...@@ -201,7 +226,33 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP and max_seq_len<8192: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN and max_seq_len < 8192:
ops.paged_attention_v2_opt_tc(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v2_opt( ops.paged_attention_v2_opt(
output, output,
exp_sums, exp_sums,
......
...@@ -13,6 +13,7 @@ if TYPE_CHECKING: ...@@ -13,6 +13,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
...@@ -203,6 +204,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -203,6 +204,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized tc paged attn kernels
"VLLM_USE_TC_PAGED_ATTN":
lambda: (os.environ.get("VLLM_USE_TC_PAGED_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm print pa parameters # flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM": "VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment