Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cae9b638
Unverified
Commit
cae9b638
authored
Jan 26, 2021
by
msbaines
Committed by
GitHub
Jan 26, 2021
Browse files
[refactor] pipe: separate out Single and MultiProcess pipe (#326)
parent
eab1551a
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
100 deletions
+106
-100
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+99
-93
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+7
-7
No files found.
tests/nn/pipe_process/test_pipe.py
View file @
cae9b638
This diff is collapsed.
Click to expand it.
tests/nn/pipe_process/test_transparency.py
View file @
cae9b638
...
@@ -21,13 +21,13 @@ import pytest
...
@@ -21,13 +21,13 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
fairscale.nn
.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
simple_linears
(
pipeline_style
):
def
simple_linears
(
pipeline_style
):
def
sum_grad
(
parameters
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs
=
torch
.
rand
(
8
,
1
)
inputs
=
torch
.
rand
(
8
,
1
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
# Without Pipe
# Without
MultiProcess
Pipe
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
zero_grad
(
model
.
parameters
())
# With Pipe
# With
MultiProcess
Pipe
model
=
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
else
:
model
.
back_helper
(
outputs
)
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment