wrap.py 204 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
from lightx2v.attentions.distributed.partial_heads_attn.attn import partial_heads_attn


def parallelize_hunyuan(hunyuan_model):
    hunyuan_model.transformer_infer.parallel_attention = partial_heads_attn