"...decoding_attention_triton/triton_flashinfer_cudnn.py" did not exist on "bd6196163ec3293b5254ecb5c6f14c16cb3577b6"
all2all.py 3.43 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
2
import torch._dynamo as dynamo
helloyongyang's avatar
helloyongyang committed
3
4
5
import torch.distributed as dist


6
@dynamo.disable
7
def all2all_seq2head(input, group=None):
Dongz's avatar
Dongz committed
8
    """
helloyongyang's avatar
helloyongyang committed
9
10
11
12
13
14
15
    将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
Dongz's avatar
Dongz committed
16
    """
helloyongyang's avatar
helloyongyang committed
17
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
18
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
19
20

    # 获取当前进程的世界大小
21
    world_size = dist.get_world_size(group=group)
helloyongyang's avatar
helloyongyang committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

    # 获取输入张量的形状
    shard_seq_len, heads, hidden_dims = input.shape
    seq_len = shard_seq_len * world_size  # 计算总序列长度
    shard_heads = heads // world_size  # 计算每个进程处理的头数

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims)  # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
        .transpose(0, 1)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
39
    dist.all_to_all_single(output, input_t, group=group)
helloyongyang's avatar
helloyongyang committed
40
41
42
43
44
45
46

    # 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
    output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()

    return output  # 返回转换后的输出张量


47
@dynamo.disable
48
def all2all_head2seq(input, group=None):
Dongz's avatar
Dongz committed
49
    """
helloyongyang's avatar
helloyongyang committed
50
51
52
53
54
55
56
    将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。

    参数:
        input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]

    返回:
        torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
Dongz's avatar
Dongz committed
57
    """
helloyongyang's avatar
helloyongyang committed
58
    # 确保输入是一个3D张量
Dongz's avatar
Dongz committed
59
    assert input.dim() == 3, f"input must be 3D tensor"
helloyongyang's avatar
helloyongyang committed
60
61

    # 获取当前进程的世界大小
62
    world_size = dist.get_world_size(group=group)
helloyongyang's avatar
helloyongyang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    # 获取输入张量的形状
    seq_len, shard_heads, hidden_dims = input.shape
    heads = shard_heads * world_size  # 计算总头数
    shard_seq_len = seq_len // world_size  # 计算每个进程处理的序列长度

    # 重塑输入张量以便进行 all-to-all 操作
    input_t = (
        input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims)  # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
        .transpose(1, 2)  # 转置以便进行 all-to-all 操作
        .contiguous()  # 确保内存连续
        .reshape(world_size, shard_heads, shard_seq_len, hidden_dims)  # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
    )

    # 创建一个与输入张量相同形状的输出张量
    output = torch.empty_like(input_t)

    # 执行 all-to-all 操作,将输入张量的内容分发到所有进程
81
    dist.all_to_all_single(output, input_t, group=group)
helloyongyang's avatar
helloyongyang committed
82
83
84
85
86
87
88
89

    # 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
    output = output.reshape(heads, shard_seq_len, hidden_dims)

    # 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
    output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)

    return output  # 返回转换后的输出张量