Unverified Commit 58e97aa6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs][minor] fixing the readme example for oss (#147)

* fixing the readme for oss
parent 10062e58
...@@ -56,6 +56,7 @@ See a more complete example [here](https://github.com/facebookresearch/fairscale ...@@ -56,6 +56,7 @@ See a more complete example [here](https://github.com/facebookresearch/fairscale
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
def train( def train(
rank: int, rank: int,
...@@ -66,7 +67,8 @@ def train( ...@@ -66,7 +67,8 @@ def train(
dist_init(rank, world_size) dist_init(rank, world_size)
# Problem statement # Problem statement
model = myAwesomeModel() model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss() loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
...@@ -82,11 +84,11 @@ def train( ...@@ -82,11 +84,11 @@ def train(
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward() loss.backward()
optimizer.step() optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn( mp.spawn(
......
...@@ -40,7 +40,6 @@ Let's suppose that your trainer looks like ...@@ -40,7 +40,6 @@ Let's suppose that your trainer looks like
model.zero_grad() model.zero_grad()
outputs = model(data) outputs = model(data)
loss = loss_fn(outputs, target) loss = loss_fn(outputs, target)
loss /= world_size
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment