Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
42e44149
Unverified
Commit
42e44149
authored
Feb 04, 2021
by
msbaines
Committed by
GitHub
Feb 04, 2021
Browse files
[refactor] multiprocess_pipe: remove pipelined_backward (#362)
parent
7fdd7ecf
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
35 deletions
+17
-35
benchmarks/experimental_ampnet.py
benchmarks/experimental_ampnet.py
+0
-1
benchmarks/pipe.py
benchmarks/pipe.py
+0
-2
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+4
-0
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+5
-15
tests/nn/model_parallel/test_layers.py
tests/nn/model_parallel/test_layers.py
+0
-1
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+8
-16
No files found.
benchmarks/experimental_ampnet.py
View file @
42e44149
...
@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers):
...
@@ -423,7 +423,6 @@ def run_mp_worker(args, available_workers):
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
pipelined_backward
=
False
,
checkpoint
=
args
.
checkpoint
,
checkpoint
=
args
.
checkpoint
,
)
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
...
benchmarks/pipe.py
View file @
42e44149
...
@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers):
...
@@ -523,7 +523,6 @@ def run_mp_worker(args, available_workers):
chunks
=
args
.
chunks
,
chunks
=
args
.
chunks
,
worker_map
=
get_worker_map
(),
worker_map
=
get_worker_map
(),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
input_device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
),
pipelined_backward
=
args
.
pipelined_backward
,
checkpoint
=
args
.
checkpoint
,
checkpoint
=
args
.
checkpoint
,
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
# TODO(anj-s): Do we need to comment this out? loss_fn=benchmark_config["criterion"],
)
)
...
@@ -592,7 +591,6 @@ parser.add_argument(
...
@@ -592,7 +591,6 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
"--checkpoint"
,
default
=
"never"
,
choices
=
[
"always"
,
"except_last"
,
"never"
],
help
=
"Checkpointing strategy for pipe"
"--checkpoint"
,
default
=
"never"
,
choices
=
[
"always"
,
"except_last"
,
"never"
],
help
=
"Checkpointing strategy for pipe"
)
)
parser
.
add_argument
(
"--pipelined-backward"
,
action
=
"store_true"
,
help
=
"Pipelined backward pass"
)
parser
.
add_argument
(
"--use_synthetic_data"
,
action
=
"store_true"
,
help
=
"Uses synthetic data for running benchmarks."
)
parser
.
add_argument
(
"--use_synthetic_data"
,
action
=
"store_true"
,
help
=
"Uses synthetic data for running benchmarks."
)
parser
.
add_argument
(
"--dry_run"
,
action
=
"store_true"
,
help
=
"Run a sample training run without regression testing."
)
parser
.
add_argument
(
"--dry_run"
,
action
=
"store_true"
,
help
=
"Run a sample training run without regression testing."
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
fairscale/nn/pipe/async_pipe.py
View file @
42e44149
...
@@ -39,6 +39,10 @@ class PartitionInfo:
...
@@ -39,6 +39,10 @@ class PartitionInfo:
class
AsyncPipe
(
MultiProcessPipe
):
class
AsyncPipe
(
MultiProcessPipe
):
def
__init__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
42e44149
...
@@ -118,13 +118,6 @@ class MultiProcessPipe(Module):
...
@@ -118,13 +118,6 @@ class MultiProcessPipe(Module):
whether to use deferred BatchNorm moving statistics (default:
whether to use deferred BatchNorm moving statistics (default:
:data:`False`, see :class:`DeferredBatchNorm` for more
:data:`False`, see :class:`DeferredBatchNorm` for more
details)
details)
pipelined_backward (bool, optional):
if True, call torch.autograd.backward once per microbatch on the
backward pass (instead of once for the whole batch). This works
around a potential deadlock in pytorch when using tensor parallelism
at the same time. Defaults to `True` if
`get_model_parallel_world_size() > 1`
(default: `None`)
Raises:
Raises:
TypeError:
TypeError:
...
@@ -174,7 +167,6 @@ class MultiProcessPipe(Module):
...
@@ -174,7 +167,6 @@ class MultiProcessPipe(Module):
chunks
:
int
=
chunks
,
chunks
:
int
=
chunks
,
checkpoint
:
str
=
checkpoint
,
checkpoint
:
str
=
checkpoint
,
deferred_batch_norm
:
bool
=
False
,
deferred_batch_norm
:
bool
=
False
,
pipelined_backward
:
bool
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -183,13 +175,17 @@ class MultiProcessPipe(Module):
...
@@ -183,13 +175,17 @@ class MultiProcessPipe(Module):
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
raise
ValueError
(
"checkpoint is not one of 'always', 'except_last', or 'never'"
)
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
self
.
balance
=
list
(
balance
)
self
.
balance
=
list
(
balance
)
verify_module
(
module
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
check_balance
(
module
,
self
.
balance
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
pipeline
:
Optional
[
MultiProcessPipeline
]
self
.
lock
=
threading
.
Lock
()
self
.
lock
=
threading
.
Lock
()
...
@@ -227,12 +223,6 @@ class MultiProcessPipe(Module):
...
@@ -227,12 +223,6 @@ class MultiProcessPipe(Module):
del
module
del
module
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
else
:
self
.
pipelined_backward
=
False
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
# The micro-batch index where the checkpointing stops.
# The micro-batch index where the checkpointing stops.
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
...
...
tests/nn/model_parallel/test_layers.py
View file @
42e44149
...
@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
...
@@ -443,7 +443,6 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
worker_map
=
worker_map
,
worker_map
=
worker_map
,
input_device
=
torch
.
cuda
.
current_device
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
chunk_size
,
chunks
=
chunk_size
,
pipelined_backward
=
True
,
).
cuda
()
).
cuda
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
pipe_rank
=
torch
.
distributed
.
get_rank
(
group
=
mpu
.
get_pipeline_parallel_group
())
pipe_rank
=
torch
.
distributed
.
get_rank
(
group
=
mpu
.
get_pipeline_parallel_group
())
...
...
tests/nn/pipe_process/test_pipe.py
View file @
42e44149
...
@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class):
...
@@ -259,15 +259,9 @@ def checkpoint_mode(pipe_class):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
always
=
pipe_class
(
always
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,)
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
except_last
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,)
)
never
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,)
except_last
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"except_last"
,
pipelined_backward
=
False
,
)
never
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
"never"
,
pipelined_backward
=
False
,
)
always_output
=
always
(
input
)
always_output
=
always
(
input
)
except_last_output
=
except_last
(
input
)
except_last_output
=
except_last
(
input
)
...
@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
...
@@ -306,7 +300,7 @@ def checkpoint_mode_when_chunks_1(pipe_class):
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
checkpoint_eval
(
pipe_class
):
def
checkpoint_eval
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
input
=
torch
.
rand
(
2
,
1
)
input
=
torch
.
rand
(
2
,
1
)
def
find_grad_fn
(
grad_fn
,
name
):
def
find_grad_fn
(
grad_fn
,
name
):
...
@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class):
...
@@ -343,9 +337,7 @@ def checkpoint_non_float_input(pipe_class):
return
input
[
0
]
*
2
return
input
[
0
]
*
2
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
nn
.
Sequential
(
ForkNonFloat
(),
JoinNonFloat
())
model
=
pipe_class
(
model
=
pipe_class
(
model
,
balance
=
[
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,)
model
,
balance
=
[
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
1
,
checkpoint
=
"always"
,
pipelined_backward
=
False
,
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
input
=
torch
.
rand
(
1
,
requires_grad
=
True
)
output
=
model
(
input
)
output
=
model
(
input
)
...
@@ -456,7 +448,7 @@ def input_pair(pipe_class):
...
@@ -456,7 +448,7 @@ def input_pair(pipe_class):
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
return
(
self
.
fc_a
(
a
),
self
.
fc_b
(
b
))
model
=
nn
.
Sequential
(
Two
())
model
=
nn
.
Sequential
(
Two
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
b
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -482,7 +474,7 @@ def input_singleton(pipe_class):
...
@@ -482,7 +474,7 @@ def input_singleton(pipe_class):
return
(
self
.
fc
(
a
),)
return
(
self
.
fc
(
a
),)
model
=
nn
.
Sequential
(
One
())
model
=
nn
.
Sequential
(
One
())
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
pipelined_backward
=
False
,
)
model
=
pipe_class
(
model
,
balance
=
[
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
a
=
torch
.
rand
(
10
,
1
,
requires_grad
=
True
)
...
@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
...
@@ -766,7 +758,7 @@ def verify_module_duplicate_parameters_on_distinct_partitions(pipe_class):
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
])
def
pipelined_backward
(
pipe_class
):
def
pipelined_backward
(
pipe_class
):
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
model
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment