Unverified Commit 5220f89b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[minor] OSS doc fix - add the DDP wrap (#131)

* wrapping the model in DDP in the tutorial

* typo
parent bfd88cad
Optimizer state sharding
========================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks like
.. code-block:: python
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
......@@ -19,6 +21,7 @@ Let's suppose that your trainer looks like
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
......@@ -50,6 +53,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
import torch
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
rank: int,
......@@ -61,6 +65,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment