Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8c8eb8e8
Unverified
Commit
8c8eb8e8
authored
Aug 28, 2020
by
Min Xu
Committed by
GitHub
Aug 28, 2020
Browse files
[fix] fix eval for oss_ddp (#55)
- added train(mode) method to be aware of eval mode
parent
fb49b515
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
4 deletions
+44
-4
fairscale/nn/data_parallel/oss_ddp.py
fairscale/nn/data_parallel/oss_ddp.py
+15
-4
tests/nn/data_parallel/test_oss_ddp.py
tests/nn/data_parallel/test_oss_ddp.py
+29
-0
No files found.
fairscale/nn/data_parallel/oss_ddp.py
View file @
8c8eb8e8
...
@@ -99,6 +99,15 @@ class OssDdp(nn.Module):
...
@@ -99,6 +99,15 @@ class OssDdp(nn.Module):
attrs
=
copy
.
copy
(
self
.
__dict__
)
attrs
=
copy
.
copy
(
self
.
__dict__
)
return
attrs
return
attrs
def
train
(
self
,
mode
:
bool
=
True
)
->
"OssDdp"
:
pre_mode
=
self
.
module
.
training
self
.
module
.
train
(
mode
)
if
self
.
module
.
training
:
assert
not
self
.
need_reduction
or
pre_mode
,
"incorrect state transition"
else
:
assert
not
self
.
need_reduction
,
"try to enter eval with grads unreduced"
return
self
@
contextmanager
@
contextmanager
def
no_sync
(
self
)
->
Generator
:
def
no_sync
(
self
)
->
Generator
:
"""A context manager to disable gradient synchronization."""
"""A context manager to disable gradient synchronization."""
...
@@ -108,6 +117,7 @@ class OssDdp(nn.Module):
...
@@ -108,6 +117,7 @@ class OssDdp(nn.Module):
self
.
accumulate_grads
=
old_accumulate_grads
self
.
accumulate_grads
=
old_accumulate_grads
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
inputs
:
Any
,
**
kwargs
:
Any
)
->
Tensor
:
if
self
.
module
.
training
:
if
self
.
need_reduction
:
if
self
.
need_reduction
:
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
raise
RuntimeError
(
"OssDdp requires explicit reduction, must call OssDdp.reduce"
)
if
not
self
.
accumulate_grads
:
if
not
self
.
accumulate_grads
:
...
@@ -119,6 +129,7 @@ class OssDdp(nn.Module):
...
@@ -119,6 +129,7 @@ class OssDdp(nn.Module):
This function must be called explicitly after backward to reduce
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
gradients. There is no automatic hook like c10d.
"""
"""
assert
self
.
module
.
training
,
"Cannot call reduce in eval"
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
def
reduce_params
(
params
:
List
[
Parameter
],
params_rank
:
int
)
->
None
:
""" Helper to reduce a list of params that should fix in the buffer. """
""" Helper to reduce a list of params that should fix in the buffer. """
...
...
tests/nn/data_parallel/test_oss_ddp.py
View file @
8c8eb8e8
...
@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
...
@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
def
run_test
(
backend
,
device
,
world_size
=
2
):
def
run_test
(
backend
,
device
,
world_size
=
2
):
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
run_one_step
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
def
run_eval_mode
(
_unused
):
""" Testing eval mode make sure this is no asserts. """
dist
.
init_process_group
(
init_method
=
f
"file://
{
tempfile
.
mkstemp
()[
1
]
}
"
,
backend
=
dist
.
Backend
.
GLOO
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
4
))
optimizer
=
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
ddp
=
OssDdp
(
model
,
optimizer
,
1
)
ddp
.
eval
()
for
_
in
range
(
5
):
input_tensor
=
torch
.
rand
((
64
,
2
))
output
=
ddp
(
input_tensor
)
ddp
.
train
()
try
:
for
_
in
range
(
5
):
input_tensor
=
torch
.
rand
((
64
,
2
))
output
=
ddp
(
input_tensor
)
except
RuntimeError
:
pass
else
:
assert
False
,
"Multiple forward passes on training mode should not pass"
def
test_eval_mode
():
mp
.
spawn
(
run_eval_mode
,
args
=
(),
join
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment