Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cc766aa5
Unverified
Commit
cc766aa5
authored
Nov 05, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Nov 05, 2020
Browse files
[feature] Add a torch AMP benchmark option and test job (#175)
* oss benchmark: add an --amp option * add a circleCI test
parent
0d1f058b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
3 deletions
+17
-3
.circleci/config.yml
.circleci/config.yml
+7
-0
benchmarks/oss.py
benchmarks/oss.py
+10
-3
No files found.
.circleci/config.yml
View file @
cc766aa5
...
@@ -126,6 +126,11 @@ run_oss_gloo: &run_oss_gloo
...
@@ -126,6 +126,11 @@ run_oss_gloo: &run_oss_gloo
command
:
|
command
:
|
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
run_oss_amp
:
&run_oss_amp
-
run
:
name
:
Run OSS with Torch AMP
command
:
|
python benchmarks/oss.py --amp --epochs 3 --optim_type oss_sharded_ddp
# -------------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------
# Jobs to run
# Jobs to run
...
@@ -316,6 +321,8 @@ jobs:
...
@@ -316,6 +321,8 @@ jobs:
-
<<
:
*run_oss_gloo
-
<<
:
*run_oss_gloo
-
<<
:
*run_oss_amp
...
...
benchmarks/oss.py
View file @
cc766aa5
...
@@ -140,9 +140,15 @@ def train(
...
@@ -140,9 +140,15 @@ def train(
next
(
model
.
parameters
()).
norm
().
item
(),
next
(
model
.
parameters
()).
grad
.
norm
().
item
()
next
(
model
.
parameters
()).
norm
().
item
(),
next
(
model
.
parameters
()).
grad
.
norm
().
item
()
)
)
)
)
if
not
args
.
cpu
and
args
.
amp
:
# Automatically computes the FW pass in half precision
with
torch
.
cuda
.
amp
.
autocast
():
outputs
=
model
(
batch
[
"inputs"
])
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
else
:
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
loss
.
backward
()
loss
.
backward
()
if
optim_type
==
OptimType
.
oss_sharded_ddp
:
if
optim_type
==
OptimType
.
oss_sharded_ddp
:
...
@@ -244,7 +250,8 @@ if __name__ == "__main__":
...
@@ -244,7 +250,8 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--torchvision_model"
,
type
=
str
,
help
=
"Any torchvision model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--torchvision_model"
,
type
=
str
,
help
=
"Any torchvision model name (str)"
,
default
=
"resnet101"
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
)
parser
.
add_argument
(
"--debug"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Display additional debug information"
)
parser
.
add_argument
(
"--amp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Activate torch AMP"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment