"vscode:/vscode.git/clone" did not exist on "c413c41cda0f9359e7a12bb674c0f87bf41798c5"
Commit 4360cfc6 authored by Tri Dao's avatar Tri Dao
Browse files

[Triton] Fix benchmark_causal.py

parent 5d079fdd
...@@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b ...@@ -93,8 +93,8 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
repeats=repeats, desc='PyTorch Attention') repeats=repeats, desc='PyTorch Attention')
benchmark_all(flash_attn_qkvpacked_func, qkv, causal=causal, repeats=repeats, desc='FlashAttention Triton') benchmark_all(flash_attn_qkvpacked_func, qkv, None, causal, repeats=repeats, desc='FlashAttention Triton')
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal=causal, backward=True) pytorch_profiler(flash_attn_qkvpacked_func, qkv, None, causal, backward=True)
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
requires_grad=True) for _ in range(3)] requires_grad=True) for _ in range(3)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment