all2all.py 3.3 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
import torch
import torch.distributed as dist


def all2all_seq2head(input):
Dongz's avatar
Dongz committed
6
    """
helloyongyang's avatar
helloyongyang committed
7
8
9
10
11
12
13
    将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
Dongz's avatar
Dongz committed
14
    """
helloyongyang's avatar
helloyongyang committed
15
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
16
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 获取输入张量的形状
    shard_seq_len, heads, hidden_dims = input.shape
    seq_len = shard_seq_len * world_size  # 计算总序列长度
    shard_heads = heads // world_size  # 计算每个进程处理的头数

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims)  # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
        .transpose(0, 1)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
    dist.all_to_all_single(output, input_t)

    # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
    output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()

    return output  # 返回转换后的输出张量


def all2all_head2seq(input):
Dongz's avatar
Dongz committed
46
    """
helloyongyang's avatar
helloyongyang committed
47
48
49
50
51
52
53
    将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
Dongz's avatar
Dongz committed
54
    """
helloyongyang's avatar
helloyongyang committed
55
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
56
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 获取输入张量的形状
    seq_len, shard_heads, hidden_dims = input.shape
    heads = shard_heads * world_size  # 计算总头数
    shard_seq_len = seq_len // world_size  # 计算每个进程处理的序列长度

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims)  # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
        .transpose(1, 2)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
        .reshape(world_size, shard_heads, shard_seq_len, hidden_dims)  # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
    dist.all_to_all_single(output, input_t)

    # 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
    output = output.reshape(heads, shard_seq_len, hidden_dims)

    # 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
    output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)

    return output  # 返回转换后的输出张量