Unverified Commit 5220f89b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[minor] OSS doc fix - add the DDP wrap (#131)

* wrapping the model in DDP in the tutorial

* typo
parent bfd88cad
Optimizer state sharding Optimizer state sharding
======================== ========================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code. Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks like Let's suppose that your trainer looks like
.. code-block:: python .. code-block:: python
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP
def train( def train(
rank: int, rank: int,
...@@ -19,11 +21,12 @@ Let's suppose that your trainer looks like ...@@ -19,11 +21,12 @@ Let's suppose that your trainer looks like
# Problem statement # Problem statement
model = myAwesomeModel().to(rank) model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss() loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc... # optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4} base_optimizer_arguments = { "lr": 1e-4}
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
params=model.parameters(), params=model.parameters(),
**base_optimizer_arguments) **base_optimizer_arguments)
...@@ -50,6 +53,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -50,6 +53,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
import torch import torch
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
def train( def train(
rank: int, rank: int,
...@@ -61,11 +65,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -61,11 +65,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
# Problem statement # Problem statement
model = myAwesomeModel().to(rank) model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss() loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc... # optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4} base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS # ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
...@@ -100,5 +105,5 @@ The above `train` function will then need to be run via a `multiprocessing.spawn ...@@ -100,5 +105,5 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
nprocs=WORLD_SIZE, nprocs=WORLD_SIZE,
join=True join=True
) )
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_ to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment