flash_attention.py 5.58 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
import math
import torch
import torch.nn as nn

from einops import rearrange

Tri Dao's avatar
Tri Dao committed
7
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
Tri Dao's avatar
Tri Dao committed
9
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
Tri Dao's avatar
Tri Dao committed
10
11


Tri Dao's avatar
Tri Dao committed
12
class FlashAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
13
14
15
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
Tri Dao's avatar
Tri Dao committed
16
        softmax_scale: The temperature to use for the softmax attention.
Tri Dao's avatar
Tri Dao committed
17
18
19
20
21
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
Tri Dao's avatar
Tri Dao committed
22
    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
23
        super().__init__()
Tri Dao's avatar
Tri Dao committed
24
        self.softmax_scale = softmax_scale
Tri Dao's avatar
Tri Dao committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        self.dropout_p = attention_dropout

    def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
                max_s=None, need_weights=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
                if unpadded: (nnz, 3, h, d)
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            key_padding_mask: An implementation of BaseMask that encodes how
                         many query each sequence in the batch consists of
        """
        assert not need_weights
        assert attn_mask is None
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
Tri Dao's avatar
Tri Dao committed
52
53
54
55
                output = flash_attn_unpadded_qkvpacked_func(
                    qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
56
57
58
59
60
61
62
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                key_padding_mask_bool = key_padding_mask.bool_matrix
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
63
64
65
66
                output_unpad = flash_attn_unpadded_qkvpacked_func(
                    x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
67
68
69
70
71
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
Tri Dao's avatar
Tri Dao committed
72
73
74
75
            output = flash_attn_unpadded_qkvpacked_func(
                qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                softmax_scale=self.softmax_scale, causal=causal
            )
Tri Dao's avatar
Tri Dao committed
76
77
78
79

        return output, None


Tri Dao's avatar
Tri Dao committed
80
class FlashMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
                 causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
        assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"

        assert use_rotary_emb in [None, '1d', '2d']
        self.use_rotary_emb = use_rotary_emb
        if self.use_rotary_emb == '1d':
            self.rotary_emb = RotaryEmbedding(self.head_dim)
        elif self.use_rotary_emb == '2d':
            self.rotary_emb = RotaryEmbedding2D(self.head_dim)

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
103
        self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

    def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
                need_weights=False):
        qkv = self.Wqkv(x)
        if self.use_rotary_emb:
            query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
                                          h=self.num_heads).unbind(dim=2)
            query, key = self.rotary_emb(query, key, seq_dimension=-3)
            qkv = torch.stack([query, key, value], dim=2)
        else:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights