Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5220f89b
Unverified
Commit
5220f89b
authored
Oct 09, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 09, 2020
Browse files
[minor] OSS doc fix - add the DDP wrap (#131)
* wrapping the model in DDP in the tutorial * typo
parent
bfd88cad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+9
-4
No files found.
docs/source/tutorials/oss.rst
View file @
5220f89b
Optimizer
state
sharding
Optimizer
state
sharding
========================
========================
Using
torch
.
nn
.
parallel
.
DistributedDataParallel
leads
to
some
wasted
communications
,
but
it
is
possible
and
makes
OSS
a
drop
in
solution
in
your
existing
torch
distributed
code
.
Using
torch
.
nn
.
parallel
.
DistributedDataParallel
leads
to
some
wasted
communications
in
the
case
of
OSS
,
but
it
is
possible
and
makes
OSS
a
drop
in
solution
in
your
existing
torch
distributed
code
.
Let
's suppose that your trainer looks like
Let
's suppose that your trainer looks like
.. code-block:: python
.. code-block:: python
import torch
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
def train(
rank: int,
rank: int,
...
@@ -19,11 +21,12 @@ Let's suppose that your trainer looks like
...
@@ -19,11 +21,12 @@ Let's suppose that your trainer looks like
# Problem statement
# Problem statement
model = myAwesomeModel().to(rank)
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
base_optimizer_arguments = { "lr": 1e-4}
optimizer = torch.optim.SGD(
optimizer = torch.optim.SGD(
params=model.parameters(),
params=model.parameters(),
**base_optimizer_arguments)
**base_optimizer_arguments)
...
@@ -50,6 +53,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -50,6 +53,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
import torch
import torch
from fairscale.optim.oss import OSS
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
def train(
def train(
rank: int,
rank: int,
...
@@ -61,11 +65,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -61,11 +65,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
# Problem statement
# Problem statement
model = myAwesomeModel().to(rank)
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
loss_ln = myVeryRelevantLoss()
# optimizer specific arguments e.g. LR, momentum, etc...
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
...
@@ -100,5 +105,5 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
...
@@ -100,5 +105,5 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
nprocs=WORLD_SIZE,
nprocs=WORLD_SIZE,
join=True
join=True
)
)
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment