flash.py 719 Bytes
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright © 2023-2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

"""
Fused Attention
===============

This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)

Extra Credits:
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
- Adam P. Goucher for simplified vector math

"""

import triton
import triton.language as tl

from flash_attn.fwd_kernel import attn_fwd
from flash_attn.bwd_preprocess import bwd_preprocess
from flash_attn.bwd_split_kernel import bwd_kernel_dk_dv, bwd_kernel_dq
from flash_attn.dropout_rng import debug_fill_dropout_rng