Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cae9b638
Unverified
Commit
cae9b638
authored
Jan 26, 2021
by
msbaines
Committed by
GitHub
Jan 26, 2021
Browse files
[refactor] pipe: separate out Single and MultiProcess pipe (#326)
parent
eab1551a
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
100 deletions
+106
-100
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+99
-93
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+7
-7
No files found.
tests/nn/pipe_process/test_pipe.py
View file @
cae9b638
...
...
@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group
,
initialize_model_parallel
,
)
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
parameters
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
list
(
pipe
.
parameters
())
!=
[]
else
:
...
...
@@ -107,7 +107,7 @@ def mpi():
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
public_attrs
(
pipeline_style
):
class
MyString
:
def
__init__
(
self
,
value
):
...
...
@@ -118,7 +118,7 @@ def public_attrs(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
(
1
,),
style
=
pipeline_style
,
...
...
@@ -127,9 +127,7 @@ def public_attrs(pipeline_style):
checkpoint
=
MyString
(
"always"
),
)
print
(
f
"balance =
{
pipe
.
devices
}
"
)
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
devices
is
None
assert
pipe
.
chunks
==
42
assert
isinstance
(
pipe
.
chunks
,
int
)
assert
pipe
.
checkpoint
==
"always"
...
...
@@ -138,13 +136,13 @@ def public_attrs(pipeline_style):
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
sequential_like
(
balance
,
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
balance
,
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
,
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
if
balance
==
[
2
]:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
...
@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
balance_wrong_length
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
...
@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style):
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
balance_less_than_1
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
...
@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style):
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
0
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
0
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
-
1
,
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
-
1
,
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
chunks_less_than_1
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
0
)
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
0
)
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=-
1
)
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=-
1
)
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
too_few_devices
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
IndexError
):
# len(balance) > len(group.size())
model
=
Pipe
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
batch_size_indivisible
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
7
,
1
))
...
...
@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
batch_size_small
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
2
,
1
))
...
...
@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode
(
pipeline_style
):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
if
grad_fn
in
visited
:
...
...
@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
input
=
torch
.
rand
(
2
,
1
)
always
=
Pipe
(
always
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style):
checkpoint
=
"always"
,
pipelined_backward
=
False
,
)
except_last
=
Pipe
(
except_last
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style):
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
)
never
=
Pipe
(
never
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode_invalid
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
,
match
=
"checkpoint is not one of 'always', 'except_last', or 'never'"
):
Pipe
(
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode_when_chunks_1
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
# All checkpoint modes are fine.
Pipe
(
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"except_last"
,
)
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_eval
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
input
=
torch
.
rand
(
2
,
1
)
...
...
@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style):
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_non_float_input
(
pipeline_style
):
class
ForkNonFloat
(
nn
.
Module
):
def
forward
(
self
,
input
):
...
...
@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style):
return
input
[
0
]
*
2
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
,
1
],
style
=
pipeline_style
,
...
...
@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style):
if
model
.
group
.
rank
()
==
1
:
# with torch.autograd.detect_anomaly():
output
.
backward
()
elif
pipeline_style
==
Pipe
.
MultiProcess
:
elif
pipeline_style
==
MultiProcess
Pipe
.
MultiProcess
:
model
.
back_helper
(
output
)
torch
.
distributed
.
barrier
()
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
no_grad
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
)
input
=
torch
.
rand
(
2
,
1
)
latent
=
None
...
...
@@ -421,7 +423,7 @@ def no_grad(pipeline_style):
nonlocal
latent
latent
=
output
partition
=
model
.
mp_
partitions
[
0
]
partition
=
model
.
partitions
[
0
]
partition
.
module
.
register_forward_hook
(
hook
)
with
torch
.
no_grad
():
...
...
@@ -431,7 +433,7 @@ def no_grad(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
exception
(
pipeline_style
):
class
ExpectedException
(
Exception
):
pass
...
...
@@ -441,7 +443,7 @@ def exception(pipeline_style):
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Raise
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
1
))
...
...
@@ -451,7 +453,7 @@ def exception(pipeline_style):
@
torch_spawn
([
4
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
exception_early_stop_asap
(
pipeline_style
):
"""Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible.
...
...
@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style):
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Counter
(),
Raise
())
model
=
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
3
))
...
...
@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_pair
(
pipeline_style
):
class
Two
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -503,7 +505,7 @@ def input_pair(pipeline_style):
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
model
=
nn
.
Sequential
(
Two
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
...
...
@@ -519,7 +521,7 @@ def input_pair(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_singleton
(
pipeline_style
):
class
One
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -531,7 +533,7 @@ def input_singleton(pipeline_style):
return
(
self
.
fc
(
a
),)
model
=
nn
.
Sequential
(
One
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
...
...
@@ -546,10 +548,10 @@ def input_singleton(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_varargs
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
a
=
torch
.
rand
(
1
)
b
=
torch
.
rand
(
1
)
...
...
@@ -560,14 +562,14 @@ def input_varargs(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
non_tensor
(
pipeline_style
):
class
NonTensor
(
nn
.
Module
):
def
forward
(
self
,
_
):
return
"hello"
model
=
nn
.
Sequential
(
NonTensor
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
# TypeError: expected Tensor as element 0 in argument 0, but got str
...
...
@@ -580,14 +582,14 @@ def non_tensor(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
non_tensor_tuple
(
pipeline_style
):
class
NonTensorTuple
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
(
x
,
"hello"
)
model
=
nn
.
Sequential
(
NonTensorTuple
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...
...
@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipeline_style
):
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
...
...
@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipeline_style
):
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
...
...
@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
...
...
@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
devices
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
...
@@ -671,7 +673,7 @@ def devices(pipeline_style):
# There are extra two ranks.
model
=
nn
.
Sequential
(
a
,
b
,
c
)
model
=
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
# Extra devices must be discarded.
if
model
.
group
.
rank
()
==
3
:
...
...
@@ -679,17 +681,17 @@ def devices(pipeline_style):
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
partitions
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
mp_
partitions
,
list
)
assert
isinstance
(
model
.
partitions
,
list
)
assert
len
(
model
)
==
1
assert
isinstance
(
model
.
mp_
partitions
[
0
].
module
,
nn
.
Sequential
)
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
assert
"0.0.weight"
in
model
.
state_dict
()
...
...
@@ -699,13 +701,13 @@ def partitions(pipeline_style):
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deny_moving
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
.
cuda
()
model
.
cpu
()
...
...
@@ -723,29 +725,29 @@ def deny_moving(pipeline_style):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
empty_module
(
pipeline_style
):
# Empty sequential module is not illegal.
model
=
nn
.
Sequential
()
model
=
Pipe
(
model
,
[],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
model
(
torch
.
tensor
([
42
]))
==
torch
.
tensor
([
42
])
assert
model
((
torch
.
tensor
([
42
]),))
==
(
torch
.
tensor
([
42
]),)
# But only tensor or tensors is legal in Pipe.
# But only tensor or tensors is legal in
MultiProcess
Pipe.
with
pytest
.
raises
(
TypeError
):
model
(
42
)
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
named_children
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
OrderedDict
([(
"a"
,
a
),
(
"b"
,
b
)]))
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
names
=
set
(
n
for
n
,
_
in
model
.
named_modules
())
if
model
.
group
.
rank
()
==
0
:
...
...
@@ -753,30 +755,30 @@ def named_children(pipeline_style):
else
:
assert
"0.b"
in
names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
#
MultiProcess
Pipe doesn't support __getattr__. Unlike nn.Sequential,
MultiProcess
Pipe requires
# several methods in its namespace.
with
pytest
.
raises
(
AttributeError
):
model
.
a
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
recommend_auto_balance
(
pipeline_style
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# balance is required
Pipe
(
nn
.
Sequential
())
MultiProcess
Pipe
(
nn
.
Sequential
())
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe
(
nn
.
Sequential
(),
[
1
])
MultiProcess
Pipe
(
nn
.
Sequential
(),
[
1
])
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
MultiProcess
Pipe
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
lazy_construction
(
pipeline_style
):
init_count
=
0
...
...
@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style):
LazyModule
(
lambda
:
Custom
()),
]
pipe
=
Pipe
(
model
,
balance
=
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
isinstance
(
pipe
[
0
],
Custom
)
assert
isinstance
(
pipe
[
1
],
Custom
)
...
...
@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style):
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
missing_worker_map
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
with
pytest
.
raises
(
ValueError
,
match
=
"'RpcTransport' requires 'worker_map' to be set"
):
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
)
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
)
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipeline_style
):
class
Surrogate
(
nn
.
Module
):
def
__init__
(
self
,
module
):
...
...
@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes
with
pytest
.
raises
(
ValueError
,
match
=
"module with duplicate parameters on distinct devices is not supported"
):
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
pipelined_backward
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
destroy_model_parallel
()
initialize_model_parallel
(
1
,
4
)
pipe
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
False
destroy_model_parallel
()
initialize_model_parallel
(
2
,
2
)
pipe
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
True
...
...
@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style):
def
async_event_loop
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
())
pipe
=
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
)
pipe
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
)
inputs
=
torch
.
rand
(
100
,
10
)
...
...
@@ -869,7 +873,7 @@ def reuse_lazy():
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe
=
Pipe
(
model
,
[
3
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
3
,
1
,
1
],
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
.
eval
()
output
=
pipe
(
torch
.
rand
(
10
))
...
...
@@ -887,7 +891,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe
reused
=
nn
.
Linear
(
10
,
10
)
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
pipe
=
Pipe
(
layers
,
[
3
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
layers
,
[
3
,
1
,
1
],
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
.
eval
()
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
...
...
@@ -931,7 +935,7 @@ def reuse_lazy():
def
test_instantiate_partition
():
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.pipe
import
instantiate_partition
from
fairscale.nn.pipe.
multiprocess_
pipe
import
instantiate_partition
class
FakeGroup
:
def
__init__
(
self
,
rank
,
size
):
...
...
@@ -947,7 +951,7 @@ def test_instantiate_partition():
def
check_partitions
(
model
,
balance
,
expected_order
,
expected_ranks
):
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe
balance: the balance argument to
MultiProcess
Pipe
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
...
...
@@ -959,7 +963,9 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for
rank
in
range
(
len
(
balance
)):
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
Pipe
.
AsyncSchedule
)
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
MultiProcessPipe
.
AsyncSchedule
)
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
for
inv
in
part
.
invocations
:
...
...
tests/nn/pipe_process/test_transparency.py
View file @
cae9b638
...
...
@@ -21,13 +21,13 @@ import pytest
import
torch
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
fairscale.nn
.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
simple_linears
(
pipeline_style
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs
=
torch
.
rand
(
8
,
1
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
# Without Pipe
# Without
MultiProcess
Pipe
outputs
=
model
(
inputs
)
loss
=
outputs
.
mean
()
loss
.
backward
()
...
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
# With Pipe
model
=
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
# With
MultiProcess
Pipe
model
=
MultiProcess
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment