Unverified Commit 65ca68a9 authored by Siddharth Goyal's avatar Siddharth Goyal Committed by GitHub
Browse files

[fix] examples: fix naming style of helper functions (#334)

parent 73221557
...@@ -10,13 +10,13 @@ def dist_init(rank, world_size): ...@@ -10,13 +10,13 @@ def dist_init(rank, world_size):
dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size) dist.init_process_group(backend=backend, init_method="tcp://localhost:29501", rank=rank, world_size=world_size)
def getModel(): def get_model():
return nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5)) return nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
def getData(n_batches=1): def get_data(n_batches=1):
return [(torch.randn(20, 10), torch.randint(0, 2, size=(20, 1)).squeeze()) for i in range(n_batches)] return [(torch.randn(20, 10), torch.randint(0, 2, size=(20, 1)).squeeze()) for i in range(n_batches)]
def getLossFun(): def get_loss_fun():
return F.nll_loss return F.nll_loss
import time import time
from typing import Optional, Union, cast from typing import Optional, Union, cast
from helpers import dist_init, getData, getLossFun, getModel from helpers import dist_init, get_data, get_loss_fun, get_model
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -21,9 +21,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool): ...@@ -21,9 +21,9 @@ def train(rank: int, world_size: int, epochs: int, use_oss: bool):
device = torch.device("cpu") if DEVICE == "cpu" else rank # type:ignore device = torch.device("cpu") if DEVICE == "cpu" else rank # type:ignore
# Problem statement # Problem statement
model = getModel().to(device) model = get_model().to(device)
dataloader = getData(n_batches=1) dataloader = get_data(n_batches=1)
loss_fn = getLossFun() loss_fn = get_loss_fun()
optimizer: Optional[Union[OSS, torch.optim.SGD]] = None optimizer: Optional[Union[OSS, torch.optim.SGD]] = None
......
from helpers import getData, getLossFun, getModel from helpers import get_data, get_loss_fun, get_model
import torch import torch
import torch.optim as optim import torch.optim as optim
...@@ -7,9 +7,9 @@ import fairscale ...@@ -7,9 +7,9 @@ import fairscale
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
RANK = 0 # example RANK = 0 # example
model = getModel() model = get_model()
data, target = getData()[0] data, target = get_data()[0]
loss_fn = getLossFun() loss_fn = get_loss_fun()
model = fairscale.nn.Pipe(model, balance=[2, 1]) model = fairscale.nn.Pipe(model, balance=[2, 1])
......
import os import os
from helpers import dist_init, getData, getLossFun, getModel from helpers import dist_init, get_data, get_loss_fun, get_model
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -20,9 +20,9 @@ def run(rank, world_size): ...@@ -20,9 +20,9 @@ def run(rank, world_size):
dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size) dist.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size) initialize_model_parallel(1, world_size)
model = getModel() model = get_model()
data, target = getData()[0] data, target = get_data()[0]
loss_fn = getLossFun() loss_fn = get_loss_fun()
device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu") device = torch.device("cuda", rank) if DEVICE == "cuda" else torch.device("cpu")
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
from helpers import dist_init, getData, getLossFun, getModel from helpers import dist_init, get_data, get_loss_fun, get_model
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch_pg import torch_pg
...@@ -37,9 +37,9 @@ def run(rank, world_size): ...@@ -37,9 +37,9 @@ def run(rank, world_size):
torch.distributed.rpc.shutdown() torch.distributed.rpc.shutdown()
return return
model = getModel() model = get_model()
data, target = getData()[0] data, target = get_data()[0]
loss_fn = getLossFun() loss_fn = get_loss_fun()
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment