Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
290afecd
Unverified
Commit
290afecd
authored
Dec 27, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Dec 27, 2020
Browse files
[doc] better ShardedGradScaler example (#271)
parent
18455bf0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
docs/source/api/optim/grad_scaler.rst
docs/source/api/optim/grad_scaler.rst
+5
-2
No files found.
docs/source/api/optim/grad_scaler.rst
View file @
290afecd
...
...
@@ -9,7 +9,7 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
import
torch
from
fairscale
.
optim
.
oss
import
OSS
from
fairscale
.
optim
.
grad_scaler
import
ShardedGradScaler
from
torch
.
nn
.
parallel
import
Distribut
edDataParallel
as
DDP
from
fairscale
.
nn
.
data_
parallel
import
Shard
edDataParallel
as
Sharded
DDP
def
train
(
rank
:
int
,
...
...
@@ -21,7 +21,6 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
#
Problem
statement
model
=
myAwesomeModel
().
to
(
rank
)
model
=
DDP
(
model
,
device_ids
=[
rank
])
dataloader
=
mySuperFastDataloader
()
loss_ln
=
myVeryRelevantLoss
()
...
...
@@ -35,6 +34,10 @@ Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware
optim
=
base_optimizer
,
**
base_optimizer_arguments
)
#
**
NEW
**
Wrap
the
model
into
ShardedDDP
model
=
ShardedDDP
(
model
,
optimizer
)
#
**
NEW
**
Use
a
ShardedGradScaler
instead
of
the
default
Pytorch
GradScaler
scaler
=
ShardedGradScaler
()
#
Any
relevant
training
loop
,
nothing
specific
to
OSS
.
For
example
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment