__init__.py 752 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2

def attention(
    attention_type="flash_attn2",
    *args, **kwargs
):
    if attention_type == "torch_sdpa":
        return torch_sdpa(*args, **kwargs)
    elif attention_type == "flash_attn2":
        return flash_attn2(*args, **kwargs)
    elif attention_type == "flash_attn3":
        return flash_attn3(*args, **kwargs)
    elif attention_type == 'sage_attn2':
        return sage_attn2(*args, **kwargs)
    else:
        raise NotImplementedError(f"Unsupported attention mode: {attention_type}")