__init__.py 892 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2
wangshankun's avatar
wangshankun committed
5
from lightx2v.attentions.common.radial_attn import radial_attn
helloyongyang's avatar
helloyongyang committed
6

Dongz's avatar
Dongz committed
7
8

def attention(attention_type="flash_attn2", *args, **kwargs):
helloyongyang's avatar
helloyongyang committed
9
10
11
12
13
14
    if attention_type == "torch_sdpa":
        return torch_sdpa(*args, **kwargs)
    elif attention_type == "flash_attn2":
        return flash_attn2(*args, **kwargs)
    elif attention_type == "flash_attn3":
        return flash_attn3(*args, **kwargs)
Dongz's avatar
Dongz committed
15
    elif attention_type == "sage_attn2":
helloyongyang's avatar
helloyongyang committed
16
        return sage_attn2(*args, **kwargs)
wangshankun's avatar
wangshankun committed
17
18
    elif attention_type == "radial_attn":
        return radial_attn(*args, **kwargs)
helloyongyang's avatar
helloyongyang committed
19
20
    else:
        raise NotImplementedError(f"Unsupported attention mode: {attention_type}")