"magic_pdf/pre_proc/ocr_dict_merge.py" did not exist on "a0be4652e6fffcd99e33b5ed17b52d00af8c71f9"
__init__.py 743 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2

Dongz's avatar
Dongz committed
6
7

def attention(attention_type="flash_attn2", *args, **kwargs):
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
    if attention_type == "torch_sdpa":
        return torch_sdpa(*args, **kwargs)
    elif attention_type == "flash_attn2":
        return flash_attn2(*args, **kwargs)
    elif attention_type == "flash_attn3":
        return flash_attn3(*args, **kwargs)
Dongz's avatar
Dongz committed
14
    elif attention_type == "sage_attn2":
helloyongyang's avatar
helloyongyang committed
15
16
17
        return sage_attn2(*args, **kwargs)
    else:
        raise NotImplementedError(f"Unsupported attention mode: {attention_type}")