tutorial_pipe_multiprocess.py 1.73 KB
Newer Older
1
2
import os

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
3
from helpers import dist_init, getData, getLossFun, getModel
4
import torch
5
import torch.distributed as dist
6
7
8
9
import torch.multiprocessing as mp
import torch.optim as optim

from fairscale.nn.model_parallel import initialize_model_parallel
10
from fairscale.nn.pipe import MultiProcessPipe
11
12

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
13
RANK = 0  # example
14
15
16
17
18


def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
19
    dist_init(rank, world_size)
20
    os.environ["MASTER_PORT"] = "10639"
21
    dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
22
23
    initialize_model_parallel(1, world_size)

24
25
26
    model = getModel()
    data, target = getData()[0]
    loss_fn = getLossFun()
27

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
28
    device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
29

30
    model = MultiProcessPipe(
31
32
        model,
        balance=[2, 1],
33
        style=MultiProcessPipe.MultiProcess,
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        worker_map={0: "worker0", 1: "worker1"},  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")

    del model


if __name__ == "__main__":
    world_size = 2
    mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)