Commit c004bf6e authored by zhuwenwen's avatar zhuwenwen
Browse files

update benchmark_paged_attention.py and ops.h of convert_vertical_slash_indexes

parent 98f67566
...@@ -117,42 +117,6 @@ def main( ...@@ -117,42 +117,6 @@ def main(
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1_opt(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v1( ops.paged_attention_v1(
output, output,
query, query,
...@@ -171,44 +135,6 @@ def main( ...@@ -171,44 +135,6 @@ def main(
) )
elif version == "v2": elif version == "v2":
if not args.custom_paged_attn: if not args.custom_paged_attn:
if args.gc_paged_attn:
if args.tc_paged_attn:
ops.paged_attention_v1_opt_tc(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_v2_opt(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
ops.paged_attention_v2( ops.paged_attention_v2(
output, output,
exp_sums, exp_sums,
...@@ -322,12 +248,6 @@ if __name__ == "__main__": ...@@ -322,12 +248,6 @@ if __name__ == "__main__":
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (hcu) supports fp8 (=fp8_e4m3)") "ROCm (hcu) supports fp8 (=fp8_e4m3)")
parser.add_argument(
"--gc-paged-attn", action="store_true", help="Use gc paged attention"
)
parser.add_argument(
"--tc-paged-attn", action="store_true", help="Use tc paged attention"
)
parser.add_argument( parser.add_argument(
"--custom-paged-attn", action="store_true", help="Use custom paged attention" "--custom-paged-attn", action="store_true", help="Use custom paged attention"
) )
......
...@@ -59,7 +59,7 @@ void merge_attn_states(torch::Tensor& output, ...@@ -59,7 +59,7 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& prefix_lse, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse); const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes( void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment