Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
f21b5ffc
Unverified
Commit
f21b5ffc
authored
Feb 03, 2021
by
msbaines
Committed by
GitHub
Feb 03, 2021
Browse files
[refactor] pipe: simplify balance and module checks (#346)
parent
cd186441
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
291 deletions
+32
-291
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+2
-8
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+27
-105
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-178
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
f21b5ffc
...
...
@@ -13,7 +13,7 @@ from torch import Tensor, nn
from
.async_pipeline
import
AsyncPipeline
from
.async_schedule
import
Invocation
,
Location
,
ModuleWrapper
from
.multiprocess_pipe
import
MultiProcessPipe
,
check_balance
from
.multiprocess_pipe
import
MultiProcessPipe
from
.skip.skippable
import
Skippable
from
.types
import
LazyModule
...
...
@@ -54,14 +54,8 @@ class AsyncPipe(MultiProcessPipe):
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
f21b5ffc
...
...
@@ -53,85 +53,22 @@ else:
NamedModules
=
OrderedDict
def
recommend_auto_balance
(
message
:
str
)
->
str
:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return
f
"""
{
message
}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""
# FIXME(tom) make this a valid way to call
def
verify_list_of_callable
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
None
:
for
layer
in
module
:
if
isinstance
(
layer
,
nn
.
Module
):
pass
elif
isinstance
(
layer
,
LazyModule
):
pass
else
:
raise
TypeError
(
f
"layer
{
type
(
layer
)
}
must be nn.Module or LazyModule to be partitioned"
)
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]])
->
None
:
if
isinstance
(
module
,
Iterable
)
and
not
isinstance
(
module
,
nn
.
Sequential
):
verify_list_of_callable
(
module
)
else
:
if
not
isinstance
(
module
,
nn
.
Sequential
):
raise
TypeError
(
"module must be nn.Sequential to be partitioned"
)
named_children
=
list
(
module
.
named_children
())
if
len
(
named_children
)
!=
len
(
module
):
if
len
(
set
(
map
(
id
,
module
)))
!=
len
(
module
):
raise
ValueError
(
"module with duplicate children is not supported"
)
def
verify_splitting
(
module
:
nn
.
Sequential
,
partitions
:
List
[
nn
.
Sequential
],
balance
:
Iterable
[
int
],)
->
None
:
num_parameters
=
len
(
list
(
module
.
parameters
()))
num_child_parameters
=
sum
(
len
(
list
(
child
.
parameters
()))
for
child
in
module
.
children
())
if
num_parameters
==
num_child_parameters
:
return
for
i
in
range
(
len
(
partitions
)):
for
j
in
range
(
i
+
1
,
len
(
partitions
)):
parti
=
partitions
[
i
]
partj
=
partitions
[
j
]
for
p
in
parti
.
parameters
():
for
q
in
partj
.
parameters
():
if
p
is
q
:
raise
ValueError
(
"module with duplicate parameters on distinct devices is not supported"
)
class
BalanceError
(
ValueError
):
pass
def
check_balance
(
module
:
Any
,
balance
:
Iterable
[
int
],
filter_unique
:
bool
=
False
)
->
None
:
if
filter_unique
:
module_len
=
len
(
set
(
map
(
id
,
module
)))
else
:
module_len
=
len
(
module
)
if
module_len
!=
sum
(
balance
):
raise
BalanceError
(
def
check_balance
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
])
->
None
:
if
len
(
module
)
!=
sum
(
balance
):
raise
ValueError
(
f
"module and sum of balance have different length (module:
{
len
(
module
)
}
, sum of balance:
{
sum
(
balance
)
}
)"
)
if
any
(
x
<=
0
for
x
in
balance
):
raise
B
al
anc
eError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
raise
V
al
u
eError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
Iterable
[
int
]
,
)
->
List
[
nn
.
Sequential
]:
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
List
[
int
])
->
List
[
nn
.
Sequential
]:
"""Splits a module into multiple partitions.
Returns:
...
...
@@ -148,10 +85,6 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequ
the number of devices is fewer than the number of partitions.
"""
balance
=
list
(
balance
)
check_balance
(
module
,
balance
)
j
=
0
partitions
=
[]
layers
:
NamedModules
=
OrderedDict
()
...
...
@@ -274,7 +207,7 @@ class MultiProcessPipe(Module):
def
__init__
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Optional
[
Iterable
[
int
]
]
=
None
,
balance
:
Iterable
[
int
],
*
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
...
...
@@ -290,14 +223,14 @@ class MultiProcessPipe(Module):
chunks
=
int
(
chunks
)
checkpoint
=
str
(
checkpoint
)
if
balance
is
None
:
raise
ValueError
(
recommend_auto_balance
(
"balance is required"
))
if
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
self
.
balance
=
list
(
balance
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
...
...
@@ -320,34 +253,29 @@ class MultiProcessPipe(Module):
else
:
self
.
group
=
group
self
.
balance
=
list
(
balance
)
if
self
.
group
.
size
()
<
len
(
self
.
balance
):
raise
IndexError
(
f
"too few ranks to hold given partitions (ranks:
{
self
.
group
.
size
()
}
, partitions:"
f
"
{
len
(
self
.
balance
)
}
)"
)
try
:
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
else
:
self
.
partitions
=
self
.
instantiate_partition
(
module
,
balance
,
self
.
group
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
for
part
in
self
.
partitions
:
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
=
split_module
(
module
,
balance
)
local_partitions
=
split_module
(
module
,
self
.
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
except
BalanceError
as
exc
:
raise
ValueError
(
recommend_auto_balance
(
str
(
exc
)))
rank
=
self
.
group
.
rank
()
if
rank
>=
len
(
self
.
balance
):
self
.
pipeline
=
None
...
...
@@ -378,14 +306,8 @@ class MultiProcessPipe(Module):
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
balance
=
list
(
balance
)
check_balance
(
module
,
balance
,
True
)
layers
:
NamedModules
=
OrderedDict
()
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
...
...
tests/nn/pipe_process/test_pipe.py
View file @
f21b5ffc
...
...
@@ -32,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
,
)
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
...
...
@@ -706,15 +706,11 @@ def named_children(pipe_class):
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
recommend_auto_balance
(
pipe_class
):
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
# balance is required
pipe_class
(
nn
.
Sequential
())
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
pipe_class
(
nn
.
Sequential
(),
[
1
])
with
pytest
.
raises
(
ValueError
,
match
=
"fairscale.nn.pipe.balance"
):
with
pytest
.
raises
(
ValueError
):
# module and sum of balance have different length (module: 2, sum of balance: 1)
pipe_class
(
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
)),
[
1
])
...
...
@@ -805,174 +801,3 @@ def async_event_loop():
if
pipe
.
final_stage
:
loss
=
output
.
mean
()
loss
.
backward
()
@
torch_spawn
([
4
])
def
reuse_lazy
():
if
False
:
# speed
reused
=
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
))
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
# model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()]
pipe
=
AsyncPipe
(
model
,
[
3
,
1
,
1
],
worker_map
=
get_worker_map
())
pipe
.
eval
()
output
=
pipe
(
torch
.
rand
(
10
))
print
(
f
"output on
{
pipe
.
group
.
rank
()
}
,
{
output
}
"
)
torch
.
distributed
.
barrier
()
set_random_seed
(
1234
)
# test both foward
reused
=
nn
.
Linear
(
10
,
10
)
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
model
=
nn
.
Sequential
(
*
layers
)
model
.
eval
()
set_random_seed
(
1234
)
# ensure identical weights but no sharing between model and pipe
reused
=
nn
.
Linear
(
10
,
10
)
layers
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
pipe
=
AsyncPipe
(
layers
,
[
3
,
1
,
1
],
worker_map
=
get_worker_map
())
pipe
.
eval
()
model_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
pipe_optimizer
=
torch
.
optim
.
SGD
(
pipe
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
if
len
(
list
(
pipe
.
parameters
()))
else
None
inputs
=
torch
.
rand
(
10
)
if
False
:
# speed
model_out
=
model
(
inputs
)
pipe_out
=
pipe
(
inputs
)
torch
.
distributed
.
barrier
()
if
pipe
.
final_stage
:
assert
torch
.
equal
(
model_out
,
pipe_out
)
model
.
train
()
pipe
.
train
()
model_out
=
model
(
inputs
)
pipe_out
=
pipe
(
inputs
)
if
pipe
.
final_stage
:
pipe_loss
=
pipe_out
.
mean
()
pipe_loss
.
backward
()
model_loss
=
model_out
.
mean
()
model_loss
.
backward
()
model_optimizer
.
step
()
if
pipe_optimizer
:
pipe_optimizer
.
step
()
model
.
eval
()
pipe
.
eval
()
model_out
=
model
(
inputs
)
pipe_out
=
pipe
(
inputs
)
print
(
f
"before barrier on
{
torch
.
distributed
.
get_rank
()
}
"
)
torch
.
distributed
.
barrier
()
print
(
f
"after barrier on
{
torch
.
distributed
.
get_rank
()
}
"
)
if
pipe
.
final_stage
:
assert
torch
.
equal
(
model_out
,
pipe_out
)
@
torch_spawn
([
1
])
def
instantiate_partition
():
from
fairscale.nn.pipe.async_schedule
import
Location
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
AsyncPipe
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
class
FakeGroup
:
def
__init__
(
self
,
rank
,
size
):
self
.
_rank
=
rank
self
.
_size
=
size
def
rank
(
self
):
return
self
.
_rank
def
size
(
self
):
return
self
.
_size
def
check_partitions
(
model
,
balance
,
expected_order
,
expected_ranks
):
"""Check the instantiated model matches expectation of order and rank
model: a list of modules or an nn.Sequential
balance: the balance argument to MultiProcessPipe
expected_order: the index of modules in `model` in the order they will
be executed, grouped by nn.Sequential
expected_rank: the rank that each module will be executed on
"""
invocations
=
[]
invocation_wrapper
=
dict
()
# Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
# instantiated model
for
rank
in
range
(
len
(
balance
)):
instantiated
=
pipe
.
instantiate_partition
(
model
,
balance
,
FakeGroup
(
rank
,
len
(
balance
)))
for
part
in
instantiated
:
assert
isinstance
(
part
.
module
,
nn
.
Sequential
)
for
inv
in
part
.
invocations
:
invocations
.
append
(
inv
)
invocation_wrapper
[
inv
]
=
part
modules
=
[]
prev
=
None
current
=
Location
(
0
,
0
)
ranks
=
[]
for
order
,
inv
in
enumerate
(
sorted
(
invocations
,
key
=
lambda
x
:
x
.
order
)):
# Check integrity of Location chain
assert
inv
.
order
==
order
assert
inv
.
source
==
prev
assert
inv
.
this
==
current
prev
=
inv
.
this
current
=
inv
.
dest
modules
.
append
(
list
(
invocation_wrapper
[
inv
].
module
.
children
()))
ranks
.
append
(
inv
.
this
.
stage
)
# assert len(modules) == len(expected_order)
for
left
,
right
in
zip
(
modules
,
expected_order
):
assert
len
(
left
)
==
len
(
right
),
f
"
{
right
}
"
assert
list
(
map
(
id
,
left
))
==
list
(
map
(
id
,
(
model
[
e
]
for
e
in
right
))),
f
"
{
right
}
"
assert
ranks
==
expected_ranks
reused
=
nn
.
Linear
(
20
,
20
)
model
=
[
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
nn
.
ReLU
(),
reused
,
nn
.
ReLU
()]
balance
=
[
3
,
1
,
1
]
check_partitions
(
model
,
balance
,
expected_order
=
[[
0
],
[
1
,
2
],
[
0
],
[
4
],
[
0
],
[
6
]],
expected_ranks
=
[
0
,
0
,
0
,
1
,
0
,
2
]
)
reused2
=
nn
.
Linear
(
5
,
5
)
model
=
[
reused
,
reused2
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
reused2
,
nn
.
ReLU
(),
reused
,
reused2
,
nn
.
ReLU
()]
balance
=
[
4
,
1
,
1
]
check_partitions
(
model
,
balance
,
expected_order
=
[[
0
],
[
1
],
[
2
,
3
],
[
0
],
[
1
],
[
6
],
[
0
],
[
1
],
[
9
]],
expected_ranks
=
[
0
,
0
,
0
,
0
,
0
,
1
,
0
,
0
,
2
],
)
reused2
=
nn
.
Linear
(
5
,
5
)
model
=
[
nn
.
Linear
(
10
,
10
),
reused
,
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(),
reused
,
reused2
,
nn
.
ReLU
(),
reused
,
reused2
,
nn
.
ReLU
(),
]
# 0 1 2 3 1 5 6 1 5 9
balance
=
[
4
,
2
,
1
]
check_partitions
(
model
,
balance
,
expected_order
=
[[
0
],
[
1
],
[
2
,
3
],
[
1
],
[
5
],
[
6
],
[
1
],
[
5
],
[
9
]],
expected_ranks
=
[
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
2
],
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment