streaming_attention.py 5.54 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import math
import torch
import torch.nn as nn

from einops import rearrange

from rotary import RotaryEmbedding, RotaryEmbedding2D
from stream_attn_interface import stream_attn_func
from bert_padding import unpad_input, pad_input, index_first_axis


class StreamingAttention(nn.Module):
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
    def __init__(self, softmax_temp=None, attention_dropout=0.0, device=None, dtype=None):
        super().__init__()
        self.softmax_temp = softmax_temp
        self.dropout_p = attention_dropout

    def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
                max_s=None, need_weights=False):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
                if unpadded: (nnz, 3, h, d)
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            key_padding_mask: An implementation of BaseMask that encodes how
                         many query each sequence in the batch consists of
        """
        assert not need_weights
        assert attn_mask is None
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
                output = stream_attn_func(qkv, cu_seqlens, self.dropout_p if self.training else 0.0,
                                        max_s, softmax_scale=self.softmax_temp, causal=causal)
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                key_padding_mask_bool = key_padding_mask.bool_matrix
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
                output_unpad = stream_attn_func(x_unpad, cu_seqlens,
                                                self.dropout_p if self.training else 0.0,
                                                max_s, softmax_scale=self.softmax_temp, causal=causal)
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
            output = stream_attn_func(qkv, cu_seqlens,
                                      self.dropout_p if self.training else 0.0,
                                      max_s, softmax_scale=self.softmax_temp, causal=causal)

        return output, None


class StreamingMHA(nn.Module):

    def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
                 causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
        assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"

        assert use_rotary_emb in [None, '1d', '2d']
        self.use_rotary_emb = use_rotary_emb
        if self.use_rotary_emb == '1d':
            self.rotary_emb = RotaryEmbedding(self.head_dim)
        elif self.use_rotary_emb == '2d':
            self.rotary_emb = RotaryEmbedding2D(self.head_dim)

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
        self.inner_attn = StreamingAttention(attention_dropout=attention_dropout, **factory_kwargs)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

    def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
                need_weights=False):
        qkv = self.Wqkv(x)
        if self.use_rotary_emb:
            query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
                                          h=self.num_heads).unbind(dim=2)
            query, key = self.rotary_emb(query, key, seq_dimension=-3)
            qkv = torch.stack([query, key, value], dim=2)
        else:
            qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights