Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b1b9e0f8
Unverified
Commit
b1b9e0f8
authored
Feb 08, 2021
by
msbaines
Committed by
GitHub
Feb 08, 2021
Browse files
[refactor] remove multiprocess dependency on async (#373)
parent
08c10993
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
24 deletions
+15
-24
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+4
-0
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+1
-4
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+1
-5
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+4
-8
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-5
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+2
-2
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
b1b9e0f8
...
@@ -192,6 +192,8 @@ class AsyncPipe(Module):
...
@@ -192,6 +192,8 @@ class AsyncPipe(Module):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
pipeline
=
None
self
.
pipeline
=
None
# TODO(msb) remove this hack
self
.
partition
=
None
else
:
else
:
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
if
deferred_batch_norm
:
...
@@ -200,6 +202,8 @@ class AsyncPipe(Module):
...
@@ -200,6 +202,8 @@ class AsyncPipe(Module):
for
name
,
part
in
enumerate
(
self
.
partitions
):
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
create_pipeline
()
self
.
create_pipeline
()
# TODO(msb) remove this hack
self
.
partition
=
self
.
partitions
[
0
].
module
del
module
del
module
...
...
fairscale/nn/pipe/async_schedule.py
View file @
b1b9e0f8
...
@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
...
@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from
.messages
import
Transport
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
...
@@ -191,10 +192,6 @@ class AsyncEventLoop:
...
@@ -191,10 +192,6 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
to the next stage in the pipeline if needed."""
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from
.multiprocess_pipeline
import
create_task
task
=
create_task
(
task
=
create_task
(
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
)
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
b1b9e0f8
...
@@ -31,7 +31,6 @@ import torch.cuda
...
@@ -31,7 +31,6 @@ import torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.
import
microbatch
from
.async_schedule
import
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
...
@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
...
@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
create_pipeline
()
self
.
create_pipeline
()
# TODO(msb) Remove this hack at some point.
self
.
partitions
=
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
del
module
del
module
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
...
@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
...
@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
pipeline
=
MultiProcessPipeline
(
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
,
self
.
partition
,
self
.
_skip_layout
,
self
.
_skip_layout
,
checkpoint_stop
,
checkpoint_stop
,
group
=
self
.
group
,
group
=
self
.
group
,
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
b1b9e0f8
...
@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
...
@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
...
@@ -162,7 +161,7 @@ class MultiProcessPipeline:
...
@@ -162,7 +161,7 @@ class MultiProcessPipeline:
def
__init__
(
def
__init__
(
self
,
self
,
partition
s
:
List
[
ModuleWrapper
]
,
partition
:
nn
.
Sequential
,
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
group
:
torch
.
distributed
.
ProcessGroup
,
group
:
torch
.
distributed
.
ProcessGroup
,
...
@@ -171,7 +170,7 @@ class MultiProcessPipeline:
...
@@ -171,7 +170,7 @@ class MultiProcessPipeline:
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
final_stage
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
partition
s
=
partition
s
self
.
partition
=
partition
self
.
skip_layout
=
skip_layout
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
group
=
group
self
.
group
=
group
...
@@ -187,7 +186,7 @@ class MultiProcessPipeline:
...
@@ -187,7 +186,7 @@ class MultiProcessPipeline:
@
property
@
property
def
checkpoint_stop
(
self
)
->
int
:
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
# Disable checkpointing if in eval mode.
training
=
self
.
partition
s
[
0
].
module
.
training
training
=
self
.
partition
.
training
if
not
training
:
if
not
training
:
return
0
return
0
return
self
.
__checkpoint_stop
return
self
.
__checkpoint_stop
...
@@ -208,15 +207,12 @@ class MultiProcessPipeline:
...
@@ -208,15 +207,12 @@ class MultiProcessPipeline:
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
for
i
,
j
in
schedule
:
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
else
:
batch
=
batches
[
i
]
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
...
...
tests/nn/pipe_process/test_pipe.py
View file @
b1b9e0f8
...
@@ -366,8 +366,8 @@ def no_grad(pipe_class):
...
@@ -366,8 +366,8 @@ def no_grad(pipe_class):
nonlocal
latent
nonlocal
latent
latent
=
output
latent
=
output
partition
=
model
.
partition
s
[
0
]
partition
=
model
.
partition
partition
.
module
.
register_forward_hook
(
hook
)
partition
.
register_forward_hook
(
hook
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
(
input
)
model
(
input
)
...
@@ -616,9 +616,7 @@ def partitions(pipe_class):
...
@@ -616,9 +616,7 @@ def partitions(pipe_class):
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
partitions
,
list
)
assert
isinstance
(
model
.
partition
,
nn
.
Sequential
)
assert
len
(
model
)
==
1
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
assert
model
[
0
].
weight
==
a
.
weight
assert
model
[
0
].
weight
==
a
.
weight
...
...
tests/nn/pipe_process/test_transparency.py
View file @
b1b9e0f8
...
@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
...
@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
else
:
model
.
back_helper
(
outputs
)
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment