Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
39675773
Unverified
Commit
39675773
authored
Feb 03, 2021
by
msbaines
Committed by
GitHub
Feb 03, 2021
Browse files
[refactor] multiprocess_pipe: cleanup __init__ (#357)
parent
de713d1e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
24 deletions
+11
-24
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+10
-16
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+1
-8
No files found.
fairscale/nn/pipe/multiprocess_pipe.py
View file @
39675773
...
@@ -220,9 +220,6 @@ class MultiProcessPipe(Module):
...
@@ -220,9 +220,6 @@ class MultiProcessPipe(Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
chunks
=
int
(
chunks
)
checkpoint
=
str
(
checkpoint
)
if
chunks
<=
0
:
if
chunks
<=
0
:
raise
ValueError
(
"number of chunks must be positive integer"
)
raise
ValueError
(
"number of chunks must be positive integer"
)
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
if
checkpoint
not
in
[
"always"
,
"except_last"
,
"never"
]:
...
@@ -259,10 +256,18 @@ class MultiProcessPipe(Module):
...
@@ -259,10 +256,18 @@ class MultiProcessPipe(Module):
f
"
{
len
(
self
.
balance
)
}
)"
f
"
{
len
(
self
.
balance
)
}
)"
)
)
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
=
split_module
(
module
,
self
.
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
if
rank
>=
len
(
self
.
balance
):
if
rank
>=
len
(
self
.
balance
):
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
warnings
.
warn
(
"More ranks than partitions, some ranks unused"
)
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
partitions
:
List
[
ModuleWrapper
]
=
[]
self
.
pipeline
=
None
else
:
else
:
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
self
.
partitions
=
self
.
instantiate_partition
(
module
,
self
.
balance
,
self
.
group
)
if
deferred_batch_norm
:
if
deferred_batch_norm
:
...
@@ -270,21 +275,10 @@ class MultiProcessPipe(Module):
...
@@ -270,21 +275,10 @@ class MultiProcessPipe(Module):
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
part
.
module
=
DeferredBatchNorm
.
convert_deferred_batch_norm
(
part
.
module
,
chunks
)
for
name
,
part
in
enumerate
(
self
.
partitions
):
for
name
,
part
in
enumerate
(
self
.
partitions
):
self
.
add_module
(
str
(
name
),
part
.
module
)
self
.
add_module
(
str
(
name
),
part
.
module
)
if
isinstance
(
module
,
nn
.
Sequential
):
self
.
create_pipeline
()
local_partitions
=
split_module
(
module
,
self
.
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
rank
=
self
.
group
.
rank
()
del
module
if
rank
>=
len
(
self
.
balance
):
self
.
pipeline
=
None
self
.
final_stage
=
False
else
:
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
self
.
create_pipeline
()
del
module
if
self
.
pipelined_backward
is
None
:
if
self
.
pipelined_backward
is
None
:
if
get_model_parallel_world_size
()
>
1
:
if
get_model_parallel_world_size
()
>
1
:
self
.
pipelined_backward
=
True
self
.
pipelined_backward
=
True
...
...
tests/nn/pipe_process/test_pipe.py
View file @
39675773
...
@@ -109,16 +109,9 @@ def mpi():
...
@@ -109,16 +109,9 @@ def mpi():
@
torch_spawn
([
1
])
@
torch_spawn
([
1
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
public_attrs
(
pipe_class
):
def
public_attrs
(
pipe_class
):
class
MyString
:
def
__init__
(
self
,
value
):
self
.
value
=
value
def
__str__
(
self
):
return
self
.
value
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
,
1
))
pipe
=
pipe_class
(
model
,
balance
=
(
1
,),
worker_map
=
get_worker_map
(),
chunks
=
42
.000
,
checkpoint
=
MyString
(
"always"
)
,)
pipe
=
pipe_class
(
model
,
balance
=
(
1
,),
worker_map
=
get_worker_map
(),
chunks
=
42
,
checkpoint
=
"always"
,)
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
balance
==
[
1
]
assert
pipe
.
chunks
==
42
assert
pipe
.
chunks
==
42
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment