Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a8dd9254
Unverified
Commit
a8dd9254
authored
Jan 30, 2021
by
msbaines
Committed by
GitHub
Jan 30, 2021
Browse files
[refactor] pipe: move async-specific code out of MultiProcessPipe (#344)
parent
e348806b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
235 additions
and
182 deletions
+235
-182
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+147
-4
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+79
-169
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+3
-3
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+6
-6
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
a8dd9254
...
...
@@ -3,10 +3,153 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
.multiprocess_pipe
import
MultiProcessPipe
from
.types
import
PipelineStyle
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
import
itertools
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
Tensor
,
nn
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.skippable
import
Skippable
from
.types
import
LazyModule
,
PipelineStyle
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
else
:
Module
=
nn
.
Module
NamedModules
=
OrderedDict
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
@
dataclass
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
# type: ignore
super
().
__init__
(
*
args
,
style
=
PipelineStyle
.
AsyncSchedule
,
**
kwargs
)
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
partitions
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
PipelineStyle
.
AsyncSchedule
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
fairscale/nn/pipe/multiprocess_pipe.py
View file @
a8dd9254
...
...
@@ -19,10 +19,8 @@
"""The MultiProcessPipe interface."""
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
import
itertools
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
warnings
import
torch
...
...
@@ -33,7 +31,7 @@ import torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.async_schedule
import
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
...
...
@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
ListOfLazyModules
=
List
[
LazyModule
]
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
...
...
@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
Of
LazyModule
s
])
->
None
:
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]
])
->
None
:
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
verify_list_of_callable
(
module
)
else
:
...
...
@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
@
dataclass
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
def
instantiate_partition
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
style
:
PipelineStyle
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
if
style
==
PipelineStyle
.
AsyncSchedule
:
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],)
->
Tuple
[
List
[
nn
.
Sequential
],
List
[
int
]]:
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],)
->
List
[
nn
.
Sequential
]:
"""Splits a module into multiple partitions.
Returns:
A tuple of (partitions, balance).
partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
...
...
@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n
layers
.
clear
()
j
+=
1
partitions
=
cast
(
List
[
nn
.
Sequential
],
nn
.
ModuleList
(
partitions
))
return
partitions
,
balance
return
partitions
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
...
...
@@ -415,10 +276,9 @@ class MultiProcessPipe(Module):
def
__init__
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
Of
LazyModule
s
],
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]
],
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
*
,
style
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
...
...
@@ -427,7 +287,6 @@ class MultiProcessPipe(Module):
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
retain_graph
:
bool
=
False
,
loss_fn
:
Optional
[
nn
.
Module
]
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -453,19 +312,16 @@ class MultiProcessPipe(Module):
self
.
pipelined_backward
=
pipelined_backward
self
.
retain_graph
=
retain_graph
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
loss_fn
=
loss_fn
self
.
lock
=
threading
.
Lock
()
self
.
group
=
group
self
.
worker_map
=
worker_map
self
.
input_device
=
input_device
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
if
self
.
group
is
None
:
self
.
group
:
torch
.
distributed
.
ProcessGroup
if
group
is
None
:
self
.
group
=
get_pipeline_parallel_group
()
assert
self
.
group
else
:
self
.
group
=
group
self
.
balance
=
list
(
balance
)
...
...
@@ -480,14 +336,14 @@ class MultiProcessPipe(Module):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
else
:
self
.
partitions
=
instantiate_partition
(
module
,
balance
,
self
.
group
,
style
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
balance
,
self
.
group
)
if
deferred_batch_norm
:
for
part
in
self
.
partitions
:
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
,
_
=
split_module
(
module
,
balance
)
local_partitions
=
split_module
(
module
,
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
...
...
@@ -501,18 +357,8 @@ class MultiProcessPipe(Module):
self
.
final_stage
=
False
else
:
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
assert
loss_fn
is
None
or
self
.
final_stage
self
.
pipeline
=
MultiProcessPipeline
(
cast
(
List
[
nn
.
Sequential
],
self
.
partitions
),
self
.
_skip_layout
,
checkpoint_stop
,
style
=
style
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
self
.
create_pipeline
()
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
...
...
@@ -520,6 +366,70 @@ class MultiProcessPipe(Module):
else
:
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
partitions
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
PipelineStyle
.
MultiProcess
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
return
sum
(
len
(
p
)
for
p
in
self
.
partitions
)
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
a8dd9254
...
...
@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty
from
queue
import
Queue
from
threading
import
Event
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
Tensor
,
nn
...
...
@@ -171,7 +171,7 @@ class MultiProcessPipeline:
def
__init__
(
self
,
partitions
:
List
[
nn
.
Sequential
],
partitions
:
List
[
ModuleWrapper
],
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
...
...
@@ -180,7 +180,7 @@ class MultiProcessPipeline:
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
)
->
None
:
self
.
partitions
:
List
[
ModuleWrapper
]
=
cast
(
List
[
ModuleWrapper
],
partitions
)
self
.
partitions
=
partitions
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
...
...
tests/nn/pipe_process/test_pipe.py
View file @
a8dd9254
...
...
@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
,
)
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
...
...
@@ -874,9 +873,12 @@ def reuse_lazy():
assert
torch
.
equal
(
model_out
,
pipe_out
)
def
test_instantiate_partition
():
@
torch_spawn
([
1
])
def
instantiate_partition
():
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.multiprocess_pipe
import
instantiate_partition
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
AsyncPipe
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
class
FakeGroup
:
def
__init__
(
self
,
rank
,
size
):
...
...
@@ -904,9 +906,7 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for
rank
in
range
(
len
(
balance
)):
instantiated
=
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
PipelineStyle
.
AsyncSchedule
)
instantiated
=
pipe
.
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)))
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
for
inv
in
part
.
invocations
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment