__init__.py 1.96 KB
Newer Older
zhangshao's avatar
zhangshao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
__version__ = "2.8.3"

import torch
if torch.cuda.is_available():
    from flash_attn.flash_attn_interface import (
        flash_attn_func,
        flash_attn_kvpacked_func,
        flash_attn_qkvpacked_func,
        flash_attn_varlen_func,
        hg_flash_attn_varlen_func,
        vllm_flash_attn_varlen_func,
        flash_attn_varlen_kvpacked_func,
        flash_attn_varlen_qkvpacked_func,
        flash_attn_with_kvcache,
        vllm_flash_attn_with_kvcache,
        sparse_attn_varlen_func,
        sparse_attn_func,
        flash_attn_func_blasst,
        sparse_attn_with_sla,
        spas_fa2_attn_meansim_cuda,
        spas_fa2_attn_meansim_topk_cuda,
        spas_fa2_attn_meansim_varlen_cuda,
        spas_fa2_attn_meansim_topk_varlen_cuda,
        # Attnmask functions - FlashAttention with explicit attention mask
        flash_attn_with_mask_func,
        flash_attn_varlen_with_mask_func,
        # unified attn functions 
        varlen_fwd_unified,
zhangshao's avatar
update  
zhangshao committed
29
        fwd_sparse_mean_pool_fast,
zhangshao's avatar
zhangshao committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    )
    # triton fa interface
    from flash_attn.flash_attn_triton_interface import flash_attn_func as triton_flash_attn_func
    from flash_attn.flash_attn_triton_interface import flash_attn_kvpacked_func as triton_flash_attn_kvpacked_func
    from flash_attn.flash_attn_triton_interface import flash_attn_qkvpacked_func as triton_flash_attn_qkvpacked_func
    from flash_attn.flash_attn_triton_interface import flash_attn_varlen_func as triton_flash_attn_varlen_func
    from flash_attn.flash_attn_triton_interface import flash_attn_varlen_kvpacked_func as triton_flash_attn_varlen_kvpacked_func
    from flash_attn.flash_attn_triton_interface import flash_attn_varlen_qkvpacked_func as triton_flash_attn_varlen_qkvpacked_func

    try:
        from .version import version, git_hash, git_branch, dtk, abi, torch_version, hcu_version  # noqa: F401
        __version__, __hcu_version__, __git_hash__, __git_branch__ = version, hcu_version, git_hash, git_branch
    except ImportError:
        pass