Unverified Commit 8b5b9540 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[docs] Minor refactor, trying to improve a little bit the html (#220)

parent e83da060
...@@ -57,7 +57,7 @@ import torch ...@@ -57,7 +57,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train( def train(
rank: int, rank: int,
...@@ -69,7 +69,7 @@ def train( ...@@ -69,7 +69,7 @@ def train(
# Problem statement # Problem statement
model = myAwesomeModel().to(rank) model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank]) model = ShardedDDP(model, device_ids=[rank]) # this will handle the gradient reduce automatically
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss() loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
......
...@@ -25,15 +25,20 @@ Components ...@@ -25,15 +25,20 @@ Components
---------- ----------
* Parallelism: * Parallelism:
* `pipeline parallelism <../../en/latest/api/nn/pipe.html>`_ * `Pipeline parallelism <../../en/latest/api/nn/pipe.html>`_
* `sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* Optimization: * Sharded training:
* `optimizer state sharding <../../en/latest/api/optim/oss.html>`_ * `Optimizer state sharding <../../en/latest/api/optim/oss.html>`_
* `sharded grad scaler - AMP <../../en/latest/api/optim/grad_scaler.html>`_ * `Sharded grad scaler - automatic mixed precision <../../en/latest/api/optim/grad_scaler.html>`_
* `Sharded distributed data parallel <../../en/latest/api/nn/sharded_ddp.html>`_
* Optimization at scale:
* `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_ * `AdaScale SGD <../../en/latest/api/optim/adascale.html>`_
* `Tutorials <../../en/latest/tutorials/index.html>`_
.. warning:: .. warning::
This library is under active development. This library is under active development.
Please be mindful and create an Please be mindful and create an
......
...@@ -16,7 +16,7 @@ Let's suppose that your trainer looks like ...@@ -16,7 +16,7 @@ Let's suppose that your trainer looks like
world_size: int, world_size: int,
epochs: int): epochs: int):
# DDP # process group init
dist_init(rank, world_size) dist_init(rank, world_size)
# Problem statement # Problem statement
...@@ -44,26 +44,28 @@ Let's suppose that your trainer looks like ...@@ -44,26 +44,28 @@ Let's suppose that your trainer looks like
optimizer.step() optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded)
.. code-block:: python .. code-block:: python
import torch import torch
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from torch.nn.parallel import DistributedDataParallel as DDP from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train( def train(
rank: int, rank: int,
world_size: int, world_size: int,
epochs: int): epochs: int):
# DDP # process group init
dist_init(rank, world_size) dist_init(rank, world_size)
# Problem statement # Problem statement
model = myAwesomeModel().to(rank) model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank]) model = ShardedDDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss() loss_ln = myVeryRelevantLoss()
...@@ -105,6 +107,7 @@ The above `train` function will then need to be run via a `multiprocessing.spawn ...@@ -105,6 +107,7 @@ The above `train` function will then need to be run via a `multiprocessing.spawn
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_. to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way. `fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision) See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment