Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8c8eb8e8
Unverified
Commit
8c8eb8e8
authored
Aug 28, 2020
by
Min Xu
Committed by
GitHub
Aug 28, 2020
Browse files
[fix] fix eval for oss_ddp (#55)
- added train(mode) method to be aware of eval mode
parent
fb49b515
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
4 deletions
+44
-4
fairscale/nn/data_parallel/oss_ddp.py
fairscale/nn/data_parallel/oss_ddp.py
+15
-4
tests/nn/data_parallel/test_oss_ddp.py
tests/nn/data_parallel/test_oss_ddp.py
+29
-0
No files found.
fairscale/nn/data_parallel/oss_ddp.py
View file @
8c8eb8e8
...
...
@@ -99,6 +99,15 @@ class OssDdp(nn.Module):
attrs
=
copy
.
copy
(
self
.
__dict__
)
return
attrs
def
train
(
self
,
mode
:
bool
=
True
)
->
"OssDdp"
:
pre_mode
=
self
.
module
.
training
self
.
module
.
train
(
mode
)
if
self
.
module
.
training
:
assert
not
self
.
need_reduction
or
pre_mode
,
"incorrect state transition"
else
:
assert
not
self
.
need_reduction
,
"try to enter eval with grads unreduced"
return
self
@
contextmanager
def
no_sync
(
self
)
->
Generator
:
"""A context manager to disable gradient synchronization."""
...
...
@@ -108,10 +117,11 @@ class OssDdp(nn.Module):
self
.
accumulate_grads
=
old_accumulate_grads
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Tensor
:
if
self
.
need_reduction
:
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
self
.
need_reduction
=
True
if
self
.
module
.
training
:
if
self
.
need_reduction
:
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
self
.
need_reduction
=
True
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
...
...
@@ -119,6 +129,7 @@ class OssDdp(nn.Module):
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fix in the buffer. """
...
...
tests/nn/data_parallel/test_oss_ddp.py
View file @
8c8eb8e8
...
...
@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
def
run_test
(
backend
,
device
,
world_size
=
2
):
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_eval_mode
(
_unused
):
""" Testing eval mode make sure this is no asserts. """
dist
.
init_process_group
(
init_method
=
f
"file://
{
tempfile
.
mkstemp
()[
1
]
}
"
,
backend
=
dist
.
Backend
.
GLOO
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
optimizer
=
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
ddp
=
OssDdp
(
model
,
optimizer
,
1
)
ddp
.
eval
()
for
_
in
range
(
5
):
input_tensor
=
torch
.
rand
((
64
,
2
))
output
=
ddp
(
input_tensor
)
ddp
.
train
()
try
:
for
_
in
range
(
5
):
input_tensor
=
torch
.
rand
((
64
,
2
))
output
=
ddp
(
input_tensor
)
except
RuntimeError
:
pass
else
:
assert
False
,
"Multiple forward passes on training mode should not pass"
def
test_eval_mode
():
mp
.
spawn
(
run_eval_mode
,
args
=
(),
join
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment