Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a8dd9254
Unverified
Commit
a8dd9254
authored
Jan 30, 2021
by
msbaines
Committed by
GitHub
Jan 30, 2021
Browse files
[refactor] pipe: move async-specific code out of MultiProcessPipe (#344)
parent
e348806b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
235 additions
and
182 deletions
+235
-182
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+147
-4
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+79
-169
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+3
-3
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+6
-6
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
a8dd9254
...
@@ -3,10 +3,153 @@
...
@@ -3,10 +3,153 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.multiprocess_pipe
import
MultiProcessPipe
from
collections
import
OrderedDict
from
.types
import
PipelineStyle
from
dataclasses
import
dataclass
,
field
import
itertools
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
Tensor
,
nn
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.skip.skippable
import
Skippable
from
.types
import
LazyModule
,
PipelineStyle
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
else
:
Module
=
nn
.
Module
NamedModules
=
OrderedDict
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
@
dataclass
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
class
AsyncPipe
(
MultiProcessPipe
):
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
# type: ignore
def
create_pipeline
(
self
)
->
None
:
super
().
__init__
(
*
args
,
style
=
PipelineStyle
.
AsyncSchedule
,
**
kwargs
)
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
partitions
,
self
.
_skip_layout
,
checkpoint_stop
,
style
=
PipelineStyle
.
AsyncSchedule
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
fairscale/nn/pipe/multiprocess_pipe.py
View file @
a8dd9254
...
@@ -19,10 +19,8 @@
...
@@ -19,10 +19,8 @@
"""The MultiProcessPipe interface."""
"""The MultiProcessPipe interface."""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
dataclasses
import
dataclass
,
field
import
itertools
import
threading
import
threading
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
warnings
import
warnings
import
torch
import
torch
...
@@ -33,7 +31,7 @@ import torch.cuda
...
@@ -33,7 +31,7 @@ import torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.
import
microbatch
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.async_schedule
import
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
...
@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"]
...
@@ -47,8 +45,6 @@ __all__ = ["MultiProcessPipe", "LazyModule"]
Tensors
=
Tuple
[
Tensor
,
...]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
ListOfLazyModules
=
List
[
LazyModule
]
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
NamedModules
=
OrderedDict
[
str
,
Module
]
...
@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
...
@@ -87,7 +83,7 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
Of
LazyModule
s
])
->
None
:
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]
])
->
None
:
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
verify_list_of_callable
(
module
)
verify_list_of_callable
(
module
)
else
:
else
:
...
@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal
...
@@ -135,145 +131,11 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
raise
BalanceError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
@
dataclass
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],)
->
List
[
nn
.
Sequential
]:
class
PartitionInfo
:
location
:
Location
modules
:
"OrderedDict[str, nn.Module]"
invocations
:
List
[
Invocation
]
=
field
(
default_factory
=
list
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
modules
)
def
instantiate_partition
(
module
:
Union
[
nn
.
Sequential
,
ListOfLazyModules
],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
style
:
PipelineStyle
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
if
style
==
PipelineStyle
.
AsyncSchedule
:
module_ids
=
list
(
map
(
id
,
module
))
index_of_first_use
=
[
module_ids
.
index
(
x
)
for
x
in
module_ids
]
locations
:
List
[
Location
]
=
[]
module_iter
=
enumerate
(
iterate_module
(
module
))
partitions
:
List
[
List
[
PartitionInfo
]]
=
[]
for
bi
,
b
in
enumerate
(
balance
):
modules_for_rank
:
List
[
PartitionInfo
]
=
[]
current_module
:
OrderedDict
[
str
,
nn
.
Module
]
=
OrderedDict
()
def
current_location
()
->
Location
:
return
Location
(
bi
,
len
(
modules_for_rank
))
def
append_module
(
mod
:
"OrderedDict[str, nn.Module]"
)
->
None
:
modules_for_rank
.
append
(
PartitionInfo
(
current_location
(),
mod
))
while
sum
(
map
(
len
,
modules_for_rank
))
+
len
(
current_module
)
<
b
:
module_index
,
(
name
,
layer
)
=
next
(
module_iter
)
if
index_of_first_use
[
module_index
]
!=
module_index
:
# Subsequent reuse of a module
locations
.
append
(
locations
[
index_of_first_use
[
module_index
]])
continue
is_reused
=
index_of_first_use
.
count
(
index_of_first_use
[
module_index
])
>
1
if
is_reused
and
len
(
current_module
)
>
0
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
current_module
[
str
(
name
)]
=
layer
locations
.
append
(
current_location
())
if
is_reused
:
append_module
(
current_module
)
current_module
=
OrderedDict
()
if
len
(
current_module
)
>
0
:
append_module
(
current_module
)
partitions
.
append
(
modules_for_rank
)
filtered_locations
:
List
[
Optional
[
Location
]]
=
[
loc
for
loc
,
_
in
itertools
.
groupby
(
locations
)]
filtered_locations
.
append
(
None
)
for
i
in
range
(
len
(
filtered_locations
)
-
1
):
loc
=
filtered_locations
[
i
]
assert
loc
if
i
==
0
:
inv
=
Invocation
(
i
,
loc
,
None
,
filtered_locations
[
i
+
1
])
else
:
inv
=
Invocation
(
i
,
loc
,
filtered_locations
[
i
-
1
],
filtered_locations
[
i
+
1
])
partitions
[
loc
.
stage
][
loc
.
index
].
invocations
.
append
(
inv
)
invocations
=
enumerate
(
iterate_module
(
module
))
partition
=
partitions
[
group
.
rank
()]
result
:
List
[
ModuleWrapper
]
=
[]
for
partition_info
in
partition
:
wrapper
=
ModuleWrapper
(
nn
.
Sequential
(
OrderedDict
((
k
,
maybe_realize
(
m
))
for
k
,
m
in
partition_info
.
modules
.
items
())),
partition_info
.
location
,
partition_info
.
invocations
,
)
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
wrapper
.
module
:
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
result
.
append
(
wrapper
)
return
result
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
],)
->
Tuple
[
List
[
nn
.
Sequential
],
List
[
int
]]:
"""Splits a module into multiple partitions.
"""Splits a module into multiple partitions.
Returns:
Returns:
A tuple of (partitions, balance).
partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
item is a partition. All layers in a partition are placed in the
...
@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n
...
@@ -307,8 +169,7 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> Tuple[List[n
layers
.
clear
()
layers
.
clear
()
j
+=
1
j
+=
1
partitions
=
cast
(
List
[
nn
.
Sequential
],
nn
.
ModuleList
(
partitions
))
return
partitions
return
partitions
,
balance
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
...
@@ -415,10 +276,9 @@ class MultiProcessPipe(Module):
...
@@ -415,10 +276,9 @@ class MultiProcessPipe(Module):
def
__init__
(
def
__init__
(
self
,
self
,
module
:
Union
[
nn
.
Sequential
,
List
Of
LazyModule
s
],
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]
],
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
balance
:
Optional
[
Iterable
[
int
]]
=
None
,
*
,
*
,
style
:
PipelineStyle
=
PipelineStyle
.
MultiProcess
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
...
@@ -427,7 +287,6 @@ class MultiProcessPipe(Module):
...
@@ -427,7 +287,6 @@ class MultiProcessPipe(Module):
deferred_batch_norm
:
bool
=
False
,
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
pipelined_backward
:
bool
=
None
,
retain_graph
:
bool
=
False
,
retain_graph
:
bool
=
False
,
loss_fn
:
Optional
[
nn
.
Module
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -453,19 +312,16 @@ class MultiProcessPipe(Module):
...
@@ -453,19 +312,16 @@ class MultiProcessPipe(Module):
self
.
pipelined_backward
=
pipelined_backward
self
.
pipelined_backward
=
pipelined_backward
self
.
retain_graph
=
retain_graph
self
.
retain_graph
=
retain_graph
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
loss_fn
=
loss_fn
self
.
lock
=
threading
.
Lock
()
self
.
lock
=
threading
.
Lock
()
self
.
group
=
group
self
.
worker_map
=
worker_map
self
.
worker_map
=
worker_map
self
.
input_device
=
input_device
self
.
input_device
=
input_device
# The micro-batch index where the checkpointing stops.
self
.
group
:
torch
.
distributed
.
ProcessGroup
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
if
group
is
None
:
if
self
.
group
is
None
:
self
.
group
=
get_pipeline_parallel_group
()
self
.
group
=
get_pipeline_parallel_group
()
assert
self
.
group
else
:
self
.
group
=
group
self
.
balance
=
list
(
balance
)
self
.
balance
=
list
(
balance
)
...
@@ -480,14 +336,14 @@ class MultiProcessPipe(Module):
...
@@ -480,14 +336,14 @@ class MultiProcessPipe(Module):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
else
:
else
:
self
.
partitions
=
instantiate_partition
(
module
,
balance
,
self
.
group
,
style
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
balance
,
self
.
group
)
if
deferred_batch_norm
:
if
deferred_batch_norm
:
for
part
in
self
.
partitions
:
for
part
in
self
.
partitions
:
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
partitions
):
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
add_module
(
str
(
name
),
part
.
module
)
if
isinstance
(
module
,
nn
.
Sequential
):
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
,
_
=
split_module
(
module
,
balance
)
local_partitions
=
split_module
(
module
,
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
...
@@ -501,24 +357,78 @@ class MultiProcessPipe(Module):
...
@@ -501,24 +357,78 @@ class MultiProcessPipe(Module):
self
.
final_stage
=
False
self
.
final_stage
=
False
else
:
else
:
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
assert
loss_fn
is
None
or
self
.
final_stage
self
.
create_pipeline
()
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
pipeline
=
MultiProcessPipeline
(
cast
(
List
[
nn
.
Sequential
],
self
.
partitions
)
,
self
.
partitions
,
self
.
_skip_layout
,
self
.
_skip_layout
,
checkpoint_stop
,
checkpoint_stop
,
style
=
style
,
style
=
PipelineStyle
.
MultiProcess
,
group
=
self
.
group
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
final_stage
=
self
.
final_stage
,
)
)
del
module
if
self
.
pipelined_backward
is
None
:
def
instantiate_partition
(
if
get_model_parallel_world_size
()
>
1
:
self
,
self
.
pipelined_backward
=
True
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
if
isinstance
(
layer
,
nn
.
Module
):
return
layer
elif
callable
(
layer
):
return
layer
()
else
:
else
:
self
.
pipelined_backward
=
False
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
"""Counts the length of the underlying sequential module."""
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
a8dd9254
...
@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty
...
@@ -23,7 +23,7 @@ from queue import Empty as QueueEmpty
from
queue
import
Queue
from
queue
import
Queue
from
threading
import
Event
from
threading
import
Event
from
types
import
TracebackType
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
...
@@ -171,7 +171,7 @@ class MultiProcessPipeline:
...
@@ -171,7 +171,7 @@ class MultiProcessPipeline:
def
__init__
(
def
__init__
(
self
,
self
,
partitions
:
List
[
nn
.
Sequential
],
partitions
:
List
[
ModuleWrapper
],
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
style
:
PipelineStyle
,
style
:
PipelineStyle
,
...
@@ -180,7 +180,7 @@ class MultiProcessPipeline:
...
@@ -180,7 +180,7 @@ class MultiProcessPipeline:
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
final_stage
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
partitions
:
List
[
ModuleWrapper
]
=
cast
(
List
[
ModuleWrapper
],
partitions
)
self
.
partitions
=
partitions
self
.
skip_layout
=
skip_layout
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
style
=
style
self
.
style
=
style
...
...
tests/nn/pipe_process/test_pipe.py
View file @
a8dd9254
...
@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import (
...
@@ -32,7 +32,6 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
,
initialize_model_parallel
,
)
)
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.types
import
PipelineStyle
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
...
@@ -874,9 +873,12 @@ def reuse_lazy():
...
@@ -874,9 +873,12 @@ def reuse_lazy():
assert
torch
.
equal
(
model_out
,
pipe_out
)
assert
torch
.
equal
(
model_out
,
pipe_out
)
def
test_instantiate_partition
():
@
torch_spawn
([
1
])
def
instantiate_partition
():
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.async_schedule
import
Location
from
fairscale.nn.pipe.multiprocess_pipe
import
instantiate_partition
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
AsyncPipe
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
class
FakeGroup
:
class
FakeGroup
:
def
__init__
(
self
,
rank
,
size
):
def
__init__
(
self
,
rank
,
size
):
...
@@ -904,9 +906,7 @@ def test_instantiate_partition():
...
@@ -904,9 +906,7 @@ def test_instantiate_partition():
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
# instantiated model
for
rank
in
range
(
len
(
balance
)):
for
rank
in
range
(
len
(
balance
)):
instantiated
=
instantiate_partition
(
instantiated
=
pipe
.
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)))
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)),
PipelineStyle
.
AsyncSchedule
)
for
part
in
instantiated
:
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
for
inv
in
part
.
invocations
:
for
inv
in
part
.
invocations
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment