flash_attention.py 5.39 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
import math
import torch
import torch.nn as nn

from einops import rearrange

Tri Dao's avatar
Tri Dao committed
7
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
Tri Dao's avatar
Tri Dao committed
8
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
Tri Dao's avatar
Tri Dao committed
9
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
Tri Dao's avatar
Tri Dao committed
10
11


Tri Dao's avatar
Tri Dao committed
12
class FlashAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
13
14
15
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
Tri Dao's avatar
Tri Dao committed
16
        softmax_scale: The temperature to use for the softmax attention.
Tri Dao's avatar
Tri Dao committed
17
18
19
20
21
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
Tri Dao's avatar
Tri Dao committed
22
    def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
23
        super().__init__()
Tri Dao's avatar
Tri Dao committed
24
        self.softmax_scale = softmax_scale
Tri Dao's avatar
Tri Dao committed
25
26
        self.dropout_p = attention_dropout

27
    def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
Tri Dao's avatar
Tri Dao committed
28
29
30
31
32
33
                max_s=None, need_weights=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
                if unpadded: (nnz, 3, h, d)
34
            key_padding_mask: a bool tensor of shape (B, S)
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        """
        assert not need_weights
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
Tri Dao's avatar
Tri Dao committed
48
49
50
51
                output = flash_attn_unpadded_qkvpacked_func(
                    qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
52
53
54
55
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
56
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
Tri Dao's avatar
Tri Dao committed
57
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
58
59
60
61
                output_unpad = flash_attn_unpadded_qkvpacked_func(
                    x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                    softmax_scale=self.softmax_scale, causal=causal
                )
Tri Dao's avatar
Tri Dao committed
62
63
64
65
66
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
Tri Dao's avatar
Tri Dao committed
67
68
69
70
            output = flash_attn_unpadded_qkvpacked_func(
                qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
                softmax_scale=self.softmax_scale, causal=causal
            )
Tri Dao's avatar
Tri Dao committed
71
72
73
74

        return output, None


Tri Dao's avatar
Tri Dao committed
75
class FlashMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
76
77
78
79
80
81
82
83
84
85
86
87

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
                 causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
88
        assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
95
96
97

        assert use_rotary_emb in [None, '1d', '2d']
        self.use_rotary_emb = use_rotary_emb
        if self.use_rotary_emb == '1d':
            self.rotary_emb = RotaryEmbedding(self.head_dim)
        elif self.use_rotary_emb == '2d':
            self.rotary_emb = RotaryEmbedding2D(self.head_dim)

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
98
        self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
99
100
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

Tri Dao's avatar
Tri Dao committed
101
    def forward(self, x, key_padding_mask=None, need_weights=False):
102
103
104
        """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        key_padding_mask: bool tensor of shape (batch, seqlen)
        """
Tri Dao's avatar
Tri Dao committed
105
106
107
108
109
        qkv = self.Wqkv(x)
        if self.use_rotary_emb:
            query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
                                          h=self.num_heads).unbind(dim=2)
            query, key = self.rotary_emb(query, key, seq_dimension=-3)
eric-tc-wong's avatar
eric-tc-wong committed
110
            qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2)
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
        else:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights