"docker/vscode:/vscode.git/clone" did not exist on "3b37fefee99425286984a9d5fa4f1850064d01eb"
Unverified Commit 92210136 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[doc] hotfixes, old documentation (#232)

Thanks Jessica for the heads up !
parent 47e57935
......@@ -74,14 +74,17 @@ def train(
# Problem statement
model = myAwesomeModel().to(rank)
model = ShardedDDP(model, device_ids=[rank]) # this will handle the gradient reduce automatically
dataloader = mySuperFastDataloader()
loss_fn = myVeryRelevantLoss()
base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
......
......@@ -65,7 +65,6 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
# Problem statement
model = myAwesomeModel().to(rank)
model = ShardedDDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
......@@ -79,6 +78,9 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
optim=base_optimizer,
**base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
for e in range(epochs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment