Unverified Commit 290afecd authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[doc] better ShardedGradScaler example (#271)

parent 18455bf0
......@@ -9,7 +9,7 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
import torch
from fairscale.optim.oss import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
def train(
rank: int,
......@@ -21,7 +21,6 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
# Problem statement
model = myAwesomeModel().to(rank)
model = DDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
......@@ -35,6 +34,10 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
optim=base_optimizer,
**base_optimizer_arguments)
# ** NEW ** Wrap the model into ShardedDDP
model = ShardedDDP(model, optimizer)
# ** NEW ** Use a ShardedGradScaler instead of the default Pytorch GradScaler
scaler = ShardedGradScaler()
# Any relevant training loop, nothing specific to OSS. For example:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment