Commit c004bf6e authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmark_paged_attention.py and ops.h of convert_vertical_slash_indexes

parent 98f67566
......@@ -117,98 +117,24 @@ def main(
for _ in range(num_iters):
if version == "v1":
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
elif version == "v2":
if not args.custom_paged_attn:
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
ops.paged_attention_v2(
output,
exp_sums,
......@@ -322,12 +248,6 @@ if __name__ == "__main__":
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)")
parser.add_argument(
"--gc-paged-attn", action="store_true", help="Use gc paged attention"
)
parser.add_argument(
"--tc-paged-attn", action="store_true", help="Use tc paged attention"
)
parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
)
......
......@@ -59,7 +59,7 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment