tutorial_pipe_multiprocess.py 1.69 KB
Newer Older
1
2
3
import os

import torch
4
import torch.distributed as dist
5
6
7
8
9
import torch.multiprocessing as mp
import torch.optim as optim

import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
10
11
12
13
from helpers import dist_init, getModel, getData, getLossFun


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
15
16
17
18


def run(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "10638"
19
    dist_init(rank, world_size)
20
    os.environ["MASTER_PORT"] = "10639"
21
    dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
22
23
    initialize_model_parallel(1, world_size)

24
25
26
    model = getModel()
    data, target = getData()[0]
    loss_fn = getLossFun()
27

28
    device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

    model = fairscale.nn.Pipe(
        model,
        balance=[2, 1],
        style=fairscale.nn.Pipe.MultiProcess,
        worker_map={0: "worker0", 1: "worker1"},  # Needed to convert ranks to RPC worker names
        input_device=device,
    ).to(device)

    # define optimizer and loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001)

    # zero the parameter gradients
    optimizer.zero_grad()

    # outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    if rank == 1:
        loss = loss_fn(outputs.to(device), target.to(device))

        # backward + optimize
        loss.backward()
        optimizer.step()
    else:
        model.back_helper(outputs)

    print(f"Finished Training Step on {rank}")

    del model


if __name__ == "__main__":
    world_size = 2
    mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)