flash_attention.py 5.55 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
import math
import torch
import torch.nn as nn

from einops import rearrange

Tri Dao's avatar
Tri Dao committed
7
8
9
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
Tri Dao's avatar
Tri Dao committed
10
11


Tri Dao's avatar
Tri Dao committed
12
class FlashAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
    def __init__(self, softmax_temp=None, attention_dropout=0.0, device=None, dtype=None):
        super().__init__()
        self.softmax_temp = softmax_temp
        self.dropout_p = attention_dropout

    def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
                max_s=None, need_weights=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
                if unpadded: (nnz, 3, h, d)
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            key_padding_mask: An implementation of BaseMask that encodes how
                         many query each sequence in the batch consists of
        """
        assert not need_weights
        assert attn_mask is None
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
Tri Dao's avatar
Tri Dao committed
52
                output = flash_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0,
Tri Dao's avatar
Tri Dao committed
53
54
55
56
57
58
59
60
                                        max_s, softmax_scale=self.softmax_temp, causal=causal)
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                key_padding_mask_bool = key_padding_mask.bool_matrix
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
61
                output_unpad = flash_attn_func(x_unpad, cu_seqlens,
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
67
68
                                                self.dropout_p if self.training else 0.0,
                                                max_s, softmax_scale=self.softmax_temp, causal=causal)
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
Tri Dao's avatar
Tri Dao committed
69
            output = flash_attn_func(qkv, cu_seqlens,
Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
                                      self.dropout_p if self.training else 0.0,
                                      max_s, softmax_scale=self.softmax_temp, causal=causal)

        return output, None


Tri Dao's avatar
Tri Dao committed
76
class FlashMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
                 causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
        assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"

        assert use_rotary_emb in [None, '1d', '2d']
        self.use_rotary_emb = use_rotary_emb
        if self.use_rotary_emb == '1d':
            self.rotary_emb = RotaryEmbedding(self.head_dim)
        elif self.use_rotary_emb == '2d':
            self.rotary_emb = RotaryEmbedding2D(self.head_dim)

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
99
        self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

    def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
                need_weights=False):
        qkv = self.Wqkv(x)
        if self.use_rotary_emb:
            query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
                                          h=self.num_heads).unbind(dim=2)
            query, key = self.rotary_emb(query, key, seq_dimension=-3)
            qkv = torch.stack([query, key, value], dim=2)
        else:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights