Unverified Commit 6851247a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] Adding an example in the README for OSS (#79)

parent 20278c0d
...@@ -7,6 +7,7 @@ fairscale supports: ...@@ -7,6 +7,7 @@ fairscale supports:
* optimizer state sharding (fairscale.optim.oss) * optimizer state sharding (fairscale.optim.oss)
## Examples ## Examples
### Pipe
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1. Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
...@@ -19,6 +20,57 @@ model = torch.nn.Sequential(a, b, c, d) ...@@ -19,6 +20,57 @@ model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8) model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
``` ```
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/oss_async_broadcast/benchmarks/oss.py), but a minimal example could look like the following :
```bash
import torch
from fairscale.optim.oss import OSS
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
for batch in dataloader:
# Train
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
optimizer.step()
if __name__ == "__main__":
# supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere
mp.spawn(
train,
args=(
WORLD_SIZE,
EPOCHS,
),
nprocs=WORLD_SIZE,
join=True,
)
```
## Requirements ## Requirements
* PyTorch >= 1.4 * PyTorch >= 1.4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment