Unverified Commit 19cb5938 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs] lintfixes (#255)



* lintfixes

* come on black

* Update tutorial_pipe_multiprocess.py

make RANK global like the other tutorials
Co-authored-by: default avatarVittorio Caggiano <caggiano@gmail.com>
parent 550f1ab7
......@@ -19,4 +19,4 @@ def getData(n_batches=1):
def getLossFun():
return F.nll_loss
\ No newline at end of file
return F.nll_loss
import time
from typing import Optional, Union, cast
from helpers import dist_init, getData, getLossFun, getModel
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from helpers import dist_init, getModel, getData, getLossFun
WORLD_SIZE = 2
EPOCHS = 3
......@@ -18,10 +18,10 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
# DDP
dist_init(rank, world_size)
rank = torch.device("cpu") if DEVICE == "cpu" else rank
device = torch.device("cpu") if DEVICE == "cpu" else rank # type:ignore
# Problem statement
model = getModel().to(rank)
model = getModel().to(device)
dataloader = getData(n_batches=1)
loss_fn = getLossFun()
......@@ -32,7 +32,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
else:
base_optimizer = torch.optim.SGD
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
optimizer = OSS(
params=model.parameters(), optim=base_optimizer, broadcast_buffer_size=2 ** 17, **base_optimizer_arguments
)
training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example:
......@@ -40,7 +42,7 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
for _ in range(epochs):
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
data, target = data.to(device), target.to(device)
# Train
model.zero_grad()
......
from helpers import getData, getLossFun, getModel
import torch
import torch.optim as optim
import fairscale
from helpers import getModel, getData, getLossFun
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
model = getModel()
data, target = getData()[0]
......@@ -19,7 +20,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.001)
# zero the parameter gradients
optimizer.zero_grad()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
# outputs and target need to be on the same device
# forward step
......@@ -33,4 +34,5 @@ optimizer.step()
print("Finished Training Step")
del model
import os
from helpers import dist_init, getData, getLossFun, getModel
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
......@@ -7,10 +8,9 @@ import torch.optim as optim
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
def run(rank, world_size):
......@@ -25,7 +25,7 @@ def run(rank, world_size):
data, target = getData()[0]
loss_fn = getLossFun()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe(
model,
......
......@@ -3,13 +3,13 @@
import os
from helpers import dist_init, getData, getLossFun, getModel
import torch
import torch.optim as optim
import torch_pg
import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
def register_optimizer(ctx, model):
......@@ -27,7 +27,7 @@ def run(rank, world_size):
torch_pg.init_mpi()
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
dist_init(rank, world_size) # FIXME (supports gloo)
dist_init(rank, world_size) # FIXME (supports gloo)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size, pipeline_backend="mpi")
......
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "dataclasses", "helpers", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment