Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2d3d5a7b
Unverified
Commit
2d3d5a7b
authored
Apr 01, 2021
by
msbaines
Committed by
GitHub
Apr 01, 2021
Browse files
[feat] remove old MultiProcessPipe (#563)
parent
e141a93e
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
63 additions
and
1139 deletions
+63
-1139
.circleci/config.yml
.circleci/config.yml
+0
-8
benchmarks/golden_configs/lm_wikitext2.py
benchmarks/golden_configs/lm_wikitext2.py
+6
-13
benchmarks/pipe.py
benchmarks/pipe.py
+9
-75
examples/tutorial_pipe_multiprocess.py
examples/tutorial_pipe_multiprocess.py
+0
-63
fairscale/nn/pipe/__init__.py
fairscale/nn/pipe/__init__.py
+0
-1
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+0
-349
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+0
-288
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+1
-272
tests/nn/pipe_process/test_bugs.py
tests/nn/pipe_process/test_bugs.py
+5
-5
tests/nn/pipe_process/test_inplace.py
tests/nn/pipe_process/test_inplace.py
+4
-4
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+36
-59
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+2
-2
No files found.
.circleci/config.yml
View file @
2d3d5a7b
...
@@ -197,12 +197,6 @@ run_pipe_benchmark: &run_pipe_benchmark
...
@@ -197,12 +197,6 @@ run_pipe_benchmark: &run_pipe_benchmark
command
:
|
command
:
|
python benchmarks/pipe.py
python benchmarks/pipe.py
run_mp_pipe_benchmark
:
&run_mp_pipe_benchmark
-
run
:
name
:
Run Multiprocess Pipe Benchmark
command
:
|
python benchmarks/pipe.py --multiprocess --lazy-construction
run_oss_benchmark
:
&run_oss_benchmark
run_oss_benchmark
:
&run_oss_benchmark
-
run
:
-
run
:
name
:
Run OSS Benchmark
name
:
Run OSS Benchmark
...
@@ -578,8 +572,6 @@ jobs:
...
@@ -578,8 +572,6 @@ jobs:
-
<<
:
*run_pipe_benchmark
-
<<
:
*run_pipe_benchmark
-
<<
:
*run_mp_pipe_benchmark
-
<<
:
*run_oss_amp
-
<<
:
*run_oss_amp
-
<<
:
*run_oss_for_each
-
<<
:
*run_oss_for_each
...
...
benchmarks/golden_configs/lm_wikitext2.py
View file @
2d3d5a7b
...
@@ -80,19 +80,12 @@ class Pipe:
...
@@ -80,19 +80,12 @@ class Pipe:
"criterion"
:
nn
.
CrossEntropyLoss
(),
"criterion"
:
nn
.
CrossEntropyLoss
(),
}
}
def
get_golden_real_stats
(
multiprocess
=
False
):
def
get_golden_real_stats
():
if
not
multiprocess
:
return
{
return
{
"avg_wps"
:
703.778
,
"avg_wps"
:
703.778
,
"std_dev_wps"
:
5.732
,
"std_dev_wps"
:
5.732
,
"peak_mem_usage"
:
[
2320996352
,
1396742144
,
1396742144
,
2340010496
],
"peak_mem_usage"
:
[
2320996352
,
1396742144
,
1396742144
,
2340010496
],
}
}
else
:
return
{
"avg_wps"
:
647.404
,
"std_dev_wps"
:
14.51
,
"peak_mem_usage"
:
[
3305007616
,
2578692608
,
3304524288
,
2578692608
],
}
def
get_golden_synthetic_stats
():
def
get_golden_synthetic_stats
():
# TODO(anj-s): Add support for synthetic regression benchmarks
# TODO(anj-s): Add support for synthetic regression benchmarks
...
...
benchmarks/pipe.py
View file @
2d3d5a7b
...
@@ -16,16 +16,13 @@ import numpy as np
...
@@ -16,16 +16,13 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
rpc
from
torch.distributed
import
rpc
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
benchmarks.golden_configs.lm_wikitext2
import
Pipe
as
lm_wikitext2
from
benchmarks.golden_configs.lm_wikitext2
import
Pipe
as
lm_wikitext2
from
fairscale.nn
import
Pipe
from
fairscale.nn
import
Pipe
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
from
fairscale.utils.testing
import
dist_init
from
fairscale.nn.pipe
import
LazyModule
,
MultiProcessPipe
from
fairscale.utils.testing
import
dist_init
,
get_worker_map
MPI_PORT
=
29500
MPI_PORT
=
29500
RPC_PORT
=
29501
RPC_PORT
=
29501
...
@@ -211,7 +208,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
...
@@ -211,7 +208,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
if
i
%
log_interval
==
0
and
i
>
0
:
if
i
%
log_interval
==
0
and
i
>
0
:
cur_loss
=
total_loss
/
log_interval
cur_loss
=
total_loss
/
log_interval
elapsed
=
time
.
time
()
-
start_time
elapsed
=
time
.
time
()
-
start_time
if
not
args
.
multiprocess
or
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
if
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
logging
.
debug
(
logging
.
debug
(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}"
.
format
(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}"
.
format
(
i
,
total_tokens_per_log_interval
/
elapsed
,
cur_loss
,
math
.
exp
(
cur_loss
)
i
,
total_tokens_per_log_interval
/
elapsed
,
cur_loss
,
math
.
exp
(
cur_loss
)
...
@@ -227,7 +224,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
...
@@ -227,7 +224,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
raise
RuntimeError
(
raise
RuntimeError
(
"Unable to benchmark on a single batch. Increase the size "
" of the dataset and rerun the benchmark."
"Unable to benchmark on a single batch. Increase the size "
" of the dataset and rerun the benchmark."
)
)
if
not
args
.
multiprocess
or
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
if
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
return
wps
,
loss
.
item
()
return
wps
,
loss
.
item
()
else
:
else
:
return
0.0
,
0.0
return
0.0
,
0.0
...
@@ -276,8 +273,7 @@ def verify_peak_memory(rank, golden_config, std_dev):
...
@@ -276,8 +273,7 @@ def verify_peak_memory(rank, golden_config, std_dev):
def
verify_lm_run
(
wps
,
golden_config
,
args
):
def
verify_lm_run
(
wps
,
golden_config
,
args
):
"""Verify that words per second for a given benchmark run matches the golden data."""
"""Verify that words per second for a given benchmark run matches the golden data."""
# Verify wps only on the last rank in multiprocess pipe
if
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
if
not
args
.
multiprocess
or
dist
.
get_rank
()
==
dist
.
get_world_size
()
-
1
:
# Assert that words per second is within 3 standard deviations of the average
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
# of five golden runs
logging
.
info
(
"Throughput(wps) is {:.2f}."
.
format
(
wps
))
logging
.
info
(
"Throughput(wps) is {:.2f}."
.
format
(
wps
))
...
@@ -289,11 +285,8 @@ def verify_lm_run(wps, golden_config, args):
...
@@ -289,11 +285,8 @@ def verify_lm_run(wps, golden_config, args):
)
)
)
)
if
args
.
multiprocess
:
for
i
in
range
(
4
):
verify_peak_memory
(
dist
.
get_rank
(),
golden_config
,
1.5
)
verify_peak_memory
(
i
,
golden_config
,
1.1
)
else
:
for
i
in
range
(
4
):
verify_peak_memory
(
i
,
golden_config
,
1.1
)
def
benchmark_language_model
(
model_config
,
model
,
benchmark_config
,
model_specs
,
args
):
def
benchmark_language_model
(
model_config
,
model
,
benchmark_config
,
model_specs
,
args
):
...
@@ -400,7 +393,7 @@ def get_golden_config(model_name, args):
...
@@ -400,7 +393,7 @@ def get_golden_config(model_name, args):
"""Return a dict with the golden data for throughput and memory usage."""
"""Return a dict with the golden data for throughput and memory usage."""
if
model_name
==
"lm"
:
if
model_name
==
"lm"
:
return
lm_wikitext2
.
get_golden_real_stats
(
args
.
multiprocess
)
return
lm_wikitext2
.
get_golden_real_stats
()
else
:
else
:
raise
RuntimeError
(
"Unrecognized args.model_mame "
%
args
.
model_name
)
raise
RuntimeError
(
"Unrecognized args.model_mame "
%
args
.
model_name
)
...
@@ -431,32 +424,6 @@ def benchmark_single_process(args):
...
@@ -431,32 +424,6 @@ def benchmark_single_process(args):
benchmark_language_model
(
model_config
,
pipe_model
,
benchmark_config
,
model_specs
,
args
)
benchmark_language_model
(
model_config
,
pipe_model
,
benchmark_config
,
model_specs
,
args
)
def
run_mp_worker
(
args
,
available_workers
):
benchmark_config
=
create_benchmark_config
(
args
.
model_name
)
model_specs
=
get_model_specs
(
args
.
model_name
)
model_config
=
create_model_config
(
args
,
benchmark_config
=
benchmark_config
,
model_specs
=
model_specs
)
model
=
model_config
[
"model"
]
balance
=
generate_balance
(
get_pipeline_parallel_group
().
size
(),
len
(
model
))
pipe_model
=
MultiProcessPipe
(
model
,
balance
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
checkpoint
=
args
.
checkpoint
,
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
)
if
torch
.
cuda
.
is_available
():
pipe_model
=
pipe_model
.
cuda
()
if
args
.
dry_run
:
train
(
model_config
,
pipe_model
,
benchmark_config
,
model_specs
,
args
)
else
:
benchmark_language_model
(
model_config
,
pipe_model
,
benchmark_config
,
model_specs
,
args
)
def
run_worker
(
rank
,
world_size
,
args
):
def
run_worker
(
rank
,
world_size
,
args
):
if
args
.
world_size
!=
0
:
if
args
.
world_size
!=
0
:
world_size
=
args
.
world_size
world_size
=
args
.
world_size
...
@@ -469,35 +436,7 @@ def run_worker(rank, world_size, args):
...
@@ -469,35 +436,7 @@ def run_worker(rank, world_size, args):
torch
.
distributed
.
destroy_process_group
()
torch
.
distributed
.
destroy_process_group
()
def
benchmark_multiprocess
(
rank
,
world_size
,
args
):
init_method_pgroup
=
"tcp://localhost:{}"
.
format
(
MPI_PORT
)
# TODO(anj-s): Add regression benchmarks for nccl as well.
torch
.
distributed
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_method_pgroup
)
torch
.
cuda
.
set_device
(
rank
%
torch
.
cuda
.
device_count
())
# TODO(anj-s): Move to TensorPipeRpcBackendOptions.
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
,
backend
=
rpc
.
BackendType
.
PROCESS_GROUP
,
rpc_backend_options
=
rpc
.
ProcessGroupRpcBackendOptions
(
rpc_timeout
=
20
,
init_method
=
"tcp://localhost:{}"
.
format
(
RPC_PORT
)
),
)
initialize_model_parallel
(
1
,
world_size
)
init_random_seed
(
0
)
run_mp_worker
(
args
,
world_size
)
rpc
.
shutdown
()
torch
.
distributed
.
destroy_process_group
()
parser
=
argparse
.
ArgumentParser
(
description
=
"benchmark"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"benchmark"
)
parser
.
add_argument
(
"--multiprocess"
,
action
=
"store_true"
,
help
=
"Runs single process benchmarks."
)
parser
.
add_argument
(
"--host"
,
"-o"
,
type
=
str
,
default
=
"localhost"
,
help
=
"hostname"
)
parser
.
add_argument
(
"--host"
,
"-o"
,
type
=
str
,
default
=
"localhost"
,
help
=
"hostname"
)
parser
.
add_argument
(
"--chunks"
,
type
=
int
,
default
=
1
,
help
=
"number of microbatches per batch"
)
parser
.
add_argument
(
"--chunks"
,
type
=
int
,
default
=
1
,
help
=
"number of microbatches per batch"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
,
help
=
"size of a batch"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
,
help
=
"size of a batch"
)
...
@@ -522,10 +461,5 @@ if __name__ == "__main__":
...
@@ -522,10 +461,5 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
if
not
args
.
debug
else
logging
.
DEBUG
)
logging
.
basicConfig
(
level
=
logging
.
INFO
if
not
args
.
debug
else
logging
.
DEBUG
)
if
not
args
.
multiprocess
:
logging
.
info
(
f
"Running single process benchmark with args:
{
args
}
"
)
logging
.
info
(
f
"Running single process benchmark with args:
{
args
}
"
)
benchmark_single_process
(
args
)
benchmark_single_process
(
args
)
else
:
world_size
=
max
(
torch
.
cuda
.
device_count
(),
1
)
logging
.
info
(
f
"Running multiprocess benchmark with args:
{
args
}
"
)
mp
.
spawn
(
benchmark_multiprocess
,
args
=
(
world_size
,
args
),
nprocs
=
world_size
,
join
=
True
)
examples/tutorial_pipe_multiprocess.py
deleted
100644 → 0
View file @
e141a93e
import
os
from
helpers
import
dist_init
,
get_data
,
get_loss_fun
,
get_model
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
import
torch.optim
as
optim
from
fairscale.nn.model_parallel
import
initialize_model_parallel
from
fairscale.nn.pipe
import
MultiProcessPipe
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
def
run
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"10638"
dist_init
(
rank
,
world_size
)
os
.
environ
[
"MASTER_PORT"
]
=
"10639"
dist
.
rpc
.
init_rpc
(
f
"worker
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
initialize_model_parallel
(
1
,
world_size
)
model
=
get_model
()
data
,
target
=
get_data
()[
0
]
loss_fn
=
get_loss_fun
()
device
=
torch
.
device
(
"cuda"
,
rank
)
if
DEVICE
==
"cuda"
else
torch
.
device
(
"cpu"
)
model
=
MultiProcessPipe
(
model
,
balance
=
[
2
,
1
],
worker_map
=
{
0
:
"worker0"
,
1
:
"worker1"
},
# Needed to convert ranks to RPC worker names
input_device
=
device
,
).
to
(
device
)
# define optimizer and loss function
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.001
)
# zero the parameter gradients
optimizer
.
zero_grad
()
# outputs and target need to be on the same device
# forward step
outputs
=
model
(
data
.
to
(
device
))
# compute loss
if
rank
==
1
:
loss
=
loss_fn
(
outputs
.
to
(
device
),
target
.
to
(
device
))
# backward + optimize
loss
.
backward
()
optimizer
.
step
()
else
:
model
.
back_helper
(
outputs
)
print
(
f
"Finished Training Step on
{
rank
}
"
)
dist
.
rpc
.
shutdown
()
del
model
if
__name__
==
"__main__"
:
world_size
=
2
mp
.
spawn
(
run
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
fairscale/nn/pipe/__init__.py
View file @
2d3d5a7b
...
@@ -20,7 +20,6 @@
...
@@ -20,7 +20,6 @@
"""A Pipe implementation in PyTorch."""
"""A Pipe implementation in PyTorch."""
from
.async_pipe
import
AsyncPipe
from
.async_pipe
import
AsyncPipe
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.checkpoint
import
is_checkpointing
,
is_recomputing
from
.multiprocess_pipe
import
LazyModule
,
MultiProcessPipe
from
.pipe
import
Pipe
from
.pipe
import
Pipe
from
.rpc
import
PipeRPCWrapper
from
.rpc
import
PipeRPCWrapper
...
...
fairscale/nn/pipe/multiprocess_pipe.py
deleted
100644 → 0
View file @
e141a93e
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The MultiProcessPipe interface."""
from
collections
import
OrderedDict
import
threading
from
typing
import
TYPE_CHECKING
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
warnings
import
torch
from
torch
import
Tensor
,
nn
import
torch.autograd
import
torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
from
.types
import
LazyModule
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
Tensors
=
Tuple
[
Tensor
,
...]
TensorOrTensors
=
Union
[
Tensor
,
Tensors
]
if
TYPE_CHECKING
:
Module
=
nn
.
Module
[
TensorOrTensors
]
NamedModules
=
OrderedDict
[
str
,
Module
]
else
:
Module
=
nn
.
Module
NamedModules
=
OrderedDict
def
verify_module
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]])
->
None
:
if
len
(
set
(
map
(
id
,
module
)))
!=
len
(
module
):
raise
ValueError
(
"module with duplicate children is not supported"
)
def
check_balance
(
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
])
->
None
:
if
len
(
module
)
!=
sum
(
balance
):
raise
ValueError
(
f
"module and sum of balance have different length (module:
{
len
(
module
)
}
, sum of balance:
{
sum
(
balance
)
}
)"
)
if
any
(
x
<=
0
for
x
in
balance
):
raise
ValueError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
class
MultiProcessPipe
(
Module
):
"""Wraps an arbitrary :class:`nn.Sequential <torch.nn.Sequential>` module
to train on Pipe_. If the module requires lots of memory, Pipe will be
very efficient.
::
model = nn.Sequential(a, b, c, d)
model = Pipe(model, balance=[1, 1, 1, 1], chunks=8)
output = model(input)
.. _Pipe: https://arxiv.org/abs/1811.06965
Pipe combines pipeline parallelism with checkpointing to reduce peak
memory required to train while minimizing device under-utilization.
You should determine the balance when defining a :class:`Pipe` module, as
balancing will not be done automatically. The module will be partitioned
into multiple devices according to the given balance. You may rely on
heuristics to find your own optimal configuration.
Args:
module (torch.nn.Sequential):
sequential module to be parallelized
balance (ints):
list of number of layers in each partition
Keyword Args:
group (ProcessGroup):
the process group that all
pipeline stages are a member of. Defaults to
`get_pipeline_parallel_group()`
worker_map (Dict[int, str]):
a map from worker name (the first argument to
`torch.distributed.rpc.init_rpc`) to global rank (i.e.
`torch.distributed.get_rank()`) needed in order for pipeline stages
to communicate with each other
input_device (device):
the device on which tensors should be located before being passed to
the first module in a given pipeline stage
chunks (int):
number of micro-batches (default: ``1``)
checkpoint (str):
when to enable checkpointing, one of ``'always'``,
``'except_last'``, or ``'never'`` (default: ``'except_last'``)
deferred_batch_norm (bool):
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
Raises:
TypeError:
the module is not a :class:`nn.Sequential <torch.nn.Sequential>`.
ValueError:
invalid arguments, or wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
#: The number of layers in each partition.
balance
:
List
[
int
]
=
[]
# ^^
# The default value [] required for Sphinx's autoattribute.
#: The devices mapped to each partition.
#:
#: ``devices[-1]`` refers to the device of the last partition, which means
#: it is the output device. Probably, you need to use it to transfer the
#: target to calculate the loss without a device mismatch
#: :exc:`RuntimeError`. For example::
#:
#: out_device = pipe.devices[-1]
#:
#: for input, target in loader:
#: target = target.to(out_device, non_blocking=True)
#: output = pipe(input)
#: loss = F.cross_entropy(output, target)
#:
#: The number of micro-batches.
chunks
:
int
=
1
#: The checkpoint mode to determine when to enable checkpointing. It is one
#: of ``'always'``, ``'except_last'``, or ``'never'``.
checkpoint
:
str
=
"except_last"
def
__init__
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
Iterable
[
int
],
*
,
group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
chunks
:
int
=
chunks
,
checkpoint
:
str
=
checkpoint
,
deferred_batch_norm
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
if
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
self
.
balance
=
list
(
balance
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
lock
=
threading
.
Lock
()
self
.
worker_map
=
worker_map
self
.
input_device
=
input_device
self
.
group
:
torch
.
distributed
.
ProcessGroup
if
group
is
None
:
self
.
group
=
get_pipeline_parallel_group
()
else
:
self
.
group
=
group
if
self
.
group
.
size
()
<
len
(
self
.
balance
):
raise
IndexError
(
f
"too few ranks to hold given partitions (ranks:
{
self
.
group
.
size
()
}
, partitions:"
f
"
{
len
(
self
.
balance
)
}
)"
)
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
rank
=
self
.
group
.
rank
()
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
if
rank
>=
len
(
self
.
balance
):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partition
=
nn
.
Sequential
()
self
.
pipeline
=
None
else
:
self
.
partition
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
self
.
partitition
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
self
.
partition
,
chunks
)
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
create_pipeline
()
del
module
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
partition
,
self
.
_skip_layout
,
checkpoint_stop
,
group
=
self
.
group
,
worker_map
=
self
.
worker_map
,
input_device
=
self
.
input_device
,
final_stage
=
self
.
final_stage
,
)
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
nn
.
Sequential
:
rank
=
group
.
rank
()
first_layer
=
sum
(
balance
[:
rank
])
num_layers
=
balance
[
rank
]
layers
=
module
[
first_layer
:
first_layer
+
num_layers
]
instantiated_layers
=
[
l
if
isinstance
(
l
,
nn
.
Module
)
else
l
()
for
l
in
layers
]
return
nn
.
Sequential
(
*
instantiated_layers
)
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
return
self
.
partition
.
__len__
()
def
__getitem__
(
self
,
index
:
int
)
->
nn
.
Module
:
"""Gets a layer in the underlying sequential module."""
return
self
.
partition
.
__getitem__
(
index
)
def
__iter__
(
self
)
->
Iterable
[
nn
.
Module
]:
"""Iterates over children of the underlying sequential module."""
return
self
.
partition
.
__iter__
()
def
forward
(
self
,
input
:
TensorOrTensors
)
->
TensorOrTensors
:
# type: ignore
""":class:`MultiProcessPipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or tensors): input mini-batch
Returns:
tensor or tensors: output mini-batch
Raises:
TypeError: input is not a tensor or tensors.
"""
microbatch
.
check
(
input
)
if
not
self
.
pipeline
:
# No pipeline is not illegal, more ranks than partitions
return
input
# Divide a mini-batch into micro-batches.
batches
=
microbatch
.
scatter
(
input
,
self
.
chunks
)
# Run pipeline parallelism.
with
self
.
lock
:
self
.
pipeline
.
run
(
self
.
training
,
batches
)
if
self
.
final_stage
:
# Merge the micro-batches into one mini-batch.
if
self
.
pipelined_backward
:
with
torch
.
no_grad
():
output
=
microbatch
.
gather
(
batches
)
phony
=
get_phony
(
torch
.
device
(
torch
.
cuda
.
current_device
()
if
torch
.
cuda
.
is_available
()
else
"cpu"
),
requires_grad
=
True
,
)
output
=
PipelinedBackwardPass
.
apply
(
output
,
batches
,
phony
)
else
:
output
=
microbatch
.
gather
(
batches
)
else
:
# Don't merge micro-batches to avoid unnecessary edges in autograd
# graph
# FIXME(tom) should figure out a proper type here
output
=
batches
# type: ignore
return
output
def
back_helper
(
self
,
output
:
List
[
microbatch
.
Batch
])
->
None
:
if
self
.
final_stage
:
raise
ValueError
(
"back_helper should only be called on non-final stages"
)
if
self
.
pipeline
:
self
.
pipeline
.
back_helper
(
output
)
class
PipelinedBackwardPass
(
torch
.
autograd
.
Function
):
@
staticmethod
# type: ignore
def
forward
(
ctx
,
input
:
TensorOrTensors
,
batches
,
phony
)
->
TensorOrTensors
:
ctx
.
batches
=
batches
return
input
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grads
)
->
Tuple
:
with
torch
.
no_grad
():
grad_batches
=
microbatch
.
scatter
(
grads
,
len
(
ctx
.
batches
))
for
grad
,
batch
in
reversed
(
list
(
zip
(
grad_batches
,
ctx
.
batches
))):
for
t
in
batch
:
t
.
retain_grad
()
torch
.
autograd
.
backward
(
batch
.
tensor_or_tensors
,
grad_tensors
=
(
*
grad
,),
retain_graph
=
True
)
with
torch
.
no_grad
():
if
ctx
.
batches
[
0
].
atomic
:
tensors
=
tuple
(
b
.
tensor
.
grad
for
b
in
ctx
.
batches
)
output
:
TensorOrTensors
=
torch
.
cat
(
tensors
)
else
:
rotated
=
[[
t
.
grad
for
t
in
b
.
tensors
]
for
b
in
ctx
.
batches
]
output_buf
=
[]
for
tensors
in
zip
(
*
rotated
):
output_buf
.
append
(
torch
.
cat
(
tensors
))
output
=
tuple
(
output_buf
)
del
ctx
.
batches
return
(
output
,
None
,
None
,
None
)
fairscale/nn/pipe/multiprocess_pipeline.py
deleted
100644 → 0
View file @
e141a93e
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The multiprocess pipeline parallelism of Pipe."""
import
os
from
queue
import
Empty
as
QueueEmpty
from
queue
import
Queue
from
types
import
TracebackType
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
from
torch
import
Tensor
,
nn
from
torch.autograd.profiler
import
record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.skip
import
Namespace
from
.skip.layout
import
SkipLayout
from
.skip.tracker
import
SkipTrackerThroughPotals
,
use_skip_tracker
from
.types
import
ACTIVATIONS_GRADS_QUEUE
,
PORTAL_QUEUE
,
SKIP_TENSOR_QUEUE
,
PipeMessage
,
TensorOrTensors
,
Tensors
from
.worker
import
Task
# Queue is generic only in stubs.
# https://mypy.readthedocs.io/en/latest/common_issues.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime
if
TYPE_CHECKING
:
InQueue
=
Queue
[
Optional
[
Task
]]
OutQueue
=
Queue
[
Tuple
[
bool
,
Union
[
Tuple
[
Task
,
Batch
],
ExcInfo
,
None
]]]
else
:
InQueue
=
Queue
OutQueue
=
Queue
__all__
:
List
[
str
]
=
[]
ExcInfo
=
Tuple
[
Type
[
BaseException
],
BaseException
,
TracebackType
]
class
SendOperator
(
torch
.
autograd
.
Function
):
"""Send activations to the next pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
transport
:
Transport
,
input
:
List
[
Tensor
],
index
:
int
)
->
Tensors
:
ranks
=
get_pipeline_parallel_ranks
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
+
1
]
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
index
,
tensors
=
tuple
(
input
)),
)
return
()
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tensors
:
return
tuple
(
grad
)
class
RecvOperator
(
torch
.
autograd
.
Function
):
"""Receive activations to the previous pipeline stage"""
@
staticmethod
# type: ignore
def
forward
(
ctx
,
tensor
:
Tensor
,
transport
:
Transport
,
index
:
int
)
->
Tensors
:
ctx
.
transport
=
transport
ctx
.
index
=
index
result
=
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
index
)
def
maybe_requires_grad
(
t
:
Tensor
)
->
Tensor
:
if
t
.
dtype
.
is_floating_point
:
return
t
.
requires_grad_
()
return
t
return
tuple
(
maybe_requires_grad
(
r
)
for
r
in
result
)
@
staticmethod
# type: ignore
def
backward
(
ctx
,
*
grad
:
Tensor
,)
->
Tuple
[
Optional
[
Tensor
],
...]:
ranks
=
get_pipeline_parallel_ranks
()
src_rank
=
torch
.
distributed
.
get_rank
()
dst_rank
=
ranks
[
ranks
.
index
(
src_rank
)
-
1
]
ctx
.
transport
.
send_message
(
PipeMessage
(
src_rank
,
dst_rank
,
queue_name
=
ACTIVATIONS_GRADS_QUEUE
,
args
=
ctx
.
index
,
tensors
=
tuple
(
grad
),),
)
return
(
None
,
None
,
None
,
None
)
class
MultiProcessPipeline
:
"""The multiprocess pipeline parallelism for Pipe."""
def
__init__
(
self
,
partition
:
nn
.
Sequential
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
group
:
torch
.
distributed
.
ProcessGroup
,
*
,
worker_map
:
Optional
[
Dict
[
int
,
str
]]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
)
->
None
:
self
.
partition
=
partition
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
group
=
group
self
.
training
:
bool
self
.
transport
=
MakeTransport
(
use_rpc
=
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
)
or
(
"FORCE_RPC"
in
os
.
environ
),
worker_map
=
worker_map
,
input_device
=
input_device
,
)
self
.
input_device
=
input_device
self
.
final_stage
=
final_stage
@
property
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
training
=
self
.
partition
.
training
if
not
training
:
return
0
return
self
.
__checkpoint_stop
def
run
(
self
,
training
:
bool
,
batches
:
List
[
Batch
])
->
None
:
"""Runs pipeline parallelism.
It modifies the given batches in place.
"""
self
.
training
=
training
m
=
len
(
batches
)
skip_trackers
=
[
SkipTrackerThroughPotals
(
self
.
skip_layout
,
i
)
for
i
in
range
(
m
)]
rank
=
self
.
group
.
rank
()
for
i
in
range
(
m
):
if
rank
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
batch
=
batches
[
i
]
with
use_skip_tracker
(
skip_trackers
[
i
]),
record_function
(
"chunk%d-part%d"
%
(
i
,
rank
)):
if
i
<
self
.
checkpoint_stop
:
chk
=
Checkpointing
(
self
.
partition
,
batch
)
batch
=
chk
.
checkpoint
()
else
:
batch
=
batch
.
call
(
self
.
partition
)
if
not
self
.
final_stage
:
self
.
send_skip_tensors
(
batch
,
i
,
skip_trackers
)
SendOperator
.
apply
(
self
.
transport
,
[
*
batch
],
i
)
for
portal
in
skip_trackers
[
i
].
portals
.
values
():
portal
.
pipeline
=
self
if
i
<
self
.
checkpoint_stop
:
chk
.
recompute
(
batch
)
batches
[
i
]
=
batch
def
get_batch_from_previous_stage
(
self
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
]
)
->
Batch
:
phony
=
torch
.
empty
(
0
,
device
=
self
.
input_device
,
requires_grad
=
True
)
result
=
RecvOperator
.
apply
(
phony
,
self
.
transport
,
i
)
if
len
(
result
)
==
1
:
batch
=
Batch
(
result
[
0
],
i
)
else
:
batch
=
Batch
(
result
,
i
)
self
.
recv_skip_tensors
(
skip_trackers
,
batches
)
return
batch
def
send_skip_tensors
(
self
,
batch
:
Batch
,
i
:
int
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
])
->
None
:
ranks
=
get_pipeline_parallel_ranks
()
this_rank
=
torch
.
distributed
.
get_rank
()
for
next_j
,
ns
,
name
in
self
.
skip_layout
.
copy_policy_by_src
(
self
.
group
.
rank
()):
life
=
skip_trackers
[
i
].
portals
[(
ns
,
name
)].
tensor_life
loaded
=
skip_trackers
[
i
].
load
(
batch
,
ns
,
name
)
if
loaded
is
not
None
:
tensors
=
tuple
([
loaded
])
else
:
tensors
=
tuple
()
self
.
transport
.
send_message
(
PipeMessage
(
this_rank
,
ranks
[
next_j
],
queue_name
=
SKIP_TENSOR_QUEUE
,
args
=
(
i
,
ns
,
name
,
life
),
tensors
=
tensors
,
),
sync
=
True
,
)
def
recv_skip_tensors
(
self
,
skip_trackers
:
List
[
SkipTrackerThroughPotals
],
batches
:
List
[
Batch
])
->
None
:
while
True
:
try
:
message
=
self
.
transport
.
recv_message
(
SKIP_TENSOR_QUEUE
,
nowait
=
True
)
(
si
,
ns
,
name
,
life
)
=
message
.
args
value
:
Optional
[
TensorOrTensors
]
=
message
.
tensors
assert
isinstance
(
value
,
tuple
)
if
len
(
value
)
==
0
:
value
=
None
else
:
assert
len
(
value
)
==
1
value
=
value
[
0
]
skip_trackers
[
si
].
save
(
batches
[
si
],
ns
,
name
,
value
)
old_life
=
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
if
life
!=
0
:
skip_trackers
[
si
].
portals
[(
ns
,
name
)].
tensor_life
=
life
except
QueueEmpty
:
break
def
send_portal_grad
(
self
,
ns_name
:
Tuple
[
Namespace
,
str
],
index
:
int
,
grad
:
TensorOrTensors
)
->
None
:
dest
,
src
=
self
.
skip_layout
.
by_ns_name
.
get
(
ns_name
,
(
-
1
,
-
1
))
if
dest
==
src
:
return
ranks
=
get_pipeline_parallel_ranks
()
dst_rank
=
ranks
[
dest
]
if
dst_rank
==
torch
.
distributed
.
get_rank
():
return
if
isinstance
(
grad
,
Tensor
):
grad
=
tuple
([
grad
])
self
.
transport
.
send_message
(
PipeMessage
(
ranks
[
src
],
dst_rank
,
queue_name
=
PORTAL_QUEUE
,
args
=
(
ns_name
,
index
),
tensors
=
grad
),
sync
=
True
,
)
def
recv_portal_grad
(
self
,
expected_ns_name
:
Tuple
[
Namespace
,
str
],
expected_index
:
int
)
->
Tensor
:
message
=
self
.
transport
.
recv_message
(
PORTAL_QUEUE
)
(
ns_name
,
index
)
=
message
.
args
grad
=
message
.
tensors
assert
len
(
grad
)
==
1
result
=
grad
[
0
]
assert
index
==
expected_index
and
ns_name
==
expected_ns_name
return
result
def
back_helper
(
self
,
output
:
List
[
Batch
])
->
None
:
tensors
:
Tensors
rank
=
torch
.
distributed
.
get_rank
()
for
batch
in
reversed
(
output
):
found
=
self
.
transport
.
get_out_of_order
(
ACTIVATIONS_GRADS_QUEUE
,
batch
.
index
)
if
batch
.
atomic
:
tensors
=
tuple
([
batch
.
tensor
])
else
:
tensors
=
batch
.
tensors
if
len
(
found
)
!=
len
(
tensors
):
raise
RuntimeError
(
"different number of tensors and gradients"
)
grads
=
[]
final_tensors
=
[]
for
i
,
tensor
in
enumerate
(
tensors
):
if
tensor
.
requires_grad
or
getattr
(
tensor
,
"grad_fn"
,
None
)
is
not
None
:
grads
.
append
(
found
[
i
])
final_tensors
.
append
(
tensor
)
try
:
torch
.
autograd
.
backward
(
final_tensors
,
grad_tensors
=
grads
,
retain_graph
=
True
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Autograd failed on
{
torch
.
distributed
.
get_rank
()
}
"
)
from
e
tests/nn/model_parallel/test_layers.py
View file @
2d3d5a7b
...
@@ -20,19 +20,15 @@
...
@@ -20,19 +20,15 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
tempfile
import
pytest
import
pytest
import
torch
import
torch
from
torch
import
nn
from
torch.distributed
import
rpc
import
torch.nn.init
as
init
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.model_parallel
import
layers
from
fairscale.nn.pipe
import
MultiProcessPipe
from
fairscale.utils.testing
import
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
from
fairscale.utils.testing
import
dist_init
,
get_world_sizes
,
set_random_seed
,
spawn_for_all_world_sizes
,
torch_spawn
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
,
filename
,
filename_rpc
):
def
run_test_parallel_embedding
(
rank
,
model_parallel_size
,
filename
,
filename_rpc
):
...
@@ -302,241 +298,6 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r
...
@@ -302,241 +298,6 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_r
print
(
" >> passed the test :-)"
)
print
(
" >> passed the test :-)"
)
def
run_test_pipe
(
rank
,
world_size
,
filename
,
filename_rpc
,
skip_dist_init
=
False
):
pipe_world_size
=
2
if
world_size
==
1
:
return
if
not
skip_dist_init
:
dist_init
(
rank
,
world_size
,
filename
,
filename_rpc
)
else
:
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29502"
rpc
.
init_rpc
(
f
"Test
{
rank
}
"
,
rank
=
rank
,
world_size
=
world_size
)
mpu
.
initialize_model_parallel
(
world_size
/
pipe_world_size
,
pipe_world_size
)
model_parallel_size
=
mpu
.
get_model_parallel_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}"
.
format
(
model_parallel_size
,
pipe_world_size
)
)
chunk_size
=
4
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
3
input_size
=
input_size_coeff
*
model_parallel_size
output_size_coeff
=
7
output_size
=
output_size_coeff
*
model_parallel_size
batch_size
=
3
*
chunk_size
target
=
torch
.
rand
((
batch_size
,
input_size
),
requires_grad
=
True
).
cuda
()
print
(
f
"target =
{
target
}
"
)
identity
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
pipeline_devices
=
mpu
.
get_pipeline_parallel_group
()
set_random_seed
(
seed
)
model
=
nn
.
Sequential
(
layers
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
layers
.
RowParallelLinear
(
output_size
,
input_size
,
keep_master_weight_for_test
=
True
,
bias
=
False
).
cuda
(),
)
set_random_seed
(
seed
)
reference
=
[
nn
.
Linear
(
input_size
,
output_size
,
bias
=
False
).
cuda
(),
nn
.
ReLU
(),
nn
.
Linear
(
output_size
,
input_size
,
bias
=
False
).
cuda
(),
]
print
(
f
"setup
{
reference
[
0
].
weight
.
size
()
}
,
{
model
[
0
].
weight
.
size
()
}
,
{
(
input_size
,
output_size
)
}
"
)
print
(
f
"setup
{
reference
[
2
].
weight
.
size
()
}
,
{
(
output_size
,
input_size
)
}
"
)
reference
[
0
].
weight
=
Parameter
(
model
[
0
].
get_master_weight
().
clone
()).
cuda
()
reference
[
2
].
weight
=
Parameter
(
model
[
2
].
get_master_weight
().
clone
()).
cuda
()
reference
=
nn
.
Sequential
(
*
reference
)
def
grad_graph
(
depth
,
grad
):
result
=
depth
*
" "
+
str
(
grad
)
if
grad
:
for
x
in
grad
.
next_functions
:
result
+=
"
\n
"
+
grad_graph
(
depth
+
1
,
x
[
0
])
return
result
def
check_weights
(
x
,
y
,
key
:
str
,
index
=
None
):
for
i
in
[
2
,
0
]:
if
index
is
not
None
and
i
!=
index
:
continue
left
=
x
[
i
].
get_master_weight
()
right
=
y
[
i
].
weight
.
data
if
not
torch
.
allclose
(
left
,
right
,
atol
=
1.0e-6
)
or
index
is
not
None
:
print
(
f
"check_weights
{
key
}
-
{
i
}
: left =
{
left
}
,
\n
right =
{
right
}
"
)
if
not
torch
.
equal
(
left
,
right
):
print
(
f
"check_weights NOT_EQUAL
{
key
}
-
{
i
}
: left =
{
left
}
,
\n
right =
{
right
}
"
)
assert
torch
.
allclose
(
left
,
right
,
atol
=
1.0e-6
)
def
dump_opt_params
(
opt
):
for
i
,
group
in
enumerate
(
opt
.
param_groups
):
for
j
,
p
in
enumerate
(
group
[
"params"
]):
print
(
f
"
{
torch
.
distributed
.
get_rank
()
}
:param
{
(
i
,
j
)
}
=
{
p
}
"
)
print
(
f
"
{
torch
.
distributed
.
get_rank
()
}
:param.grad
{
(
i
,
j
)
}
=
{
p
.
grad
}
"
)
def
forward_model
(
model_
,
target
,
step
=
False
):
optimizer
=
torch
.
optim
.
SGD
(
model_
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
optimizer
.
zero_grad
()
model_
.
zero_grad
()
output
=
model_
(
identity
())
loss
=
nn
.
MSELoss
()
model_
.
zero_grad
()
if
step
:
loss
(
output
,
target
).
backward
()
saved_weight_0
=
model_
[
0
].
weight
.
data
.
clone
()
saved_weight_2
=
model_
[
2
].
weight
.
data
.
clone
()
dump_opt_params
(
optimizer
)
optimizer
.
step
()
assert
not
torch
.
allclose
(
saved_weight_0
,
model_
[
0
].
weight
.
data
,
atol
=
1.0e-6
)
assert
not
torch
.
allclose
(
saved_weight_2
,
model_
[
2
].
weight
.
data
,
atol
=
1.0e-6
)
return
output
output
=
forward_model
(
model
,
target
)
reference_output
=
forward_model
(
reference
,
target
)
error
=
reference_output
.
sub
(
output
).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
output
=
forward_model
(
model
,
target
)
error
=
reference_output
.
sub
(
output
).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
output
=
forward_model
(
model
,
target
)
error
=
reference_output
.
sub
(
output
).
max
()
torch
.
distributed
.
barrier
()
assert
error
<
1.0e-6
check_weights
(
model
,
reference
,
"before"
)
saved_weight_0
=
model
[
0
].
weight
.
data
.
clone
()
saved_weight_2
=
model
[
2
].
weight
.
data
.
clone
()
output
=
forward_model
(
model
,
target
,
step
=
True
)
error
=
reference_output
.
sub
(
output
).
max
()
assert
error
<
1.0e-6
model
[
0
].
weight
.
data
=
saved_weight_0
model
[
2
].
weight
.
data
=
saved_weight_2
worker_map
=
{
i
:
f
"Test
{
i
}
"
for
i
in
range
(
torch
.
distributed
.
get_world_size
())}
if
pipe_world_size
==
2
:
print
(
"actually doing pipe stuff now"
)
assert
torch
.
equal
(
saved_weight_0
,
model
[
0
].
weight
.
data
)
assert
torch
.
equal
(
saved_weight_2
,
model
[
2
].
weight
.
data
)
pipe_model
=
MultiProcessPipe
(
model
,
[
2
,
1
],
group
=
pipeline_devices
,
worker_map
=
worker_map
,
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
chunk_size
,
).
cuda
()
torch
.
distributed
.
barrier
()
pipe_rank
=
torch
.
distributed
.
get_rank
(
group
=
mpu
.
get_pipeline_parallel_group
())
print
(
f
"pipe rank is
{
pipe_rank
}
"
)
if
pipe_rank
==
0
:
assert
torch
.
equal
(
saved_weight_0
,
pipe_model
[
0
].
weight
.
data
)
else
:
if
not
torch
.
equal
(
saved_weight_2
,
pipe_model
[
0
].
weight
.
data
):
print
(
f
"ne
{
pipe_rank
}
: left
\n
{
saved_weight_2
}
\n
right:
\n
{
pipe_model
[
0
].
weight
.
data
}
"
)
assert
torch
.
equal
(
saved_weight_2
,
pipe_model
[
0
].
weight
.
data
)
optimizer
=
torch
.
optim
.
SGD
(
pipe_model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
)
optimizer
.
zero_grad
()
if
pipe_rank
==
0
:
assert
torch
.
equal
(
saved_weight_0
,
pipe_model
[
0
].
weight
.
data
)
print
(
f
"runner
{
rank
}
:
\n
{
pipe_model
[
0
].
weight
.
data
}
"
)
else
:
assert
torch
.
equal
(
saved_weight_2
,
pipe_model
[
0
].
weight
.
data
)
print
(
f
"runner
{
rank
}
:
\n
{
pipe_model
[
0
].
weight
.
data
}
"
)
if
torch
.
distributed
.
get_rank
(
mpu
.
get_pipeline_parallel_group
())
==
1
:
check_weights
(
model
,
reference
,
"pre-pipe"
,
index
=
2
)
else
:
check_weights
(
model
,
reference
,
"pre-pipe"
,
index
=
0
)
pipe_output
=
pipe_model
(
identity
())
print
(
f
"exited pipe for
{
rank
}
"
)
forward_model
(
reference
,
target
,
step
=
True
)
print
(
f
"pipe_output
{
rank
}
=
{
pipe_output
}
"
)
print
(
f
"reference_output
{
rank
}
=
{
reference_output
}
"
)
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
(
mpu
.
get_pipeline_parallel_group
())
==
1
:
error
=
reference_output
.
sub
(
pipe_output
.
cuda
()).
max
()
if
error
>=
1.0e-6
:
print
(
f
"error bad
{
error
}
"
)
assert
error
<
1.0e-6
loss
=
nn
.
MSELoss
()
failed
=
False
pipe_output
.
retain_grad
()
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
try
:
loss
(
pipe_output
,
target
).
backward
()
except
Exception
as
e
:
failed
=
True
print
(
f
"got
{
e
}
while doing backward, deadlock?"
)
if
failed
:
raise
RuntimeError
(
"failed somehow"
)
dump_opt_params
(
optimizer
)
optimizer
.
step
()
print
(
"calling check_weights on master"
)
check_weights
(
model
,
reference
,
"pipe"
,
index
=
2
)
print
(
f
"waiting for barrier on master, pid=
{
os
.
getpid
()
}
"
)
else
:
print
(
f
"calling backwards on slave, pid=
{
os
.
getpid
()
}
"
)
failed
=
False
with
torch
.
autograd
.
profiler
.
profile
()
as
prof
:
try
:
pipe_model
.
back_helper
(
pipe_output
)
except
Exception
as
e
:
failed
=
True
print
(
f
"got
{
e
}
while doing backward, deadlock?"
)
if
failed
:
raise
RuntimeError
(
"failed somehow"
)
dump_opt_params
(
optimizer
)
print
(
"calling step on slave"
)
optimizer
.
step
()
print
(
"calling check_weights on slave"
)
check_weights
(
model
,
reference
,
"pipe"
,
index
=
0
)
print
(
"waiting for barrier on slave"
)
pipe_model
.
zero_grad
()
torch
.
distributed
.
barrier
()
pipe_model
.
eval
()
pipe_output
=
pipe_model
(
identity
())
updated_ref_output
=
forward_model
(
reference
,
target
)
if
torch
.
distributed
.
get_rank
(
mpu
.
get_pipeline_parallel_group
())
==
1
:
error
=
updated_ref_output
.
sub
(
pipe_output
.
cuda
()).
max
()
print
(
f
"outputs are ref:
\n
{
updated_ref_output
}
\n
pipe:
\n
{
pipe_output
}
"
)
assert
error
<
1.0e-6
torch
.
distributed
.
barrier
()
print
(
f
"finished waiting for barrier on, pid=
{
os
.
getpid
()
}
"
)
print
(
f
"really exited pipe for
{
rank
}
"
)
rpc
.
shutdown
()
torch
.
distributed
.
destroy_process_group
()
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
torch
.
backends
.
cudnn
.
benchmark
=
False
...
@@ -556,35 +317,3 @@ def test_column_parallel():
...
@@ -556,35 +317,3 @@ def test_column_parallel():
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
,
reason
=
"only works on mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
,
reason
=
"only works on mpi"
)
def
test_row_parallel
():
def
test_row_parallel
():
spawn_for_all_world_sizes
(
run_test_row_parallel_linear
)
spawn_for_all_world_sizes
(
run_test_row_parallel_linear
)
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
not
in
os
.
environ
,
reason
=
"only works on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
mpi_pipe
():
mpu
.
destroy_model_parallel
()
_
,
tempfile_init
=
tempfile
.
mkstemp
()
_
,
tempfile_rpc_init
=
tempfile
.
mkstemp
()
run_test_pipe
(
torch
.
distributed
.
get_rank
(),
torch
.
distributed
.
get_world_size
(),
tempfile_init
,
tempfile_rpc_init
,
skip_dist_init
=
True
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
test_pipe_layer
():
world_sizes
=
[
x
for
x
in
get_world_sizes
()
if
x
<=
torch
.
cuda
.
device_count
()
/
2
]
spawn_for_all_world_sizes
(
run_test_pipe
,
args
=
[
False
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skip
(
reason
=
"potential deadlock in nccl with multiple processes using the same gpu"
)
def
test_eight_pipe_layer
():
world_sizes
=
[
x
for
x
in
get_world_sizes
()
if
x
<=
torch
.
cuda
.
device_count
()
/
2
]
spawn_for_all_world_sizes
(
run_test_pipe
,
[
8
])
tests/nn/pipe_process/test_bugs.py
View file @
2d3d5a7b
...
@@ -22,13 +22,13 @@ import torch
...
@@ -22,13 +22,13 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
python_autograd_function
(
pipe_class
):
def
python_autograd_function
(
pipe_class
):
# FIXME deadlock with AsyncPipe?
# FIXME deadlock with AsyncPipe?
# A Python autograd function might fail with this error:
# A Python autograd function might fail with this error:
...
@@ -71,7 +71,7 @@ def python_autograd_function(pipe_class):
...
@@ -71,7 +71,7 @@ def python_autograd_function(pipe_class):
@
torch_spawn
([
3
])
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
exception_no_hang
(
pipe_class
):
def
exception_no_hang
(
pipe_class
):
# In v0.0.2, once a failed partition receives a normal message
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# (non-closing) for the next micro-batch, a hang occured. The reason was
...
@@ -104,7 +104,7 @@ def exception_no_hang(pipe_class):
...
@@ -104,7 +104,7 @@ def exception_no_hang(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"2 cuda devices required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
tuple_wait
(
cuda_sleep
,
pipe_class
):
def
tuple_wait
(
cuda_sleep
,
pipe_class
):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# Under this behavior, if checkpointing was disabled, there's a possibility
...
@@ -157,7 +157,7 @@ def tuple_wait(cuda_sleep, pipe_class):
...
@@ -157,7 +157,7 @@ def tuple_wait(cuda_sleep, pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
parallel_randoms
(
pipe_class
):
def
parallel_randoms
(
pipe_class
):
class
Dropouts
(
nn
.
Module
):
class
Dropouts
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
tests/nn/pipe_process/test_inplace.py
View file @
2d3d5a7b
...
@@ -21,13 +21,13 @@ import pytest
...
@@ -21,13 +21,13 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
inplace_on_requires_grad
(
pipe_class
):
def
inplace_on_requires_grad
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
ReLU
(
inplace
=
True
))
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
checkpoint
=
"always"
)
...
@@ -50,7 +50,7 @@ def inplace_on_requires_grad(pipe_class):
...
@@ -50,7 +50,7 @@ def inplace_on_requires_grad(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
inplace_on_not_requires_grad
(
pipe_class
):
def
inplace_on_not_requires_grad
(
pipe_class
):
# In-place operation on a tensor not requiring grad doesn't cause a
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
# RuntimeError. Currently, we cannot detect this case.
...
@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipe_class):
...
@@ -70,7 +70,7 @@ def inplace_on_not_requires_grad(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
inplace_incorrect_grad
(
pipe_class
):
def
inplace_incorrect_grad
(
pipe_class
):
class
M
(
nn
.
Module
):
class
M
(
nn
.
Module
):
def
forward
(
self
,
foo_bar
):
def
forward
(
self
,
foo_bar
):
...
...
tests/nn/pipe_process/test_pipe.py
View file @
2d3d5a7b
...
@@ -26,17 +26,14 @@ import pytest
...
@@ -26,17 +26,14 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.model_parallel.initialize
import
(
from
fairscale.nn.model_parallel.initialize
import
get_pipeline_parallel_group
destroy_model_parallel
,
from
fairscale.nn.pipe
import
AsyncPipe
get_pipeline_parallel_group
,
from
fairscale.nn.pipe.types
import
LazyModule
initialize_model_parallel
,
)
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
,
torch_version
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
parameters
(
pipe_class
):
def
parameters
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
pipe
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
)
...
@@ -107,7 +104,7 @@ def mpi():
...
@@ -107,7 +104,7 @@ def mpi():
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
public_attrs
(
pipe_class
):
def
public_attrs
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
...
@@ -122,7 +119,7 @@ def public_attrs(pipe_class):
...
@@ -122,7 +119,7 @@ def public_attrs(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
2
],
[
1
,
1
]])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
sequential_like
(
balance
,
pipe_class
):
def
sequential_like
(
balance
,
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -161,7 +158,7 @@ def sequential_like(balance, pipe_class):
...
@@ -161,7 +158,7 @@ def sequential_like(balance, pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
balance_wrong_length
(
pipe_class
):
def
balance_wrong_length
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -176,7 +173,7 @@ def balance_wrong_length(pipe_class):
...
@@ -176,7 +173,7 @@ def balance_wrong_length(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
balance_less_than_1
(
pipe_class
):
def
balance_less_than_1
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -191,7 +188,7 @@ def balance_less_than_1(pipe_class):
...
@@ -191,7 +188,7 @@ def balance_less_than_1(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
chunks_less_than_1
(
pipe_class
):
def
chunks_less_than_1
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
...
@@ -203,7 +200,7 @@ def chunks_less_than_1(pipe_class):
...
@@ -203,7 +200,7 @@ def chunks_less_than_1(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
too_few_devices
(
pipe_class
):
def
too_few_devices
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
),
nn
.
Linear
(
1
,
1
))
...
@@ -213,7 +210,7 @@ def too_few_devices(pipe_class):
...
@@ -213,7 +210,7 @@ def too_few_devices(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
batch_size_indivisible
(
pipe_class
):
def
batch_size_indivisible
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
...
@@ -226,7 +223,7 @@ def batch_size_indivisible(pipe_class):
...
@@ -226,7 +223,7 @@ def batch_size_indivisible(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
batch_size_small
(
pipe_class
):
def
batch_size_small
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
4
)
...
@@ -239,7 +236,7 @@ def batch_size_small(pipe_class):
...
@@ -239,7 +236,7 @@ def batch_size_small(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
checkpoint_mode
(
pipe_class
):
def
checkpoint_mode
(
pipe_class
):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
def
count_grad_fn
(
grad_fn
,
name
,
visited
=
set
()):
if
grad_fn
in
visited
:
if
grad_fn
in
visited
:
...
@@ -273,7 +270,7 @@ def checkpoint_mode(pipe_class):
...
@@ -273,7 +270,7 @@ def checkpoint_mode(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
checkpoint_mode_invalid
(
pipe_class
):
def
checkpoint_mode_invalid
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
...
@@ -284,7 +281,7 @@ def checkpoint_mode_invalid(pipe_class):
...
@@ -284,7 +281,7 @@ def checkpoint_mode_invalid(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
checkpoint_mode_when_chunks_1
(
pipe_class
):
def
checkpoint_mode_when_chunks_1
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
...
@@ -297,7 +294,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
...
@@ -297,7 +294,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
checkpoint_eval
(
pipe_class
):
def
checkpoint_eval
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
...
@@ -326,7 +323,7 @@ def checkpoint_eval(pipe_class):
...
@@ -326,7 +323,7 @@ def checkpoint_eval(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
xfail
(
torch_version
()
<
(
1
,
6
,
0
),
reason
=
"Doesn't work on torch < 1.6.0"
,
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
checkpoint_non_float_input
(
pipe_class
):
def
checkpoint_non_float_input
(
pipe_class
):
class
ForkNonFloat
(
nn
.
Module
):
class
ForkNonFloat
(
nn
.
Module
):
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
...
@@ -344,14 +341,12 @@ def checkpoint_non_float_input(pipe_class):
...
@@ -344,14 +341,12 @@ def checkpoint_non_float_input(pipe_class):
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
# with torch.autograd.detect_anomaly():
# with torch.autograd.detect_anomaly():
output
.
backward
()
output
.
backward
()
elif
pipe_class
==
MultiProcessPipe
:
model
.
back_helper
(
output
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
no_grad
(
pipe_class
):
def
no_grad
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
)
...
@@ -376,7 +371,7 @@ def no_grad(pipe_class):
...
@@ -376,7 +371,7 @@ def no_grad(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
exception
(
pipe_class
):
def
exception
(
pipe_class
):
class
ExpectedException
(
Exception
):
class
ExpectedException
(
Exception
):
pass
pass
...
@@ -396,7 +391,7 @@ def exception(pipe_class):
...
@@ -396,7 +391,7 @@ def exception(pipe_class):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
device_count
()
<
4
,
reason
=
"Not enough GPUs"
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
xfail
(
strict
=
True
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
exception_early_stop_asap
(
pipe_class
):
def
exception_early_stop_asap
(
pipe_class
):
"""Even the first partitions have finished to process, the partition before
"""Even the first partitions have finished to process, the partition before
the failed partition hould be killed as soon as possible.
the failed partition hould be killed as soon as possible.
...
@@ -435,7 +430,7 @@ def exception_early_stop_asap(pipe_class):
...
@@ -435,7 +430,7 @@ def exception_early_stop_asap(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
input_pair
(
pipe_class
):
def
input_pair
(
pipe_class
):
class
Two
(
nn
.
Module
):
class
Two
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -462,7 +457,7 @@ def input_pair(pipe_class):
...
@@ -462,7 +457,7 @@ def input_pair(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
input_singleton
(
pipe_class
):
def
input_singleton
(
pipe_class
):
class
One
(
nn
.
Module
):
class
One
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -487,7 +482,7 @@ def input_singleton(pipe_class):
...
@@ -487,7 +482,7 @@ def input_singleton(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
input_varargs
(
pipe_class
):
def
input_varargs
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
())
...
@@ -501,7 +496,7 @@ def input_varargs(pipe_class):
...
@@ -501,7 +496,7 @@ def input_varargs(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
non_tensor
(
pipe_class
):
def
non_tensor
(
pipe_class
):
class
NonTensor
(
nn
.
Module
):
class
NonTensor
(
nn
.
Module
):
def
forward
(
self
,
_
):
def
forward
(
self
,
_
):
...
@@ -521,7 +516,7 @@ def non_tensor(pipe_class):
...
@@ -521,7 +516,7 @@ def non_tensor(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
non_tensor_tuple
(
pipe_class
):
def
non_tensor_tuple
(
pipe_class
):
class
NonTensorTuple
(
nn
.
Module
):
class
NonTensorTuple
(
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -543,7 +538,7 @@ def non_tensor_tuple(pipe_class):
...
@@ -543,7 +538,7 @@ def non_tensor_tuple(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipe_class
):
def
deferred_batch_norm
(
checkpoint
,
lazy
,
pipe_class
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
...
@@ -567,7 +562,7 @@ def deferred_batch_norm(checkpoint, lazy, pipe_class):
...
@@ -567,7 +562,7 @@ def deferred_batch_norm(checkpoint, lazy, pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"lazy"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipe_class
):
def
deferred_batch_norm_params
(
checkpoint
,
lazy
,
pipe_class
):
bn
=
nn
.
BatchNorm2d
(
3
)
bn
=
nn
.
BatchNorm2d
(
3
)
pipe_bn
=
deepcopy
(
bn
)
pipe_bn
=
deepcopy
(
bn
)
...
@@ -592,7 +587,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
...
@@ -592,7 +587,7 @@ def deferred_batch_norm_params(checkpoint, lazy, pipe_class):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
devices
(
pipe_class
):
def
devices
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -608,7 +603,7 @@ def devices(pipe_class):
...
@@ -608,7 +603,7 @@ def devices(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
partitions
(
pipe_class
):
def
partitions
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -626,7 +621,7 @@ def partitions(pipe_class):
...
@@ -626,7 +621,7 @@ def partitions(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
deny_moving
(
pipe_class
):
def
deny_moving
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
@@ -650,7 +645,7 @@ def deny_moving(pipe_class):
...
@@ -650,7 +645,7 @@ def deny_moving(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
empty_module
(
pipe_class
):
def
empty_module
(
pipe_class
):
# Empty sequential module is not illegal.
# Empty sequential module is not illegal.
model
=
nn
.
Sequential
()
model
=
nn
.
Sequential
()
...
@@ -666,7 +661,7 @@ def empty_module(pipe_class):
...
@@ -666,7 +661,7 @@ def empty_module(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
@
pytest
.
mark
.
skip
(
reason
=
"TODO(msb) handle named_children"
)
@
pytest
.
mark
.
skip
(
reason
=
"TODO(msb) handle named_children"
)
def
named_children
(
pipe_class
):
def
named_children
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
...
@@ -688,7 +683,7 @@ def named_children(pipe_class):
...
@@ -688,7 +683,7 @@ def named_children(pipe_class):
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
recommend_auto_balance
(
pipe_class
):
def
recommend_auto_balance
(
pipe_class
):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
# module and sum of balance have differen length (module: 0, sum of balance: 1)
# module and sum of balance have differen length (module: 0, sum of balance: 1)
...
@@ -700,7 +695,7 @@ def recommend_auto_balance(pipe_class):
...
@@ -700,7 +695,7 @@ def recommend_auto_balance(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
lazy_construction
(
pipe_class
):
def
lazy_construction
(
pipe_class
):
init_count
=
0
init_count
=
0
...
@@ -730,7 +725,7 @@ def lazy_construction(pipe_class):
...
@@ -730,7 +725,7 @@ def lazy_construction(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"doesn't apply to mpi"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
missing_worker_map
(
pipe_class
):
def
missing_worker_map
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
...
@@ -740,7 +735,7 @@ def missing_worker_map(pipe_class):
...
@@ -740,7 +735,7 @@ def missing_worker_map(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
skip
(
reason
=
"currently broken"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipe_class
):
def
verify_module_duplicate_parameters_on_distinct_partitions
(
pipe_class
):
class
Surrogate
(
nn
.
Module
):
class
Surrogate
(
nn
.
Module
):
def
__init__
(
self
,
module
):
def
__init__
(
self
,
module
):
...
@@ -755,24 +750,6 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
...
@@ -755,24 +750,6 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
])
def
pipelined_backward
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
destroy_model_parallel
()
initialize_model_parallel
(
1
,
4
)
pipe
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
False
destroy_model_parallel
()
initialize_model_parallel
(
2
,
2
)
pipe
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
assert
pipe
.
pipelined_backward
is
True
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
def
async_event_loop
():
def
async_event_loop
():
...
...
tests/nn/pipe_process/test_transparency.py
View file @
2d3d5a7b
...
@@ -21,13 +21,13 @@ import pytest
...
@@ -21,13 +21,13 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
from
fairscale.nn.pipe
import
AsyncPipe
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
from
fairscale.utils.testing
import
get_worker_map
,
set_random_seed
,
torch_spawn
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
AsyncPipe
])
def
simple_linears
(
pipe_class
):
def
simple_linears
(
pipe_class
):
def
sum_grad
(
parameters
):
def
sum_grad
(
parameters
):
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
return
sum
([
p
.
grad
.
sum
()
for
p
in
parameters
if
p
.
grad
is
not
None
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment