flash_blocksparse_attention.py 6.66 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
import math
import torch
import torch.nn as nn

from einops import rearrange

import hydra

Tri Dao's avatar
Tri Dao committed
9
10
11
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_attn.flash_blocksparse_attn_interface import convert_blockmask
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
12

Tri Dao's avatar
Tri Dao committed
13

Tri Dao's avatar
Tri Dao committed
14
class FlashBlocksparseAttention(nn.Module):
Tri Dao's avatar
Tri Dao committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    """Implement the scaled dot product attention with softmax.
    Arguments
    ---------
        softmax_temp: The temperature to use for the softmax attention.
                      (default: 1/sqrt(d_keys) where d_keys is computed at
                      runtime)
        attention_dropout: The dropout rate to apply to the attention
                           (default: 0.1)
    """
    def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0,
                 max_seq_length=2048, device=None, dtype=None):
        super().__init__()
        self.sparsity_config = hydra.utils.instantiate(sparsity_config)
        self.softmax_temp = softmax_temp
        self.dropout_p = attention_dropout

        # initialize sparse layout and register as buffer
        max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256
        layout = self.sparsity_config.make_layout(max_seq_length)
        self.register_buffer("layout", layout)
        blockmask_converted = convert_blockmask(self.layout, causal=False)
        self.register_buffer("blockmask_converted", blockmask_converted)
        # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')

    def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
                max_s=None, need_weights=False, convert_mask=True):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
            attn_mask: An implementation of BaseMask that encodes where each
                       query can attend to
            key_padding_mask: An implementation of BaseMask that encodes how
                         many query each sequence in the batch consists of
        """
        assert not need_weights
        assert attn_mask is None
        assert qkv.dtype == torch.float16
        assert qkv.is_cuda

        if cu_seqlens is None:
            batch_size = qkv.shape[0]
            seqlen = qkv.shape[1]
            # Convert mask to take a subset
            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
            assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
            blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
            if key_padding_mask is None:
                qkv = rearrange(qkv, 'b s ... -> (b s) ...')
                max_s = seqlen
                cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
                                        device=qkv.device)
Tri Dao's avatar
Tri Dao committed
67
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
76
77
                    qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
                    max_s, softmax_scale=self.softmax_temp, causal=causal
                )
                output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
            else:
                key_padding_mask_bool = key_padding_mask.bool_matrix
                nheads = qkv.shape[-2]
                x = rearrange(qkv, 'b s three h d -> b s (three h d)')
                x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
                x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
Tri Dao's avatar
Tri Dao committed
78
                output_unpad = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
                    x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
                    max_s, softmax_scale=self.softmax_temp, causal=causal
                )
                output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
                                            indices, batch_size, seqlen),
                                'b s (h d) -> b s h d', h=nheads)
        else:
            assert max_s is not None
            seqlen = max_s
            # Convert mask to take a subset
            seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256
            assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1]
            blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256]
            if convert_mask:
Tri Dao's avatar
Tri Dao committed
93
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
94
95
96
97
                    qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0,
                    max_s, softmax_scale=self.softmax_temp, causal=causal
                )
            else:
Tri Dao's avatar
Tri Dao committed
98
                output = flash_blocksparse_attn_func(
Tri Dao's avatar
Tri Dao committed
99
100
101
102
103
104
105
106
                    qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0,
                    max_s, softmax_scale=self.softmax_temp, causal=causal,
                    convert_mask=False,
                )

        return output, None


Tri Dao's avatar
Tri Dao committed
107
class FlashBlocksparseMHA(nn.Module):
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True,
                 attention_dropout=0.0, causal=False, max_seq_length=2048,
                 device=None, dtype=None, **kwargs) -> None:
        assert batch_first
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.causal = causal

        self.num_heads = num_heads
        assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
        self.head_dim = self.embed_dim // num_heads
        assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"

        self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
124
        self.inner_attn = FlashBlocksparseAttention(
Tri Dao's avatar
Tri Dao committed
125
126
127
128
129
130
131
132
133
134
135
136
            sparsity_config, attention_dropout=attention_dropout,
            max_seq_length=max_seq_length, **factory_kwargs
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

    def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
                need_weights=False):
        qkv = self.Wqkv(x)
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
        context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
                                                need_weights=need_weights, causal=self.causal)
        return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights