Commit 1a9a61d7 authored by zhuwenwen's avatar zhuwenwen
Browse files

add fa-pa benchmark

parent 937a3ec1
...@@ -104,6 +104,10 @@ def main( ...@@ -104,6 +104,10 @@ def main(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if version == "v12":
sliding_window = ((-1, -1))
logits_soft_cap = 0.0
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize() torch.cuda.synchronize()
if profile: if profile:
...@@ -247,6 +251,24 @@ def main( ...@@ -247,6 +251,24 @@ def main(
k_scale, k_scale,
v_scale, v_scale,
) )
elif version == "v12":
from flash_attn import vllm_flash_attn_with_kvcache
vllm_flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=seq_lens,
block_table=block_tables,
softmax_scale=scale,
causal=True,
window_size=sliding_window,
softcap=logits_soft_cap,
alibi_slopes=alibi_slopes,
return_softmax_lse=False,
k_scale=k_scale,
v_scale=v_scale,
kv_cache_dtype=kv_cache_dtype,
).squeeze(1)
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -276,7 +298,7 @@ if __name__ == "__main__": ...@@ -276,7 +298,7 @@ if __name__ == "__main__":
) )
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--version", type=str, choices=["v1", "v2", "v12"], default="v12")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
...@@ -287,7 +309,7 @@ if __name__ == "__main__": ...@@ -287,7 +309,7 @@ if __name__ == "__main__":
choices=[64, 80, 96, 112, 120, 128, 192, 256], choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128, default=128,
) )
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32, 64], default=64)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
parser.add_argument( parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment