Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
92210136
Unverified
Commit
92210136
authored
Dec 04, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 04, 2020
Browse files
[doc] hotfixes, old documentation (#232)
Thanks Jessica for the heads up !
parent
47e57935
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
2 deletions
+7
-2
README.md
README.md
+4
-1
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+3
-1
No files found.
README.md
View file @
92210136
...
@@ -74,14 +74,17 @@ def train(
...
@@ -74,14 +74,17 @@ def train(
# Problem statement
# Problem statement
model
=
myAwesomeModel
().
to
(
rank
)
model
=
myAwesomeModel
().
to
(
rank
)
model
=
ShardedDDP
(
model
,
device_ids
=
[
rank
])
# this will handle the gradient reduce automatically
dataloader
=
mySuperFastDataloader
()
dataloader
=
mySuperFastDataloader
()
loss_fn
=
myVeryRelevantLoss
()
loss_fn
=
myVeryRelevantLoss
()
base_optimizer
=
torch
.
optim
.
SGD
# pick any pytorch compliant optimizer here
base_optimizer
=
torch
.
optim
.
SGD
# pick any pytorch compliant optimizer here
base_optimizer_arguments
=
{}
# pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments
=
{}
# pass any optimizer specific arguments here, or directly below when instantiating OSS
# Wrap the optimizer in its state sharding brethren
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model
=
ShardedDDP
(
model
,
optimizer
)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model
.
train
()
model
.
train
()
for
e
in
range
(
epochs
):
for
e
in
range
(
epochs
):
...
...
docs/source/tutorials/oss.rst
View file @
92210136
...
@@ -65,7 +65,6 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
...
@@ -65,7 +65,6 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
# Problem statement
# Problem statement
model = myAwesomeModel().to(rank)
model = myAwesomeModel().to(rank)
model = ShardedDDP(model, device_ids=[rank])
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss_ln = myVeryRelevantLoss()
loss_ln = myVeryRelevantLoss()
...
@@ -79,6 +78,9 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
...
@@ -79,6 +78,9 @@ DDP can be used in place of ShardedDDP in the example below, but the memory savi
optim=base_optimizer,
optim=base_optimizer,
**base_optimizer_arguments)
**base_optimizer_arguments)
# Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
model = ShardedDDP(model, optimizer)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
for e in range(epochs):
for e in range(epochs):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment