Unverified Commit 19cb5938 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs] lintfixes (#255)



* lintfixes

* come on black

* Update tutorial_pipe_multiprocess.py

make RANK global like the other tutorials
Co-authored-by: default avatarVittorio Caggiano <caggiano@gmail.com>
parent 550f1ab7
import time import time
from typing import Optional, Union, cast from typing import Optional, Union, cast
from helpers import dist_init, getData, getLossFun, getModel
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from helpers import dist_init, getModel, getData, getLossFun
WORLD_SIZE = 2 WORLD_SIZE = 2
EPOCHS = 3 EPOCHS = 3
...@@ -18,10 +18,10 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -18,10 +18,10 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
# DDP # DDP
dist_init(rank, world_size) dist_init(rank, world_size)
rank = torch.device("cpu") if DEVICE == "cpu" else rank device = torch.device("cpu") if DEVICE == "cpu" else rank # type:ignore
# Problem statement # Problem statement
model = getModel().to(rank) model = getModel().to(device)
dataloader = getData(n_batches=1) dataloader = getData(n_batches=1)
loss_fn = getLossFun() loss_fn = getLossFun()
...@@ -32,7 +32,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -32,7 +32,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
else: else:
base_optimizer = torch.optim.SGD base_optimizer = torch.optim.SGD
base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc... base_optimizer_arguments = {"lr": 1e-4} # any optimizer specific arguments, LR, momentum, etc...
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) optimizer = OSS(
params=model.parameters(), optim=base_optimizer, broadcast_buffer_size=2 ** 17, **base_optimizer_arguments
)
training_start = time.monotonic() training_start = time.monotonic()
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
...@@ -40,7 +42,7 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -40,7 +42,7 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
for _ in range(epochs): for _ in range(epochs):
for (data, target) in dataloader: for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank) data, target = data.to(device), target.to(device)
# Train # Train
model.zero_grad() model.zero_grad()
......
from helpers import getData, getLossFun, getModel
import torch import torch
import torch.optim as optim import torch.optim as optim
import fairscale import fairscale
from helpers import getModel, getData, getLossFun
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
model = getModel() model = getModel()
data, target = getData()[0] data, target = getData()[0]
...@@ -19,7 +20,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.001) ...@@ -19,7 +20,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.001)
# zero the parameter gradients # zero the parameter gradients
optimizer.zero_grad() optimizer.zero_grad()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu") device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
# outputs and target need to be on the same device # outputs and target need to be on the same device
# forward step # forward step
...@@ -33,4 +34,5 @@ optimizer.step() ...@@ -33,4 +34,5 @@ optimizer.step()
print("Finished Training Step") print("Finished Training Step")
del model del model
import os import os
from helpers import dist_init, getData, getLossFun, getModel
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -7,10 +8,9 @@ import torch.optim as optim ...@@ -7,10 +8,9 @@ import torch.optim as optim
import fairscale import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example
def run(rank, world_size): def run(rank, world_size):
...@@ -25,7 +25,7 @@ def run(rank, world_size): ...@@ -25,7 +25,7 @@ def run(rank, world_size):
data, target = getData()[0] data, target = getData()[0]
loss_fn = getLossFun() loss_fn = getLossFun()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu") device = torch.device("cuda", RANK) if DEVICE == "cuda" else torch.device("cpu")
model = fairscale.nn.Pipe( model = fairscale.nn.Pipe(
model, model,
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
import os import os
from helpers import dist_init, getData, getLossFun, getModel
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch_pg import torch_pg
import fairscale import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import initialize_model_parallel
from helpers import dist_init, getModel, getData, getLossFun
def register_optimizer(ctx, model): def register_optimizer(ctx, model):
......
...@@ -28,4 +28,4 @@ use_parentheses = true ...@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "dataclasses", "helpers", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment