attn.py 1.21 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.distributed as dist
from lightx2v.attentions import attention


def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
    num_heads = q.shape[-2]
    cur_rank = dist.get_rank()
    world_size = dist.get_world_size()
    num_chunk_heads = int(num_heads / dist.get_world_size())

Dongz's avatar
Dongz committed
12
13
14
15
    if cur_rank == world_size - 1:
        q = q[:, num_chunk_heads * cur_rank :, :]
        k = k[:, num_chunk_heads * cur_rank :, :]
        v = v[:, num_chunk_heads * cur_rank :, :]
helloyongyang's avatar
helloyongyang committed
16
    else:
Dongz's avatar
Dongz committed
17
18
19
        q = q[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
        k = k[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
        v = v[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
helloyongyang's avatar
helloyongyang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    output = attention(
        attention_type=attention_type,
        q=q,
        k=k,
        v=v,
        cu_seqlens_q=cu_seqlens_qkv,
        cu_seqlens_kv=cu_seqlens_qkv,
        max_seqlen_q=max_seqlen_qkv,
        max_seqlen_kv=max_seqlen_qkv,
    )

    gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
    dist.all_gather(gathered_outputs, output)

    combined_output = torch.cat(gathered_outputs, dim=1)

Dongz's avatar
Dongz committed
37
    return combined_output