Unverified Commit 58e97aa6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs][minor] fixing the readme example for oss (#147)

* fixing the readme for oss
parent 10062e58
......@@ -56,6 +56,7 @@ See a more complete example [here](https://github.com/facebookresearch/fairscale
import torch
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
......@@ -66,7 +67,8 @@ def train(
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
......@@ -82,11 +84,11 @@ def train(
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
......
......@@ -40,7 +40,6 @@ Let's suppose that your trainer looks like
model.zero_grad()
outputs = model(data)
loss = loss_fn(outputs, target)
loss /= world_size
loss.backward()
optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment