Unverified Commit 02478eb3 authored by jessijzhao's avatar jessijzhao Committed by GitHub
Browse files

[feat] add CPU support to tutorials in examples + factorize tutorials (#247)

* [feat] add CPU support to tutorials in examples

* now works on a machine without cuda
* fixes some minor typos

* [cleanup] factorize tutorials in examples

* collects duplicate code across tutorials in helpers.py

* [fix] getData in tutorials now returns iterable
parent 7e5ddcd2
......@@ -109,8 +109,8 @@ class LinearLayer(nn.Linear):
self.weight.data.uniform_(-initrange, initrange)
class TransformerLMSequntial(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
class TransformerLMSequential(nn.Sequential):
"""A small language model based on the design of GPT-2 using nn.Sequential
for compatability with Pipe"""
def __init__(self, ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder):
......@@ -122,7 +122,7 @@ class TransformerLMSequntial(nn.Sequential):
layers.append(TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LinearLayer(ninp, ntokens, initrange))
super(TransformerLMSequntial, self).__init__(*layers)
super(TransformerLMSequential, self).__init__(*layers)
def get_data(device):
......@@ -177,7 +177,7 @@ def make_model(args, device, ntokens):
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
model = TransformerLMSequential(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
criterion = nn.CrossEntropyLoss()
lr = 0.01 # learning rate
......
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
def dist_init(rank, world_size):
backend = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
print(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def getModel():
return nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
def getData(n_batches=1):
return [(torch.randn(20, 10), torch.randint(0, 2, size=(20, 1)).squeeze()) for i in range(n_batches)]
def getLossFun():
return F.nll_loss
\ No newline at end of file
......@@ -4,10 +4,9 @@ from typing import Optional, Union, cast
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from fairscale.optim.oss import OSS
from helpers import dist_init, getModel, getData, getLossFun
WORLD_SIZE = 2
EPOCHS = 3
......@@ -15,34 +14,15 @@ EPOCHS = 3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def dist_init(rank, world_size):
backend = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
print(f"Using backend: {backend}")
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def getModel():
return nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
def getData():
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
return [(data, target)]
def getLossFun():
return F.nll_loss
def train(rank: int, world_size: int, epochs: int, use_oss: bool):
# DDP
dist_init(rank, world_size)
rank = torch.device("cpu") if DEVICE == "cpu" else rank
# Problem statement
model = getModel().to(rank)
dataloader = getData()
dataloader = getData(n_batches=1)
loss_fn = getLossFun()
optimizer: Optional[Union[OSS, torch.optim.SGD]] = None
......@@ -52,7 +32,7 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
else:
base_optimizer = torch.optim.SGD
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
optimizer = OSS(params=model.parameters(), optim=base_optimizer, default=base_optimizer_arguments)
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example:
......@@ -82,9 +62,10 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
print(f"Loss: {loss.item()}")
training_end = time.monotonic()
max_memory = torch.cuda.max_memory_allocated(rank)
print(f"[{dist.get_rank()}] : Training done. {training_end-training_start:.2f} sec")
if DEVICE == "cuda":
max_memory = torch.cuda.max_memory_allocated(rank)
print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import fairscale
from helpers import getModel, getData, getLossFun
model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = getModel()
data, target = getData()[0]
loss_fn = getLossFun()
model = fairscale.nn.Pipe(model, balance=[2, 1])
......@@ -19,11 +19,11 @@ optimizer = optim.SGD(model.parameters(), lr=0.001)
# zero the parameter gradients
optimizer.zero_grad()
device = model.devices[0]
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
# outputs and target need to be on the same device
# forward step
outputs = model(data.to(device))
outputs = model(data.to(device).requires_grad_())
# compute loss
loss = loss_fn(outputs.to(device), target.to(device))
......
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
dist_init(rank, world_size)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size)
model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss
model = getModel()
data, target = getData()[0]
loss_fn = getLossFun()
device = torch.device("cuda", rank)
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe(
model,
......
......@@ -4,13 +4,12 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_pg
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
def register_optimizer(ctx, model):
......@@ -28,7 +27,7 @@ def run(rank, world_size):
torch_pg.init_mpi()
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
dist_init(rank, world_size) # FIXME (supports gloo)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size, pipeline_backend="mpi")
......@@ -38,10 +37,9 @@ def run(rank, world_size):
torch.distributed.rpc.shutdown()
return
model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss
model = getModel()
data, target = getData()[0]
loss_fn = getLossFun()
device = torch.device("cuda", rank)
......
......@@ -37,8 +37,6 @@ class ShardedDataParallel(nn.Module):
the sharded optimizer(s) which will decide the gradient partitioning
Keyword Args:
process_group (torch.nn.Optimizer):
Optimizer to shard (default: SGD)
process_group (group):
torch.distributed group (default: group.WORLD)
broadcast_buffers (bool):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment