Unverified Commit 8b5b9540 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs] Minor refactor, trying to improve a little bit the html (#220)

parent e83da060
......@@ -57,7 +57,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
......@@ -69,7 +69,7 @@ def train(
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
model = ShardedDDP(model, device_ids=[rank]) # this will handle the gradient reduce automatically
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
......
......@@ -25,15 +25,20 @@ Components
----------
* Parallelism:
* `pipeline parallelism <../../en/latest/api/nn/pipe.html>`_
* `sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* `Pipeline parallelism <../../en/latest/api/nn/pipe.html>`_
* Optimization:
* `optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `sharded grad scaler - AMP <../../en/latest/api/optim/grad_scaler.html>`_
* Sharded training:
* `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_
* `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* Optimization at scale:
* `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_
* `Tutorials <../../en/latest/tutorials/index.html>`_
.. warning::
This library is under active development.
Please be mindful and create an
......
......@@ -16,7 +16,7 @@ Let's suppose that your trainer looks like
world_size: int,
epochs: int):
# DDP
# process group init
dist_init(rank, world_size)
# Problem statement
......@@ -44,26 +44,28 @@ Let's suppose that your trainer looks like
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded)
.. code-block:: python
import torch
from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
# process group init
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
model = ShardedDDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
......@@ -105,6 +107,7 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment