Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cae9b638
Unverified
Commit
cae9b638
authored
Jan 26, 2021
by
msbaines
Committed by
GitHub
Jan 26, 2021
Browse files
[refactor] pipe: separate out Single and MultiProcess pipe (#326)
parent
eab1551a
Changes
22
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
100 deletions
+106
-100
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+99
-93
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+7
-7
No files found.
tests/nn/pipe_process/test_pipe.py
View file @
cae9b638
This diff is collapsed.
Click to expand it.
tests/nn/pipe_process/test_transparency.py
View file @
cae9b638
...
...
@@ -21,13 +21,13 @@ import pytest
import
torch
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
fairscale.nn
.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
simple_linears
(
pipeline_style
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs
=
torch
.
rand
(
8
,
1
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
# Without Pipe
# Without
MultiProcess
Pipe
outputs
=
model
(
inputs
)
loss
=
outputs
.
mean
()
loss
.
backward
()
...
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
# With Pipe
model
=
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
# With
MultiProcess
Pipe
model
=
MultiProcess
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment