Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
cae9b638
"vscode:/vscode.git/clone" did not exist on "c07e9601164aac5cc9aa860b36d8cf0562e00982"
Unverified
Commit
cae9b638
authored
Jan 26, 2021
by
msbaines
Committed by
GitHub
Jan 26, 2021
Browse files
[refactor] pipe: separate out Single and MultiProcess pipe (#326)
parent
eab1551a
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
100 deletions
+106
-100
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+99
-93
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+7
-7
No files found.
tests/nn/pipe_process/test_pipe.py
View file @
cae9b638
...
@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import (
...
@@ -31,15 +31,15 @@ from fairscale.nn.model_parallel.initialize import (
get_pipeline_parallel_group
,
get_pipeline_parallel_group
,
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
fairscale.nn.pipe
import
LazyModule
,
Pipe
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
parameters
(
pipeline_style
):
def
parameters
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
list
(
pipe
.
parameters
())
!=
[]
assert
list
(
pipe
.
parameters
())
!=
[]
else
:
else
:
...
@@ -107,7 +107,7 @@ def mpi():
...
@@ -107,7 +107,7 @@ def mpi():
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
public_attrs
(
pipeline_style
):
def
public_attrs
(
pipeline_style
):
class
MyString
:
class
MyString
:
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
...
@@ -118,7 +118,7 @@ def public_attrs(pipeline_style):
...
@@ -118,7 +118,7 @@ def public_attrs(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
(
1
,),
balance
=
(
1
,),
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -127,9 +127,7 @@ def public_attrs(pipeline_style):
...
@@ -127,9 +127,7 @@ def public_attrs(pipeline_style):
checkpoint
=
MyString
(
"always"
),
checkpoint
=
MyString
(
"always"
),
)
)
print
(
f
"balance =
{
pipe
.
devices
}
"
)
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
devices
is
None
assert
pipe
.
chunks
==
42
assert
pipe
.
chunks
==
42
assert
isinstance
(
pipe
.
chunks
,
int
)
assert
isinstance
(
pipe
.
chunks
,
int
)
assert
pipe
.
checkpoint
==
"always"
assert
pipe
.
checkpoint
==
"always"
...
@@ -138,13 +136,13 @@ def public_attrs(pipeline_style):
...
@@ -138,13 +136,13 @@ def public_attrs(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
sequential_like
(
balance
,
pipeline_style
):
def
sequential_like
(
balance
,
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
balance
,
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
,
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
if
balance
==
[
2
]:
if
balance
==
[
2
]:
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
...
@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style):
...
@@ -177,7 +175,7 @@ def sequential_like(balance, pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
balance_wrong_length
(
pipeline_style
):
def
balance_wrong_length
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style):
...
@@ -185,14 +183,14 @@ def balance_wrong_length(pipeline_style):
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
balance_less_than_1
(
pipeline_style
):
def
balance_less_than_1
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style):
...
@@ -200,39 +198,39 @@ def balance_less_than_1(pipeline_style):
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
0
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
0
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
-
1
,
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
balance
=
[
-
1
,
3
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
chunks_less_than_1
(
pipeline_style
):
def
chunks_less_than_1
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
0
)
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
0
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=-
1
)
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=-
1
)
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
too_few_devices
(
pipeline_style
):
def
too_few_devices
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
IndexError
):
with
pytest
.
raises
(
IndexError
):
# len(balance) > len(group.size())
# len(balance) > len(group.size())
model
=
Pipe
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
batch_size_indivisible
(
pipeline_style
):
def
batch_size_indivisible
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
7
,
1
))
model
(
torch
.
rand
(
7
,
1
))
...
@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style):
...
@@ -242,10 +240,10 @@ def batch_size_indivisible(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
batch_size_small
(
pipeline_style
):
def
batch_size_small
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
with
pytest
.
warns
(
None
)
as
record
:
with
pytest
.
warns
(
None
)
as
record
:
model
(
torch
.
rand
(
2
,
1
))
model
(
torch
.
rand
(
2
,
1
))
...
@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style):
...
@@ -255,7 +253,7 @@ def batch_size_small(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode
(
pipeline_style
):
def
checkpoint_mode
(
pipeline_style
):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
if
grad_fn
in
visited
:
if
grad_fn
in
visited
:
...
@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style):
...
@@ -275,7 +273,7 @@ def checkpoint_mode(pipeline_style):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
always
=
Pipe
(
always
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style):
...
@@ -284,7 +282,7 @@ def checkpoint_mode(pipeline_style):
checkpoint
=
"always"
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
pipelined_backward
=
False
,
)
)
except_last
=
Pipe
(
except_last
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style):
...
@@ -293,7 +291,7 @@ def checkpoint_mode(pipeline_style):
checkpoint
=
"except_last"
,
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
pipelined_backward
=
False
,
)
)
never
=
Pipe
(
never
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style):
...
@@ -313,12 +311,12 @@ def checkpoint_mode(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode_invalid
(
pipeline_style
):
def
checkpoint_mode_invalid
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
with
pytest
.
raises
(
ValueError
,
match
=
"checkpoint is not one of 'always', 'except_last', or 'never'"
):
with
pytest
.
raises
(
ValueError
,
match
=
"checkpoint is not one of 'always', 'except_last', or 'never'"
):
Pipe
(
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style):
...
@@ -329,23 +327,27 @@ def checkpoint_mode_invalid(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_mode_when_chunks_1
(
pipeline_style
):
def
checkpoint_mode_when_chunks_1
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
# All checkpoint modes are fine.
# All checkpoint modes are fine.
Pipe
(
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"except_last"
,
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"except_last"
,
)
)
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
MultiProcessPipe
(
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
)
MultiProcessPipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"never"
)
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_eval
(
pipeline_style
):
def
checkpoint_eval
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
...
@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style):
...
@@ -373,7 +375,7 @@ def checkpoint_eval(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
checkpoint_non_float_input
(
pipeline_style
):
def
checkpoint_non_float_input
(
pipeline_style
):
class
ForkNonFloat
(
nn
.
Module
):
class
ForkNonFloat
(
nn
.
Module
):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
...
@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style):
...
@@ -384,7 +386,7 @@ def checkpoint_non_float_input(pipeline_style):
return
input
[
0
]
*
2
return
input
[
0
]
*
2
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
,
1
],
balance
=
[
1
,
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style):
...
@@ -399,17 +401,17 @@ def checkpoint_non_float_input(pipeline_style):
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
# with torch.autograd.detect_anomaly():
# with torch.autograd.detect_anomaly():
output
.
backward
()
output
.
backward
()
elif
pipeline_style
==
Pipe
.
MultiProcess
:
elif
pipeline_style
==
MultiProcess
Pipe
.
MultiProcess
:
model
.
back_helper
(
output
)
model
.
back_helper
(
output
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
no_grad
(
pipeline_style
):
def
no_grad
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
latent
=
None
latent
=
None
...
@@ -421,7 +423,7 @@ def no_grad(pipeline_style):
...
@@ -421,7 +423,7 @@ def no_grad(pipeline_style):
nonlocal
latent
nonlocal
latent
latent
=
output
latent
=
output
partition
=
model
.
mp_
partitions
[
0
]
partition
=
model
.
partitions
[
0
]
partition
.
module
.
register_forward_hook
(
hook
)
partition
.
module
.
register_forward_hook
(
hook
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -431,7 +433,7 @@ def no_grad(pipeline_style):
...
@@ -431,7 +433,7 @@ def no_grad(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
exception
(
pipeline_style
):
def
exception
(
pipeline_style
):
class
ExpectedException
(
Exception
):
class
ExpectedException
(
Exception
):
pass
pass
...
@@ -441,7 +443,7 @@ def exception(pipeline_style):
...
@@ -441,7 +443,7 @@ def exception(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Raise
())
model
=
nn
.
Sequential
(
Raise
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
1
)
with
pytest
.
raises
(
ExpectedException
):
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
1
))
model
(
torch
.
rand
(
1
))
...
@@ -451,7 +453,7 @@ def exception(pipeline_style):
...
@@ -451,7 +453,7 @@ def exception(pipeline_style):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
exception_early_stop_asap
(
pipeline_style
):
def
exception_early_stop_asap
(
pipeline_style
):
"""Even the first partitions have finished to process, the partition before
"""Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible.
the failed partition hould be killed as soon as possible.
...
@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style):
...
@@ -480,7 +482,7 @@ def exception_early_stop_asap(pipeline_style):
raise
ExpectedException
()
raise
ExpectedException
()
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Counter
(),
Raise
())
model
=
nn
.
Sequential
(
Pass
(),
Pass
(),
Counter
(),
Raise
())
model
=
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
3
)
with
pytest
.
raises
(
ExpectedException
):
with
pytest
.
raises
(
ExpectedException
):
model
(
torch
.
rand
(
3
))
model
(
torch
.
rand
(
3
))
...
@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style):
...
@@ -490,7 +492,7 @@ def exception_early_stop_asap(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_pair
(
pipeline_style
):
def
input_pair
(
pipeline_style
):
class
Two
(
nn
.
Module
):
class
Two
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -503,7 +505,7 @@ def input_pair(pipeline_style):
...
@@ -503,7 +505,7 @@ def input_pair(pipeline_style):
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
model
=
nn
.
Sequential
(
Two
())
model
=
nn
.
Sequential
(
Two
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
)
...
@@ -519,7 +521,7 @@ def input_pair(pipeline_style):
...
@@ -519,7 +521,7 @@ def input_pair(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_singleton
(
pipeline_style
):
def
input_singleton
(
pipeline_style
):
class
One
(
nn
.
Module
):
class
One
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -531,7 +533,7 @@ def input_singleton(pipeline_style):
...
@@ -531,7 +533,7 @@ def input_singleton(pipeline_style):
return
(
self
.
fc
(
a
),)
return
(
self
.
fc
(
a
),)
model
=
nn
.
Sequential
(
One
())
model
=
nn
.
Sequential
(
One
())
model
=
Pipe
(
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
)
...
@@ -546,10 +548,10 @@ def input_singleton(pipeline_style):
...
@@ -546,10 +548,10 @@ def input_singleton(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
input_varargs
(
pipeline_style
):
def
input_varargs
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
a
=
torch
.
rand
(
1
)
a
=
torch
.
rand
(
1
)
b
=
torch
.
rand
(
1
)
b
=
torch
.
rand
(
1
)
...
@@ -560,14 +562,14 @@ def input_varargs(pipeline_style):
...
@@ -560,14 +562,14 @@ def input_varargs(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
non_tensor
(
pipeline_style
):
def
non_tensor
(
pipeline_style
):
class
NonTensor
(
nn
.
Module
):
class
NonTensor
(
nn
.
Module
):
def
forward
(
self
,
_
):
def
forward
(
self
,
_
):
return
"hello"
return
"hello"
model
=
nn
.
Sequential
(
NonTensor
())
model
=
nn
.
Sequential
(
NonTensor
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
# TypeError: expected Tensor as element 0 in argument 0, but got str
# TypeError: expected Tensor as element 0 in argument 0, but got str
...
@@ -580,14 +582,14 @@ def non_tensor(pipeline_style):
...
@@ -580,14 +582,14 @@ def non_tensor(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
non_tensor_tuple
(
pipeline_style
):
def
non_tensor_tuple
(
pipeline_style
):
class
NonTensorTuple
(
nn
.
Module
):
class
NonTensorTuple
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
(
x
,
"hello"
)
return
(
x
,
"hello"
)
model
=
nn
.
Sequential
(
NonTensorTuple
())
model
=
nn
.
Sequential
(
NonTensorTuple
())
model
=
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
balance
=
[
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
x
=
torch
.
rand
(
1
)
x
=
torch
.
rand
(
1
)
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
# TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1
...
@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style):
...
@@ -602,7 +604,7 @@ def non_tensor_tuple(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipeline_style
):
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipeline_style
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
...
@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
...
@@ -611,7 +613,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
model
=
[
LazyModule
(
pipe_fn
)]
else
:
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
...
@@ -632,7 +634,7 @@ def deferred_batch_norm(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipeline_style
):
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipeline_style
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
...
@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
...
@@ -641,7 +643,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
model
=
[
LazyModule
(
pipe_fn
)]
model
=
[
LazyModule
(
pipe_fn
)]
else
:
else
:
model
=
nn
.
Sequential
(
pipe_bn
)
model
=
nn
.
Sequential
(
pipe_bn
)
pipe
=
Pipe
(
pipe
=
MultiProcess
Pipe
(
model
,
model
,
balance
=
[
1
],
balance
=
[
1
],
style
=
pipeline_style
,
style
=
pipeline_style
,
...
@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
...
@@ -663,7 +665,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipeline_style):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
devices
(
pipeline_style
):
def
devices
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -671,7 +673,7 @@ def devices(pipeline_style):
...
@@ -671,7 +673,7 @@ def devices(pipeline_style):
# There are extra two ranks.
# There are extra two ranks.
model
=
nn
.
Sequential
(
a
,
b
,
c
)
model
=
nn
.
Sequential
(
a
,
b
,
c
)
model
=
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
# Extra devices must be discarded.
# Extra devices must be discarded.
if
model
.
group
.
rank
()
==
3
:
if
model
.
group
.
rank
()
==
3
:
...
@@ -679,17 +681,17 @@ def devices(pipeline_style):
...
@@ -679,17 +681,17 @@ def devices(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
partitions
(
pipeline_style
):
def
partitions
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
mp_
partitions
,
list
)
assert
isinstance
(
model
.
partitions
,
list
)
assert
len
(
model
)
==
1
assert
len
(
model
)
==
1
assert
isinstance
(
model
.
mp_
partitions
[
0
].
module
,
nn
.
Sequential
)
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
assert
"0.0.weight"
in
model
.
state_dict
()
assert
"0.0.weight"
in
model
.
state_dict
()
...
@@ -699,13 +701,13 @@ def partitions(pipeline_style):
...
@@ -699,13 +701,13 @@ def partitions(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
deny_moving
(
pipeline_style
):
def
deny_moving
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
.
cuda
()
model
.
cuda
()
model
.
cpu
()
model
.
cpu
()
...
@@ -723,29 +725,29 @@ def deny_moving(pipeline_style):
...
@@ -723,29 +725,29 @@ def deny_moving(pipeline_style):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
empty_module
(
pipeline_style
):
def
empty_module
(
pipeline_style
):
# Empty sequential module is not illegal.
# Empty sequential module is not illegal.
model
=
nn
.
Sequential
()
model
=
nn
.
Sequential
()
model
=
Pipe
(
model
,
[],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
model
(
torch
.
tensor
([
42
]))
==
torch
.
tensor
([
42
])
assert
model
(
torch
.
tensor
([
42
]))
==
torch
.
tensor
([
42
])
assert
model
((
torch
.
tensor
([
42
]),))
==
(
torch
.
tensor
([
42
]),)
assert
model
((
torch
.
tensor
([
42
]),))
==
(
torch
.
tensor
([
42
]),)
# But only tensor or tensors is legal in Pipe.
# But only tensor or tensors is legal in
MultiProcess
Pipe.
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
model
(
42
)
model
(
42
)
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
named_children
(
pipeline_style
):
def
named_children
(
pipeline_style
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
model
=
nn
.
Sequential
(
OrderedDict
([(
"a"
,
a
),
(
"b"
,
b
)]))
model
=
nn
.
Sequential
(
OrderedDict
([(
"a"
,
a
),
(
"b"
,
b
)]))
model
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
model
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
names
=
set
(
n
for
n
,
_
in
model
.
named_modules
())
names
=
set
(
n
for
n
,
_
in
model
.
named_modules
())
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
...
@@ -753,30 +755,30 @@ def named_children(pipeline_style):
...
@@ -753,30 +755,30 @@ def named_children(pipeline_style):
else
:
else
:
assert
"0.b"
in
names
assert
"0.b"
in
names
# Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires
#
MultiProcess
Pipe doesn't support __getattr__. Unlike nn.Sequential,
MultiProcess
Pipe requires
# several methods in its namespace.
# several methods in its namespace.
with
pytest
.
raises
(
AttributeError
):
with
pytest
.
raises
(
AttributeError
):
model
.
a
model
.
a
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
recommend_auto_balance
(
pipeline_style
):
def
recommend_auto_balance
(
pipeline_style
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# balance is required
# balance is required
Pipe
(
nn
.
Sequential
())
MultiProcess
Pipe
(
nn
.
Sequential
())
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
# module and sum of balance have differen length (module: 0, sum of balance: 1)
Pipe
(
nn
.
Sequential
(),
[
1
])
MultiProcess
Pipe
(
nn
.
Sequential
(),
[
1
])
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# module and sum of balance have different length (module: 2, sum of balance: 1)
# module and sum of balance have different length (module: 2, sum of balance: 1)
Pipe
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
MultiProcess
Pipe
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
lazy_construction
(
pipeline_style
):
def
lazy_construction
(
pipeline_style
):
init_count
=
0
init_count
=
0
...
@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style):
...
@@ -796,7 +798,7 @@ def lazy_construction(pipeline_style):
LazyModule
(
lambda
:
Custom
()),
LazyModule
(
lambda
:
Custom
()),
]
]
pipe
=
Pipe
(
model
,
balance
=
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
balance
=
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
isinstance
(
pipe
[
0
],
Custom
)
assert
isinstance
(
pipe
[
0
],
Custom
)
assert
isinstance
(
pipe
[
1
],
Custom
)
assert
isinstance
(
pipe
[
1
],
Custom
)
...
@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style):
...
@@ -806,17 +808,17 @@ def lazy_construction(pipeline_style):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
missing_worker_map
(
pipeline_style
):
def
missing_worker_map
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
with
pytest
.
raises
(
ValueError
,
match
=
"'RpcTransport' requires 'worker_map' to be set"
):
with
pytest
.
raises
(
ValueError
,
match
=
"'RpcTransport' requires 'worker_map' to be set"
):
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
)
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
)
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipeline_style
):
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipeline_style
):
class
Surrogate
(
nn
.
Module
):
class
Surrogate
(
nn
.
Module
):
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
...
@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
...
@@ -828,23 +830,23 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style):
# FIXME(tom) can't have duplicate params with separate processes
# FIXME(tom) can't have duplicate params with separate processes
with
pytest
.
raises
(
ValueError
,
match
=
"module with duplicate parameters on distinct devices is not supported"
):
with
pytest
.
raises
(
ValueError
,
match
=
"module with duplicate parameters on distinct devices is not supported"
):
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
pipelined_backward
(
pipeline_style
):
def
pipelined_backward
(
pipeline_style
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
destroy_model_parallel
()
destroy_model_parallel
()
initialize_model_parallel
(
1
,
4
)
initialize_model_parallel
(
1
,
4
)
pipe
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
False
assert
pipe
.
pipelined_backward
is
False
destroy_model_parallel
()
destroy_model_parallel
()
initialize_model_parallel
(
2
,
2
)
initialize_model_parallel
(
2
,
2
)
pipe
=
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
1
,
1
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
True
assert
pipe
.
pipelined_backward
is
True
...
@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style):
...
@@ -853,7 +855,9 @@ def pipelined_backward(pipeline_style):
def
async_event_loop
():
def
async_event_loop
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
())
pipe
=
Pipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
)
pipe
=
MultiProcessPipe
(
model
,
[
1
,
1
,
1
,
1
],
style
=
MultiProcessPipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
(),
chunks
=
10
)
inputs
=
torch
.
rand
(
100
,
10
)
inputs
=
torch
.
rand
(
100
,
10
)
...
@@ -869,7 +873,7 @@ def reuse_lazy():
...
@@ -869,7 +873,7 @@ def reuse_lazy():
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe
=
Pipe
(
model
,
[
3
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
model
,
[
3
,
1
,
1
],
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
.
eval
()
pipe
.
eval
()
output
=
pipe
(
torch
.
rand
(
10
))
output
=
pipe
(
torch
.
rand
(
10
))
...
@@ -887,7 +891,7 @@ def reuse_lazy():
...
@@ -887,7 +891,7 @@ def reuse_lazy():
# ensure identical weights but no sharing between model and pipe
# ensure identical weights but no sharing between model and pipe
reused
=
nn
.
Linear
(
10
,
10
)
reused
=
nn
.
Linear
(
10
,
10
)
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
pipe
=
Pipe
(
layers
,
[
3
,
1
,
1
],
style
=
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
=
MultiProcess
Pipe
(
layers
,
[
3
,
1
,
1
],
style
=
MultiProcess
Pipe
.
AsyncSchedule
,
worker_map
=
get_worker_map
())
pipe
.
eval
()
pipe
.
eval
()
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
...
@@ -931,7 +935,7 @@ def reuse_lazy():
...
@@ -931,7 +935,7 @@ def reuse_lazy():
def
test_instantiate_partition
():
def
test_instantiate_partition
():
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.pipe
import
instantiate_partition
from
fairscale.nn.pipe.
multiprocess_
pipe
import
instantiate_partition
class
FakeGroup
:
class
FakeGroup
:
def
__init__
(
self
,
rank
,
size
):
def
__init__
(
self
,
rank
,
size
):
...
@@ -947,7 +951,7 @@ def test_instantiate_partition():
...
@@ -947,7 +951,7 @@ def test_instantiate_partition():
def
check_partitions
(
model
,
balance
,
expected_order
,
expected_ranks
):
def
check_partitions
(
model
,
balance
,
expected_order
,
expected_ranks
):
"""Check the instantiated model matches expectation of order and rank
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
model: a list of modules or an nn.Sequential
balance: the balance argument to Pipe
balance: the balance argument to
MultiProcess
Pipe
expected_order: the index of modules in `model` in the order they will
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
expected_rank: the rank that each module will be executed on
...
@@ -959,7 +963,9 @@ def test_instantiate_partition():
...
@@ -959,7 +963,9 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
# instantiated model
for
rank
in
range
(
len
(
balance
)):
for
rank
in
range
(
len
(
balance
)):
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
Pipe
.
AsyncSchedule
)
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
MultiProcessPipe
.
AsyncSchedule
)
for
part
in
instantiated
:
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
for
inv
in
part
.
invocations
:
for
inv
in
part
.
invocations
:
...
...
tests/nn/pipe_process/test_transparency.py
View file @
cae9b638
...
@@ -21,13 +21,13 @@ import pytest
...
@@ -21,13 +21,13 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn
import
Pipe
from
fairscale.nn
.pipe
import
MultiProcess
Pipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
Pipe
.
MultiProcess
,
Pipe
.
AsyncSchedule
])
@
pytest
.
mark
.
parametrize
(
"pipeline_style"
,
[
MultiProcess
Pipe
.
MultiProcess
,
MultiProcess
Pipe
.
AsyncSchedule
])
def
simple_linears
(
pipeline_style
):
def
simple_linears
(
pipeline_style
):
def
sum_grad
(
parameters
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
...
@@ -40,7 +40,7 @@ def simple_linears(pipeline_style):
inputs
=
torch
.
rand
(
8
,
1
)
inputs
=
torch
.
rand
(
8
,
1
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
2
),
nn
.
Linear
(
2
,
4
),
nn
.
Linear
(
4
,
2
),
nn
.
Linear
(
2
,
1
),)
# Without Pipe
# Without
MultiProcess
Pipe
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
...
@@ -54,20 +54,20 @@ def simple_linears(pipeline_style):
zero_grad
(
model
.
parameters
())
zero_grad
(
model
.
parameters
())
# With Pipe
# With
MultiProcess
Pipe
model
=
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
MultiProcess
Pipe
(
model
,
[
2
,
2
],
style
=
pipeline_style
,
worker_map
=
get_worker_map
(),
chunks
=
4
)
outputs
=
model
(
inputs
)
outputs
=
model
(
inputs
)
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
else
:
model
.
back_helper
(
outputs
)
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
mp_
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
pipeline
.
partitions
[
0
].
module
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment