Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8b5b9540
Unverified
Commit
8b5b9540
authored
Dec 01, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 01, 2020
Browse files
[docs] Minor refactor, trying to improve a little bit the html (#220)
parent
e83da060
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
12 deletions
+20
-12
README.md
README.md
+2
-2
docs/source/index.rst
docs/source/index.rst
+10
-5
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+8
-5
No files found.
README.md
View file @
8b5b9540
...
...
@@ -57,7 +57,7 @@ import torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
fairscale.optim.oss
import
OSS
from
torch.nn.
parallel
import
Distribut
edDataParallel
as
DDP
from
fairscale.nn.data_
parallel
import
Shard
edDataParallel
as
Sharded
DDP
def
train
(
rank
:
int
,
...
...
@@ -69,7 +69,7 @@ def train(
# Problem statement
model
=
myAwesomeModel
().
to
(
rank
)
model
=
DDP
(
model
,
device_ids
=
[
rank
])
model
=
Sharded
DDP
(
model
,
device_ids
=
[
rank
])
# this will handle the gradient reduce automatically
dataloader
=
mySuperFastDataloader
()
loss_fn
=
myVeryRelevantLoss
()
base_optimizer
=
torch
.
optim
.
SGD
# pick any pytorch compliant optimizer here
...
...
docs/source/index.rst
View file @
8b5b9540
...
...
@@ -25,15 +25,20 @@ Components
----------
* Parallelism:
* `pipeline parallelism <../../en/latest/api/nn/pipe.html>`_
* `sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* `Pipeline parallelism <../../en/latest/api/nn/pipe.html>`_
* Optimization:
* `optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `sharded grad scaler - AMP <../../en/latest/api/optim/grad_scaler.html>`_
* Sharded training:
* `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_
* `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* Optimization at scale:
* `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_
* `Tutorials <../../en/latest/tutorials/index.html>`_
.. warning::
This library is under active development.
Please be mindful and create an
...
...
docs/source/tutorials/oss.rst
View file @
8b5b9540
...
...
@@ -16,7 +16,7 @@ Let's suppose that your trainer looks like
world_size: int,
epochs: int):
#
DDP
#
process group init
dist_init(rank, world_size)
# Problem statement
...
...
@@ -44,26 +44,28 @@ Let's suppose that your trainer looks like
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded)
.. code-block:: python
import torch
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
#
DDP
#
process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
model =
Sharded
DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
...
...
@@ -105,6 +107,7 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment