Commit 53910677 authored by zhuwenwen's avatar zhuwenwen
Browse files

update pa tc and gc benchmark

parent 65d64273
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random, seed_everything) create_kv_caches_with_random, seed_everything)
import vllm.envs as envs
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
PARTITION_SIZE = 512 PARTITION_SIZE = 512
...@@ -102,22 +103,40 @@ def main( ...@@ -102,22 +103,40 @@ def main(
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v1_opt( if envs.VLLM_USE_TC_PAGED_ATTN:
output, ops.paged_attention_v1_opt_tc(
query, output,
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
scale, num_kv_heads,
block_tables, scale,
seq_lens, block_tables,
block_size, seq_lens,
max_seq_len, block_size,
alibi_slopes, max_seq_len,
kv_cache_dtype, alibi_slopes,
k_scale, kv_cache_dtype,
v_scale, k_scale,
) v_scale,
)
else:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else: else:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
...@@ -137,27 +156,48 @@ def main( ...@@ -137,27 +156,48 @@ def main(
) )
elif version == "v2": elif version == "v2":
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
ops.paged_attention_v2( if envs.VLLM_USE_TC_PAGED_ATTN:
output, ops.paged_attention_v2_opt_tc(
exp_sums, output,
max_logits, exp_sums,
tmp_output, max_logits,
query, tmp_output,
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
scale, num_kv_heads,
block_tables, scale,
seq_lens, block_tables,
block_size, seq_lens,
max_seq_len, block_size,
alibi_slopes, max_seq_len,
kv_cache_dtype, alibi_slopes,
k_scale, kv_cache_dtype,
v_scale, k_scale,
) v_scale,
)
else:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else: else:
ops.paged_attention_v2_opt( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
max_logits, max_logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment