Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
92210136
Unverified
Commit
92210136
authored
Dec 04, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 04, 2020
Browse files
[doc] hotfixes, old documentation (#232)
Thanks Jessica for the heads up !
parent
47e57935
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
2 deletions
+7
-2
README.md
README.md
+4
-1
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+3
-1
No files found.
README.md
View file @
92210136
...
@@ -74,14 +74,17 @@ def train(
...
@@ -74,14 +74,17 @@ def train(
# Problem statement
# Problem statement
model
=
myAwesomeModel
().
to
(
rank
)
model
=
myAwesomeModel
().
to
(
rank
)
model
=
ShardedDDP
(
model
,
device_ids
=
[
rank
])
# this will handle the gradient reduce automatically
dataloader
=
mySuperFastDataloader
()
dataloader
=
mySuperFastDataloader
()
loss_fn
=
myVeryRelevantLoss
()
loss_fn
=
myVeryRelevantLoss
()
base_optimizer
=
torch
.
optim
.
SGD
# pick any pytorch compliant optimizer here
base_optimizer
=
torch
.
optim
.
SGD
# pick any pytorch compliant optimizer here
base_optimizer_arguments
=
{}
# pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments
=
{}
# pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model
=
ShardedDDP
(
model
,
optimizer
)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model
.
train
()
model
.
train
()
for
e
in
range
(
epochs
):
for
e
in
range
(
epochs
):
...
...
docs/source/tutorials/oss.rst
View file @
92210136
...
@@ -65,7 +65,6 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
...
@@ -65,7 +65,6 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
# Problem statement
# Problem statement
model = myAwesomeModel().to(rank)
model = myAwesomeModel().to(rank)
model = ShardedDDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
loss_ln = myVeryRelevantLoss()
...
@@ -79,6 +78,9 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
...
@@ -79,6 +78,9 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
optim=base_optimizer,
optim=base_optimizer,
**base_optimizer_arguments)
**base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
for e in range(epochs):
for e in range(epochs):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment