processor.py 689 Bytes
Newer Older
Xinchi Huang's avatar
Xinchi Huang committed
1
2
3
4
5
6
7
8
9
from re import split
import torch
import torch.distributed as dist


def pre_process(x):
    world_size = dist.get_world_size()
    cur_rank = dist.get_rank()

Dongz's avatar
Dongz committed
10
    x = torch.chunk(x, world_size, dim=0)[cur_rank]
Xinchi Huang's avatar
Xinchi Huang committed
11
12
13

    return x

Dongz's avatar
Dongz committed
14

Xinchi Huang's avatar
Xinchi Huang committed
15
16
17
18
19
20
def post_process(x):
    # 获取当前进程的世界大小
    world_size = dist.get_world_size()

    # 创建一个列表,用于存储所有进程的输出
    gathered_x = [torch.empty_like(x) for _ in range(world_size)]
Dongz's avatar
Dongz committed
21

Xinchi Huang's avatar
Xinchi Huang committed
22
23
24
25
26
27
    # 收集所有进程的输出
    dist.all_gather(gathered_x, x)

    # 在指定的维度上合并所有进程的输出
    combined_output = torch.cat(gathered_x, dim=0)

Dongz's avatar
Dongz committed
28
    return combined_output  # 返回合并后的输出