flash_attention.py 4.66 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
import math
import torch
import torch.nn as nn

from einops import rearrange

Tri Dao's avatar
Tri Dao committed
7
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.bert_padding import unpad_input, pad_input
Tri Dao's avatar
Tri Dao committed
9
10


Tri Dao's avatar
Tri Dao committed
11
class FlashAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
12
13
14
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
Tri Dao's avatar
Tri Dao committed
15
        softmax_scale: The temperature to use for the softmax attention.
Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
Tri Dao's avatar
Tri Dao committed
21
    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
22
        super().__init__()
Tri Dao's avatar
Tri Dao committed
23
        self.softmax_scale = softmax_scale
Tri Dao's avatar
Tri Dao committed
24
25
        self.dropout_p = attention_dropout

26
    def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
32
                max_s=None, need_weights=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
                if unpadded: (nnz, 3, h, d)
33
            key_padding_mask: a bool tensor of shape (B, S)
Tri Dao's avatar
Tri Dao committed
34
35
        """
        assert not need_weights
36
        assert qkv.dtype in [torch.float16, torch.bfloat16]
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
45
46
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
Tri Dao's avatar
Tri Dao committed
47
48
49
50
                output = flash_attn_unpadded_qkvpacked_func(
                    qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
51
52
53
54
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
55
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
Tri Dao's avatar
Tri Dao committed
56
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
57
58
59
60
                output_unpad = flash_attn_unpadded_qkvpacked_func(
                    x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
Tri Dao's avatar
Tri Dao committed
66
67
68
69
            output = flash_attn_unpadded_qkvpacked_func(
                qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                softmax_scale=self.softmax_scale, causal=causal
            )
Tri Dao's avatar
Tri Dao committed
70
71
72
73

        return output, None


Tri Dao's avatar
Tri Dao committed
74
class FlashMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
75
76

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
77
                 causal=False, device=None, dtype=None, **kwargs) -> None:
Tri Dao's avatar
Tri Dao committed
78
79
80
81
82
83
84
85
86
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
87
        assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
Tri Dao's avatar
Tri Dao committed
88
89

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
90
        self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
91
92
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

Tri Dao's avatar
Tri Dao committed
93
    def forward(self, x, key_padding_mask=None, need_weights=False):
94
95
96
        """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        key_padding_mask: bool tensor of shape (batch, seqlen)
        """
Tri Dao's avatar
Tri Dao committed
97
        qkv = self.Wqkv(x)
98
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
Tri Dao's avatar
Tri Dao committed
99
100
101
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights