flash_attn_interface.py 4.45 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import torch
import torch.nn as nn

Tri Dao's avatar
Tri Dao committed
5
import flash_attn_cuda
Tri Dao's avatar
Tri Dao committed
6
7


Tri Dao's avatar
Tri Dao committed
8
9
def _flash_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal, return_softmax):
    context, softmax_lse, *rest = flash_attn_cuda.fwd(qkv, cu_seqlens, dropout_p, max_s, softmax_scale,
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
                                                       False, causal, return_softmax, None)
    # if context.isnan().any() or softmax_lse.isnan().any():
    #     breakpoint()
    S_dmask = rest[0] if return_softmax else None
    return context, softmax_lse, S_dmask


Tri Dao's avatar
Tri Dao committed
17
def _flash_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p, max_s,
Tri Dao's avatar
Tri Dao committed
18
                   softmax_scale, causal):
Tri Dao's avatar
Tri Dao committed
19
    dqkv, dp, softmax_d = flash_attn_cuda.bwd(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, dropout_p,
Tri Dao's avatar
Tri Dao committed
20
21
22
23
24
25
                                               softmax_scale, max_s, False, causal, None)
    # if dqkv.isnan().any() or softmax_d.isnan().any():
    #     breakpoint()
    return dqkv


Tri Dao's avatar
Tri Dao committed
26
class FlashAttnFun(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
27
28
29
30
31
32
33

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
        # Save rng_state because the backward pass will regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
34
        context, softmax_lse, S_dmask = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
            qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=False
        )
        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_s = max_s
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return context

    @staticmethod
    def backward(ctx, dout):
        qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
        # S_dmask is None, temporarily use another tensor just to get it running
Tri Dao's avatar
Tri Dao committed
51
        dqkv = _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
61
            dout, qkv, context, context, softmax_lse, cu_seqlens, ctx.dropout_p,
            ctx.max_s, ctx.softmax_scale, ctx.causal
        )
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None, None, None


# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
Tri Dao's avatar
Tri Dao committed
62
class FlashAttnFunWithS(torch.autograd.Function):
Tri Dao's avatar
Tri Dao committed
63
64
65
66
67
68
69

    @staticmethod
    def forward(ctx, qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal):
        # Save rng_state because the backward pass is gonna regenerate the dropout mask
        rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
        if softmax_scale is None:
            softmax_scale = qkv.shape[-1] ** (-0.5)
Tri Dao's avatar
Tri Dao committed
70
        context, softmax_lse, S_dmask = _flash_attn_forward(
Tri Dao's avatar
Tri Dao committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal=causal, return_softmax=True
        )
        ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state)
        ctx.dropout_p = dropout_p
        ctx.max_s = max_s
        ctx.softmax_scale = softmax_scale
        ctx.causal = causal
        return context, S_dmask, softmax_lse

    @staticmethod
    def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored):
        qkv, context, S_dmask, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
        if rng_state is not None:
            cur_rng_state = torch.cuda.get_rng_state()
            torch.cuda.set_rng_state(rng_state)
Tri Dao's avatar
Tri Dao committed
86
        dqkv = _flash_attn_backward(
Tri Dao's avatar
Tri Dao committed
87
88
89
90
91
92
93
94
            dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, ctx.dropout_p,
            ctx.max_s, ctx.softmax_scale, ctx.causal
        )
        if rng_state is not None:
            torch.cuda.set_rng_state(cur_rng_state)
        return dqkv, None, None, None, None, None


Tri Dao's avatar
Tri Dao committed
95
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
Tri Dao's avatar
Tri Dao committed
96
97
98
                     return_attn_probs=False):
    """dropout_p should be set to 0.0 during evaluation
    """
Tri Dao's avatar
Tri Dao committed
99
    func = FlashAttnFun if not return_attn_probs else FlashAttnFunWithS
Tri Dao's avatar
Tri Dao committed
100
    return func.apply(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causal)