Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b1b9e0f8
Unverified
Commit
b1b9e0f8
authored
Feb 08, 2021
by
msbaines
Committed by
GitHub
Feb 08, 2021
Browse files
[refactor] remove multiprocess dependency on async (#373)
parent
08c10993
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
24 deletions
+15
-24
fairscale/nn/pipe/async_pipe.py
fairscale/nn/pipe/async_pipe.py
+4
-0
fairscale/nn/pipe/async_schedule.py
fairscale/nn/pipe/async_schedule.py
+1
-4
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+1
-5
fairscale/nn/pipe/multiprocess_pipeline.py
fairscale/nn/pipe/multiprocess_pipeline.py
+4
-8
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-5
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
+2
-2
No files found.
fairscale/nn/pipe/async_pipe.py
View file @
b1b9e0f8
...
@@ -192,6 +192,8 @@ class AsyncPipe(Module):
...
@@ -192,6 +192,8 @@ class AsyncPipe(Module):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
pipeline
=
None
self
.
pipeline
=
None
# TODO(msb) remove this hack
self
.
partition
=
None
else
:
else
:
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
if
deferred_batch_norm
:
...
@@ -200,6 +202,8 @@ class AsyncPipe(Module):
...
@@ -200,6 +202,8 @@ class AsyncPipe(Module):
for
name
,
part
in
enumerate
(
self
.
partitions
):
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
create_pipeline
()
self
.
create_pipeline
()
# TODO(msb) remove this hack
self
.
partition
=
self
.
partitions
[
0
].
module
del
module
del
module
...
...
fairscale/nn/pipe/async_schedule.py
View file @
b1b9e0f8
...
@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
...
@@ -17,6 +17,7 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_ranks
from
.messages
import
Transport
from
.messages
import
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
from
.multiprocess_pipeline
import
create_task
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.skip.tracker
import
SkipTrackerThroughPotals
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
from
.types
import
EVENT_LOOP_QUEUE
,
PipeMessage
,
Tensors
...
@@ -191,10 +192,6 @@ class AsyncEventLoop:
...
@@ -191,10 +192,6 @@ class AsyncEventLoop:
"""Actually run the forward pass for a given module, and send the result
"""Actually run the forward pass for a given module, and send the result
to the next stage in the pipeline if needed."""
to the next stage in the pipeline if needed."""
# We import here to avoid a cyclic dependency.
# TODO(msb) Break the cyclic dependency.
from
.multiprocess_pipeline
import
create_task
task
=
create_task
(
task
=
create_task
(
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
self
.
checkpoint_stop
,
batch
.
index
,
self
.
group
.
rank
(),
batch
,
partition
.
module
,
skip_trackers
,
)
)
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
b1b9e0f8
...
@@ -31,7 +31,6 @@ import torch.cuda
...
@@ -31,7 +31,6 @@ import torch.cuda
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
fairscale.nn.model_parallel
import
get_model_parallel_world_size
,
get_pipeline_parallel_group
from
.
import
microbatch
from
.
import
microbatch
from
.async_schedule
import
Location
,
ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
...
@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
...
@@ -219,9 +218,6 @@ class MultiProcessPipe(Module):
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
add_module
(
str
(
0
),
self
.
partition
)
self
.
create_pipeline
()
self
.
create_pipeline
()
# TODO(msb) Remove this hack at some point.
self
.
partitions
=
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
del
module
del
module
def
create_pipeline
(
self
)
->
None
:
def
create_pipeline
(
self
)
->
None
:
...
@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
...
@@ -229,7 +225,7 @@ class MultiProcessPipe(Module):
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
checkpoint_stop
=
{
"always"
:
self
.
chunks
,
"except_last"
:
self
.
chunks
-
1
,
"never"
:
0
}[
self
.
checkpoint
]
self
.
pipeline
=
MultiProcessPipeline
(
self
.
pipeline
=
MultiProcessPipeline
(
[
ModuleWrapper
(
self
.
partition
,
Location
(
self
.
group
.
rank
(),
0
))]
,
self
.
partition
,
self
.
_skip_layout
,
self
.
_skip_layout
,
checkpoint_stop
,
checkpoint_stop
,
group
=
self
.
group
,
group
=
self
.
group
,
...
...
fairscale/nn/pipe/multiprocess_pipeline.py
View file @
b1b9e0f8
...
@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
...
@@ -30,7 +30,6 @@ from torch.autograd.profiler import record_function
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_ranks
from
.async_schedule
import
ModuleWrapper
from
.checkpoint
import
Checkpointing
from
.checkpoint
import
Checkpointing
from
.messages
import
MakeTransport
,
Transport
from
.messages
import
MakeTransport
,
Transport
from
.microbatch
import
Batch
from
.microbatch
import
Batch
...
@@ -162,7 +161,7 @@ class MultiProcessPipeline:
...
@@ -162,7 +161,7 @@ class MultiProcessPipeline:
def
__init__
(
def
__init__
(
self
,
self
,
partition
s
:
List
[
ModuleWrapper
]
,
partition
:
nn
.
Sequential
,
skip_layout
:
SkipLayout
,
skip_layout
:
SkipLayout
,
checkpoint_stop
:
int
,
checkpoint_stop
:
int
,
group
:
torch
.
distributed
.
ProcessGroup
,
group
:
torch
.
distributed
.
ProcessGroup
,
...
@@ -171,7 +170,7 @@ class MultiProcessPipeline:
...
@@ -171,7 +170,7 @@ class MultiProcessPipeline:
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
input_device
:
Union
[
None
,
int
,
str
,
torch
.
device
]
=
None
,
final_stage
:
bool
=
False
,
final_stage
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
partition
s
=
partition
s
self
.
partition
=
partition
self
.
skip_layout
=
skip_layout
self
.
skip_layout
=
skip_layout
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
__checkpoint_stop
=
checkpoint_stop
self
.
group
=
group
self
.
group
=
group
...
@@ -187,7 +186,7 @@ class MultiProcessPipeline:
...
@@ -187,7 +186,7 @@ class MultiProcessPipeline:
@
property
@
property
def
checkpoint_stop
(
self
)
->
int
:
def
checkpoint_stop
(
self
)
->
int
:
# Disable checkpointing if in eval mode.
# Disable checkpointing if in eval mode.
training
=
self
.
partition
s
[
0
].
module
.
training
training
=
self
.
partition
.
training
if
not
training
:
if
not
training
:
return
0
return
0
return
self
.
__checkpoint_stop
return
self
.
__checkpoint_stop
...
@@ -208,15 +207,12 @@ class MultiProcessPipeline:
...
@@ -208,15 +207,12 @@ class MultiProcessPipeline:
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
schedule
=
[(
i
,
self
.
group
.
rank
())
for
i
in
range
(
m
)]
for
i
,
j
in
schedule
:
for
i
,
j
in
schedule
:
assert
len
(
self
.
partitions
)
==
1
partition
=
self
.
partitions
[
0
]
if
self
.
group
.
rank
()
!=
0
:
if
self
.
group
.
rank
()
!=
0
:
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
batch
=
self
.
get_batch_from_previous_stage
(
i
,
skip_trackers
,
batches
)
else
:
else
:
batch
=
batches
[
i
]
batch
=
batches
[
i
]
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
partition
.
module
,
skip_trackers
)
task
=
create_task
(
self
.
checkpoint_stop
,
i
,
j
,
batch
,
self
.
partition
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
batches
[
i
]
=
self
.
execute_task
(
task
,
i
,
skip_trackers
)
...
...
tests/nn/pipe_process/test_pipe.py
View file @
b1b9e0f8
...
@@ -366,8 +366,8 @@ def no_grad(pipe_class):
...
@@ -366,8 +366,8 @@ def no_grad(pipe_class):
nonlocal
latent
nonlocal
latent
latent
=
output
latent
=
output
partition
=
model
.
partition
s
[
0
]
partition
=
model
.
partition
partition
.
module
.
register_forward_hook
(
hook
)
partition
.
register_forward_hook
(
hook
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
(
input
)
model
(
input
)
...
@@ -616,9 +616,7 @@ def partitions(pipe_class):
...
@@ -616,9 +616,7 @@ def partitions(pipe_class):
model
=
nn
.
Sequential
(
a
,
b
)
model
=
nn
.
Sequential
(
a
,
b
)
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
())
assert
isinstance
(
model
.
partitions
,
list
)
assert
isinstance
(
model
.
partition
,
nn
.
Sequential
)
assert
len
(
model
)
==
1
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
assert
model
[
0
].
weight
==
a
.
weight
assert
model
[
0
].
weight
==
a
.
weight
...
...
tests/nn/pipe_process/test_transparency.py
View file @
b1b9e0f8
...
@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
...
@@ -60,13 +60,13 @@ def simple_linears(pipe_class):
if
model
.
group
.
rank
()
==
1
:
if
model
.
group
.
rank
()
==
1
:
loss
=
outputs
.
mean
()
loss
=
outputs
.
mean
()
loss
.
backward
()
loss
.
backward
()
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
1
])
else
:
else
:
model
.
back_helper
(
outputs
)
model
.
back_helper
(
outputs
)
grad_with_pipe
=
sum_grad
(
model
.
p
ipeline
.
partitions
[
0
].
module
.
parameters
())
grad_with_pipe
=
sum_grad
(
model
.
p
artition
.
parameters
())
# Both grads should be identical.
# Both grads should be identical.
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
assert
torch
.
allclose
(
grad_with_pipe
,
grad_without_pipe
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment