all2all.py 3.36 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
2
import torch._dynamo as dynamo
helloyongyang's avatar
helloyongyang committed
3
4
5
import torch.distributed as dist


6
@dynamo.disable
helloyongyang's avatar
helloyongyang committed
7
def all2all_seq2head(input):
Dongz's avatar
Dongz committed
8
    """
helloyongyang's avatar
helloyongyang committed
9
10
11
12
13
14
15
    将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
Dongz's avatar
Dongz committed
16
    """
helloyongyang's avatar
helloyongyang committed
17
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
18
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 获取输入张量的形状
    shard_seq_len, heads, hidden_dims = input.shape
    seq_len = shard_seq_len * world_size  # 计算总序列长度
    shard_heads = heads // world_size  # 计算每个进程处理的头数

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims)  # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
        .transpose(0, 1)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
    dist.all_to_all_single(output, input_t)

    # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
    output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()

    return output  # 返回转换后的输出张量


47
@dynamo.disable
helloyongyang's avatar
helloyongyang committed
48
def all2all_head2seq(input):
Dongz's avatar
Dongz committed
49
    """
helloyongyang's avatar
helloyongyang committed
50
51
52
53
54
55
56
    将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
Dongz's avatar
Dongz committed
57
    """
helloyongyang's avatar
helloyongyang committed
58
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
59
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 获取输入张量的形状
    seq_len, shard_heads, hidden_dims = input.shape
    heads = shard_heads * world_size  # 计算总头数
    shard_seq_len = seq_len // world_size  # 计算每个进程处理的序列长度

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims)  # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
        .transpose(1, 2)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
        .reshape(world_size, shard_heads, shard_seq_len, hidden_dims)  # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
    dist.all_to_all_single(output, input_t)

    # 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
    output = output.reshape(heads, shard_seq_len, hidden_dims)

    # 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
    output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)

    return output  # 返回转换后的输出张量