Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
fee979d9
Unverified
Commit
fee979d9
authored
Apr 30, 2021
by
msbaines
Committed by
GitHub
Apr 30, 2021
Browse files
[test] nn.Pipe: add a parity test that also tests with amp (#645)
parent
3b7373e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
0 deletions
+102
-0
tests/ci_test_list_1.txt
tests/ci_test_list_1.txt
+1
-0
tests/nn/pipe/test_parity.py
tests/nn/pipe/test_parity.py
+101
-0
No files found.
tests/ci_test_list_1.txt
View file @
fee979d9
...
@@ -4,3 +4,4 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
...
@@ -4,3 +4,4 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/pipe/test_parity.py
tests/nn/pipe/test_parity.py
0 → 100644
View file @
fee979d9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Testing Pipe Module Parity
"""
import
contextlib
import
copy
import
numpy
as
np
import
pytest
import
torch
from
fairscale.nn
import
Pipe
from
fairscale.utils.testing
import
skip_if_single_gpu
def
_get_model
(
num_inputs
=
2
,
num_hidden
=
20
,
num_outputs
=
2
):
num_layers
=
torch
.
cuda
.
device_count
()
-
2
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
num_inputs
,
num_hidden
),
*
([
torch
.
nn
.
Linear
(
num_hidden
,
num_hidden
)
for
_
in
range
(
num_layers
)]),
torch
.
nn
.
Linear
(
num_hidden
,
num_outputs
),
)
return
model
def
_check_parity
(
rmodel
,
pmodel
,
ropt
,
popt
,
rloss
,
ploss
):
for
pparams
,
rparams
in
zip
(
pmodel
.
parameters
(),
rmodel
.
parameters
()):
assert
torch
.
allclose
(
pparams
.
cuda
(),
rparams
,
atol
=
1e-2
),
f
"Model params are different
{
oparams
}
{
rparams
}
"
for
p_pg
,
reg_pg
in
zip
(
popt
.
param_groups
,
ropt
.
param_groups
):
for
p_pg
,
reg_pg
in
zip
(
p_pg
[
"params"
],
reg_pg
[
"params"
]):
assert
torch
.
allclose
(
p_pg
.
cuda
(),
reg_pg
,
atol
=
1e-2
),
f
"Model parameters differ in between Pipe and Vanilla
{
[
o_pg
]
}
{
reg_pg
}
"
for
p_buf
,
reg_buf
in
zip
(
pmodel
.
buffers
(),
rmodel
.
buffers
()):
assert
torch
.
allclose
(
p_buf
.
cuda
(),
reg_buf
,
atol
=
1e-2
),
"Model buffers differ in between Pipe and Vanilla."
def
_get_fp16_context
(
use_fp16
=
False
):
if
use_fp16
:
return
torch
.
cuda
.
amp
.
autocast
()
else
:
return
contextlib
.
nullcontext
()
def
_train
(
model
,
optimizer
,
use_fp16
):
inputs
=
torch
.
ones
(
32
,
2
).
cuda
()
labels
=
torch
.
ones
(
32
,
2
)
loss_fn
=
torch
.
nn
.
MSELoss
(
reduction
=
"sum"
)
model
.
train
()
with
_get_fp16_context
(
use_fp16
):
pred
=
model
(
inputs
)
loss
=
loss_fn
(
pred
,
labels
.
to
(
pred
.
device
))
loss
.
backward
()
optimizer
.
step
()
return
model
,
optimizer
,
loss
def
_train_reg_model
(
model
,
use_fp16
=
False
):
model
=
copy
.
deepcopy
(
model
)
model
=
model
.
cuda
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
)
return
_train
(
model
,
optimizer
,
use_fp16
)
def
_train_pipe_model
(
model
,
use_fp16
=
False
,
checkpoint
=
"never"
,
chunks
=
1
):
model
=
copy
.
deepcopy
(
model
)
model
=
Pipe
(
model
,
balance
=
[
1
]
*
torch
.
cuda
.
device_count
(),
devices
=
list
(
range
(
torch
.
cuda
.
device_count
())),
chunks
=
chunks
,
checkpoint
=
checkpoint
,
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
)
return
_train
(
model
,
optimizer
,
use_fp16
)
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"use_fp16"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"chunks"
,
[
1
,
4
])
def
test_correctness
(
use_fp16
,
checkpoint
,
chunks
):
torch
.
manual_seed
(
0
)
np
.
random
.
seed
(
0
)
if
use_fp16
and
not
hasattr
(
torch
.
cuda
.
amp
,
"custom_fwd"
):
pytest
.
skip
(
f
"AMP APIs are not supported in torch version
{
torch
.
__version__
}
"
)
model
=
_get_model
()
rmodel
,
ropt
,
rloss
=
_train_reg_model
(
model
)
pmodel
,
popt
,
ploss
=
_train_pipe_model
(
model
,
use_fp16
=
use_fp16
,
checkpoint
=
checkpoint
,
chunks
=
chunks
,)
_check_parity
(
rmodel
,
pmodel
,
ropt
,
popt
,
rloss
,
ploss
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment