attn.py 1.18 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.distributed as dist
from lightx2v.attentions import attention


def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
    num_heads = q.shape[-2]
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()
    num_chunk_heads = int(num_heads / dist.get_world_size())

    if cur_rank == world_size-1:
        q = q[:, num_chunk_heads*cur_rank:, :]
        k = k[:, num_chunk_heads*cur_rank:, :]
        v = v[:, num_chunk_heads*cur_rank:, :]
    else:
        q = q[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
        k = k[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
        v = v[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]

    output = attention(
        attention_type=attention_type,
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=cu_seqlens_qkv,
        cu_seqlens_kv=cu_seqlens_qkv,
        max_seqlen_q=max_seqlen_qkv,
        max_seqlen_kv=max_seqlen_qkv,
    )

    gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
    dist.all_gather(gathered_outputs, output)

    combined_output = torch.cat(gathered_outputs, dim=1)

    return combined_output