processor.py 965 Bytes
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
2
3
from re import split
import torch
import torch.distributed as dist
4
5
6
import torch.nn.functional as F

PADDING_SIZE = None
Xinchi Huang's avatar
Xinchi Huang committed
7
8
9
10
11
12


def pre_process(x):
    world_size = dist.get_world_size()
    cur_rank = dist.get_rank()

13
14
15
16
17
18
    padding_size = (world_size - (x.shape[0] % world_size)) % world_size

    if padding_size > 0:
        # 使用 F.pad 填充第一维
        x = F.pad(x, (0, 0, 0, padding_size))  # (后维度填充, 前维度填充)

Dongz's avatar
Dongz committed
19
    x = torch.chunk(x, world_size, dim=0)[cur_rank]
Xinchi Huang's avatar
Xinchi Huang committed
20
21
22

    return x

Dongz's avatar
Dongz committed
23

Xinchi Huang's avatar
Xinchi Huang committed
24
25
26
27
28
29
def post_process(x):
    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 创建一个列表,用于存储所有进程的输出
    gathered_x = [torch.empty_like(x) for _ in range(world_size)]
Dongz's avatar
Dongz committed
30

Xinchi Huang's avatar
Xinchi Huang committed
31
32
33
34
35
36
    # 收集所有进程的输出
    dist.all_gather(gathered_x, x)

    # 在指定的维度上合并所有进程的输出
    combined_output = torch.cat(gathered_x, dim=0)

Dongz's avatar
Dongz committed
37
    return combined_output  # 返回合并后的输出