Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
Run a 4-layer model on 2 GPUs. The first two layers run on cuda:0 and the next two layers run on cuda:1.
...
@@ -19,6 +20,57 @@ model = torch.nn.Sequential(a, b, c, d)
...
@@ -19,6 +20,57 @@ model = torch.nn.Sequential(a, b, c, d)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
model = fairscale.nn.Pipe(model, balance=[2, 2], devices=[0, 1], chunks=8)
```
```
### Optimizer state sharding (ZeRO)
See a more complete example [here](https://github.com/facebookresearch/fairscale/blob/oss_async_broadcast/benchmarks/oss.py), but a minimal example could look like the following :
```bash
import torch
from fairscale.optim.oss import OSS
def train(
rank: int,
world_size: int,
epochs: int):
# DDP
dist_init(rank, world_size)
# Problem statement
model = myAwesomeModel()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments ={}# pass any optimizer specific arguments here, or directly below when instantiating OSS