benchmark_causal.py 4.09 KB
Newer Older
1
2
3
4
5
6
7
8
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat

9
from flash_attn.utils.benchmark import benchmark_all, pytorch_profiler
10
11
12
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.triton.fused_attention import attention as attention

13
14
15
16
try:
    from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
except ImportError:
    scaled_upper_triang_masked_softmax = None
17

18
19

def attention_pytorch(qkv, dropout_p=0.0, causal=True):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
        dropout_p: float
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
    """
    batch_size, seqlen, _, nheads, d = qkv.shape
    q, k, v = qkv.unbind(dim=2)
    q = rearrange(q, 'b t h d -> (b h) t d')
    k = rearrange(k, 'b s h d -> (b h) d s')
    softmax_scale = 1.0 / math.sqrt(d)
    # Preallocate attn_weights for `baddbmm`
    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
                       '(b h) t s -> b h t s', h=nheads)
    if causal:
        # "triu_tril_cuda_template" not implemented for 'BFloat16'
        # So we have to construct the mask in float
        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
        scores = scores + causal_mask.to(dtype=scores.dtype)
    attention = torch.softmax(scores, dim=-1)
    attention_drop = F.dropout(attention, dropout_p)
    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
    return output.to(dtype=qkv.dtype)


def attention_triton(q, k, v):
    """
    No dropout and only support causal=True.
    Triton implementation seems to require q, k, v being contiguous?
    Arguments:
        q, k, v: (batch_size, nheads, seqlen, head_dim)
    Output:
        output: (batch_size, nheads, seqlen, head_dim)
    """
    softmax_scale = 1.0 / math.sqrt(q.shape[-1])
    return attention(q, k, v, softmax_scale)


61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def attention_megatron(qkv):
    """
    Arguments:
        qkv: (batch_size, seqlen, 3, nheads, head_dim)
    Output:
        output: (batch_size, seqlen, nheads, head_dim)
    """
    batch_size, seqlen, _, nheads, d = qkv.shape
    q, k, v = qkv.unbind(dim=2)
    q = rearrange(q, 'b t h d -> (b h) t d')
    k = rearrange(k, 'b s h d -> (b h) d s')
    softmax_scale = 1.0 / math.sqrt(d)
    # Preallocate attn_weights for `baddbmm`
    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
                       '(b h) t s -> b h t s', h=nheads)
    attention = scaled_upper_triang_masked_softmax(scores, None, scale=1.0)
    output = torch.einsum('bhts,bshd->bthd', attention, v)
    return output.to(dtype=qkv.dtype)


82
83
84
torch.manual_seed(0)
repeats = 30
batch_size = 2
85
seqlen = 4096
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
nheads = 12
headdim = 128
dropout_p = 0.0
causal = True
dtype = torch.bfloat16
device = 'cuda'

qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
                  requires_grad=True)
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                          device=qkv.device)

benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'),
              cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
              repeats=repeats, desc='PyTorch Attention')

q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
                       requires_grad=True) for _ in range(3)]
benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton')
106
107
108

if scaled_upper_triang_masked_softmax is not None:
    benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')