tutorial_pipe_multiprocess.py 1.71 KB
Newer Older
1
2
import os

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
3
from helpers import dist_init, getData, getLossFun, getModel
4
import torch
5
import torch.distributed as dist
6
7
8
9
10
import torch.multiprocessing as mp
import torch.optim as optim

import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
11
12

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
13
RANK = 0  # example
14
15
16
17
18


def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
19
    dist_init(rank, world_size)
20
    os.environ["MASTER_PORT"] = "10639"
21
    dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
22
23
    initialize_model_parallel(1, world_size)

24
25
26
    model = getModel()
    data, target = getData()[0]
    loss_fn = getLossFun()
27

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
28
    device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    model = fairscale.nn.Pipe(
        model,
        balance=[2, 1],
        style=fairscale.nn.Pipe.MultiProcess,
        worker_map={0: "worker0", 1: "worker1"},  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")

    del model


if __name__ == "__main__":
    world_size = 2
    mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)