tutorial_pipe.py 829 Bytes
Newer Older
1
from helpers import get_data, get_loss_fun, get_model
2
3
4
5
6
import torch
import torch.optim as optim

import fairscale

7
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
RANK = 0  # example
9

10
11
12
model = get_model()
data, target = get_data()[0]
loss_fn = get_loss_fun()
13
14
15
16
17
18
19
20
21
22

model = fairscale.nn.Pipe(model, balance=[2, 1])

# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.001)


# zero the parameter gradients
optimizer.zero_grad()

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
23
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
24
25
26

# outputs and target need to be on the same device
# forward step
27
outputs = model(data.to(device).requires_grad_())
28
29
30
31
32
33
34
35
36
# compute loss
loss = loss_fn(outputs.to(device), target.to(device))

# backward + optimize
loss.backward()
optimizer.step()

print("Finished Training Step")

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
37

38
del model