Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
42e44149
Unverified
Commit
42e44149
authored
Feb 04, 2021
by
msbaines
Committed by
GitHub
Feb 04, 2021
Browse files
[refactor] multiprocess_pipe: remove pipelined_backward (#362)
parent
7fdd7ecf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
35 deletions
+17
-35
benchmarks/experimental_ampnet.py
benchmarks/experimental_ampnet.py
+0
-1
benchmarks/pipe.py
benchmarks/pipe.py
+0
-2
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+4
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+5
-15
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+0
-1
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+8
-16
No files found.
benchmarks/experimental_ampnet.py
View file @
42e44149
...
@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers):
...
@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers):
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
pipelined_backward
=
False
,
checkpoint
=
args
.
checkpoint
,
checkpoint
=
args
.
checkpoint
,
)
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
benchmarks/pipe.py
View file @
42e44149
...
@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers):
...
@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers):
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
pipelined_backward
=
args
.
pipelined_backward
,
checkpoint
=
args
.
checkpoint
,
checkpoint
=
args
.
checkpoint
,
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
)
)
...
@@ -592,7 +591,6 @@ parser.add_argument(
...
@@ -592,7 +591,6 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
"--checkpoint"
,
default
=
"never"
,
choices
=
[
"always"
,
"except_last"
,
"never"
],
help
=
"Checkpointing strategy for pipe"
"--checkpoint"
,
default
=
"never"
,
choices
=
[
"always"
,
"except_last"
,
"never"
],
help
=
"Checkpointing strategy for pipe"
)
)
parser
.
add_argument
(
"--pipelined-backward"
,
action
=
"store_true"
,
help
=
"Pipelined backward pass"
)
parser
.
add_argument
(
"--use_synthetic_data"
,
action
=
"store_true"
,
help
=
"Uses synthetic data for running benchmarks."
)
parser
.
add_argument
(
"--use_synthetic_data"
,
action
=
"store_true"
,
help
=
"Uses synthetic data for running benchmarks."
)
parser
.
add_argument
(
"--dry_run"
,
action
=
"store_true"
,
help
=
"Run a sample training run without regression testing."
)
parser
.
add_argument
(
"--dry_run"
,
action
=
"store_true"
,
help
=
"Run a sample training run without regression testing."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
fairscale/nn/pipe/async_pipe.py
View file @
42e44149
...
@@ -39,6 +39,10 @@ class PartitionInfo:
...
@@ -39,6 +39,10 @@ class PartitionInfo:
class
AsyncPipe
(
MultiProcessPipe
):
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
42e44149
...
@@ -118,13 +118,6 @@ class MultiProcessPipe(Module):
...
@@ -118,13 +118,6 @@ class MultiProcessPipe(Module):
whether to use deferred BatchNorm moving statistics (default:
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
Raises:
Raises:
TypeError:
TypeError:
...
@@ -174,7 +167,6 @@ class MultiProcessPipe(Module):
...
@@ -174,7 +167,6 @@ class MultiProcessPipe(Module):
chunks
:
int
=
chunks
,
chunks
:
int
=
chunks
,
checkpoint
:
str
=
checkpoint
,
checkpoint
:
str
=
checkpoint
,
deferred_batch_norm
:
bool
=
False
,
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -183,13 +175,17 @@ class MultiProcessPipe(Module):
...
@@ -183,13 +175,17 @@ class MultiProcessPipe(Module):
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
self
.
balance
=
list
(
balance
)
self
.
balance
=
list
(
balance
)
verify_module
(
module
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
check_balance
(
module
,
self
.
balance
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
lock
=
threading
.
Lock
()
self
.
lock
=
threading
.
Lock
()
...
@@ -227,12 +223,6 @@ class MultiProcessPipe(Module):
...
@@ -227,12 +223,6 @@ class MultiProcessPipe(Module):
del
module
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
...
...
tests/nn/model_parallel/test_layers.py
View file @
42e44149
...
@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
worker_map
=
worker_map
,
worker_map
=
worker_map
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
chunk_size
,
chunks
=
chunk_size
,
pipelined_backward
=
True
,
).
cuda
()
).
cuda
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
pipe_rank
=
torch
.
distributed
.
get_rank
(
group
=
mpu
.
get_pipeline_parallel_group
())
pipe_rank
=
torch
.
distributed
.
get_rank
(
group
=
mpu
.
get_pipeline_parallel_group
())
...
...
tests/nn/pipe_process/test_pipe.py
View file @
42e44149
...
@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class):
...
@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
always
=
pipe_class
(
always
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,)
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
except_last
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,)
)
never
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,)
except_last
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
)
never
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,
pipelined_backward
=
False
,
)
always_output
=
always
(
input
)
always_output
=
always
(
input
)
except_last_output
=
except_last
(
input
)
except_last_output
=
except_last
(
input
)
...
@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
...
@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
checkpoint_eval
(
pipe_class
):
def
checkpoint_eval
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
def
find_grad_fn
(
grad_fn
,
name
):
def
find_grad_fn
(
grad_fn
,
name
):
...
@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class):
...
@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class):
return
input
[
0
]
*
2
return
input
[
0
]
*
2
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
pipe_class
(
model
=
pipe_class
(
model
,
balance
=
[
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,)
model
,
balance
=
[
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
output
=
model
(
input
)
output
=
model
(
input
)
...
@@ -456,7 +448,7 @@ def input_pair(pipe_class):
...
@@ -456,7 +448,7 @@ def input_pair(pipe_class):
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
model
=
nn
.
Sequential
(
Two
())
model
=
nn
.
Sequential
(
Two
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -482,7 +474,7 @@ def input_singleton(pipe_class):
...
@@ -482,7 +474,7 @@ def input_singleton(pipe_class):
return
(
self
.
fc
(
a
),)
return
(
self
.
fc
(
a
),)
model
=
nn
.
Sequential
(
One
())
model
=
nn
.
Sequential
(
One
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
...
@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
])
def
pipelined_backward
(
pipe_class
):
def
pipelined_backward
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment