Unverified Commit 2d415f30 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS: add a small sphinx tutorial, similar to README (#92)

Add a small tutorial, similar to the OSS README
parent 426d8449
......@@ -11,7 +11,7 @@ fairscale supports:
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
```bash
```python
import torch
import fairscale
......@@ -23,7 +23,7 @@ model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/oss_async_broadcast/benchmarks/oss.py), but a minimal example could look like the following :
```bash
```python
import torch
from fairscale.optim.oss import OSS
......@@ -58,7 +58,7 @@ def train(
optimizer.step()
if __name__ == "__main__":
# supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
# Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
......
......@@ -5,3 +5,5 @@ Tutorials
:maxdepth: 1
pipe
oss
Optimizer state sharding
========================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html
.. code-block:: default
import torch
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
.. code-block:: default
:emphasize-lines: 49, 65, 66
import torch
from fairscale.optim.oss import OSS
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
......@@ -46,7 +46,9 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
"""
#: The optimizer used for a given shard
optim: Optimizer
in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = dist.group.WORLD, **defaults: Any):
......@@ -61,10 +63,10 @@ class OSS(Optimizer):
split_param_groups = self.partition_parameters()
self.optim = optim(split_param_groups[self.rank], **defaults)
# Optional consolidated optimizer state
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank
# Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device
# Sync local and global param_groups keys
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment