flash_blocksparse_attn_interface.py 6.87 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn

Tri Dao's avatar
Tri Dao committed
5
import flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42


def convert_blockmask(blockmask, causal):
    """Convert from the 0-1 format to the format used by the CUDA code.
    0 means the block is skipped.
    nonzero means the block is not skipped.
    Argument:
        blockmask: (row, col): a 0-1 tensor
    Return:
        blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
            indices of the nonzero blocks, padded with -1 to reach length @row.
            The indices are multiplied by 4, with the smallest bit used to encode whether
            it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
            the last nonzero in its row..
    """
    assert not causal
    # TD [2022-05-13]: The indexing and sorting is very tricky
    nrow, ncol = blockmask.shape
    # Sort does not support bool on CUDA
    blockmask = blockmask.to(dtype=torch.uint8)
    nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
    nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
    last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
    last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
        torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
    ]
    first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
    first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
        torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
    ]
    nonzero_idx = nonzero_sorted_rowidx * 4
    nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
    nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
    nonzero_idx[nonzero_val == 0] = -1
    return nonzero_idx.T.contiguous().to(dtype=torch.int32)


Tri Dao's avatar
Tri Dao committed
43
def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale,
Tri Dao's avatar
Tri Dao committed
44
                                     causal, return_softmax):
Tri Dao's avatar
Tri Dao committed
45
    context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p,
Tri Dao's avatar
Tri Dao committed
46
47
48
49
50
51
52
53
                                                             max_s, softmax_scale, causal,
                                                             return_softmax, None)
    # if context.isnan().any() or softmax_lse.isnan().any():
    #     breakpoint()
    S_dmask = rest[0] if return_softmax else None
    return context, softmax_lse, S_dmask


Tri Dao's avatar
Tri Dao committed
54
def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask,
Tri Dao's avatar
Tri Dao committed
55
                                      dropout_p, max_s, softmax_scale, causal):
Tri Dao's avatar
Tri Dao committed
56
    dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens,
Tri Dao's avatar
Tri Dao committed
57
58
59
60
61
62
63
                                                     blockmask, dropout_p, softmax_scale, max_s,
                                                     causal, None)
    # if dqkv.isnan().any() or softmax_d.isnan().any():
    #     breakpoint()
    return dqkv


Tri Dao's avatar
Tri Dao committed
64
class FlashBlocksparseAttnFun(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
65
66
67
68
69
70
71

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
72
        context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
Tri Dao's avatar
Tri Dao committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
            return_softmax=False
        )
        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_s = max_s
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return context

    @staticmethod
    def backward(ctx, dout):
        qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
        # S_dmask is None, temporarily use another tensor just to get it running
Tri Dao's avatar
Tri Dao committed
90
        dqkv = _flash_blocksparse_attn_backward(
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97
98
99
100
            dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
            ctx.max_s, ctx.softmax_scale, ctx.causal
        )
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None, None, None, None


# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
Tri Dao's avatar
Tri Dao committed
101
class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
102
103
104
105
106
107
108

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal):
        # Save rng_state because the backward pass is gonna regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
109
        context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward(
Tri Dao's avatar
Tri Dao committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal,
            return_softmax=True
        )
        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_s = max_s
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return context, S_dmask, softmax_lse

    @staticmethod
    def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
        qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
Tri Dao's avatar
Tri Dao committed
126
        dqkv = _flash_blocksparse_attn_backward(
Tri Dao's avatar
Tri Dao committed
127
128
129
130
131
132
133
134
            dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p,
            ctx.max_s, ctx.softmax_scale, ctx.causal
        )
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None, None, None


Tri Dao's avatar
Tri Dao committed
135
def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None,
Tri Dao's avatar
Tri Dao committed
136
137
138
                                 causal=False, return_attn_probs=False, convert_mask=True):
    """dropout_p should be set to 0.0 during evaluation
    """
Tri Dao's avatar
Tri Dao committed
139
    func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS
Tri Dao's avatar
Tri Dao committed
140
141
142
    if convert_mask:
        blockmask = convert_blockmask(blockmask, causal=causal)
    return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal)