flashattention.py 1.97 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright 2025 StepFun Inc. All Rights Reserved.
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# ==============================================================================
import torch
# def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
#                     return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
#     softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
#     return torch.ops.Optimus._fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
from flash_attn import flash_attn_func as fa_func


def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
                    return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
    softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
    return fa_func(q, k, v, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal,
                   return_attn_probs=return_attn_probs)


class FlashSelfAttention(torch.nn.Module):
    def __init__(
            self,
            attention_dropout=0.0,
    ):
        super().__init__()
        self.dropout_p = attention_dropout

    def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
        if cu_seqlens is None:
            output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
        else:
            raise ValueError('cu_seqlens is not supported!')

        return output