Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
290afecd
Unverified
Commit
290afecd
authored
Dec 27, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 27, 2020
Browse files
[doc] better ShardedGradScaler example (#271)
parent
18455bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
docs/source/api/optim/grad_scaler.rst
docs/source/api/optim/grad_scaler.rst
+5
-2
No files found.
docs/source/api/optim/grad_scaler.rst
View file @
290afecd
...
...
@@ -9,7 +9,7 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
import
torch
from
fairscale
.
optim
.
oss
import
OSS
from
fairscale
.
optim
.
grad_scaler
import
ShardedGradScaler
from
torch
.
nn
.
parallel
import
Distribut
edDataParallel
as
DDP
from
fairscale
.
nn
.
data_
parallel
import
Shard
edDataParallel
as
Sharded
DDP
def
train
(
rank
:
int
,
...
...
@@ -21,7 +21,6 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
#
Problem
statement
model
=
myAwesomeModel
().
to
(
rank
)
model
=
DDP
(
model
,
device_ids
=[
rank
])
dataloader
=
mySuperFastDataloader
()
loss_ln
=
myVeryRelevantLoss
()
...
...
@@ -35,6 +34,10 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
#
**
NEW
**
Wrap
the
model
into
ShardedDDP
model
=
ShardedDDP
(
model
,
optimizer
)
#
**
NEW
**
Use
a
ShardedGradScaler
instead
of
the
default
Pytorch
GradScaler
scaler
=
ShardedGradScaler
()
#
Any
relevant
training
loop
,
nothing
specific
to
OSS
.
For
example
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment