Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e3a20fef
Unverified
Commit
e3a20fef
authored
Feb 03, 2021
by
msbaines
Committed by
GitHub
Feb 03, 2021
Browse files
[refactor] multiprocess_pipe: focus on LazyModule usage (#360)
parent
d624b81a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
422 deletions
+12
-422
.circleci/config.yml
.circleci/config.yml
+1
-1
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+8
-88
tests/nn/pipe_process/skip/__init__.py
tests/nn/pipe_process/skip/__init__.py
+0
-18
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+0
-180
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+0
-133
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-2
No files found.
.circleci/config.yml
View file @
e3a20fef
...
@@ -172,7 +172,7 @@ run_mp_pipe_benchmark: &run_mp_pipe_benchmark
...
@@ -172,7 +172,7 @@ run_mp_pipe_benchmark: &run_mp_pipe_benchmark
-
run
:
-
run
:
name
:
Run Multiprocess Pipe Benchmark
name
:
Run Multiprocess Pipe Benchmark
command
:
|
command
:
|
python benchmarks/pipe.py --multiprocess
python benchmarks/pipe.py --multiprocess
--lazy-construction
run_oss_benchmark
:
&run_oss_benchmark
run_oss_benchmark
:
&run_oss_benchmark
-
run
:
-
run
:
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
e3a20fef
...
@@ -35,8 +35,7 @@ from .async_schedule import Location, ModuleWrapper
...
@@ -35,8 +35,7 @@ from .async_schedule import Location, ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.layout
import
SkipLayout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
from
.types
import
LazyModule
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
...
@@ -68,43 +67,6 @@ def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[
...
@@ -68,43 +67,6 @@ def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[
raise
ValueError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
raise
ValueError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
List
[
int
])
->
List
[
nn
.
Sequential
]:
"""Splits a module into multiple partitions.
Returns:
partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
j
=
0
partitions
=
[]
layers
:
NamedModules
=
OrderedDict
()
for
name
,
layer
in
module
.
named_children
():
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
# Group buffered layers as a partition.
partition
=
nn
.
Sequential
(
layers
)
partitions
.
append
(
partition
)
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
return
partitions
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
...
@@ -225,11 +187,6 @@ class MultiProcessPipe(Module):
...
@@ -225,11 +187,6 @@ class MultiProcessPipe(Module):
verify_module
(
module
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
check_balance
(
module
,
self
.
balance
)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if
isinstance
(
module
,
nn
.
Sequential
):
verify_skippables
(
module
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
pipelined_backward
=
pipelined_backward
...
@@ -251,11 +208,7 @@ class MultiProcessPipe(Module):
...
@@ -251,11 +208,7 @@ class MultiProcessPipe(Module):
f
"
{
len
(
self
.
balance
)
}
)"
f
"
{
len
(
self
.
balance
)
}
)"
)
)
if
isinstance
(
module
,
nn
.
Sequential
):
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
local_partitions
=
split_module
(
module
,
self
.
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
self
.
final_stage
=
rank
==
len
(
self
.
balance
)
-
1
...
@@ -297,45 +250,12 @@ class MultiProcessPipe(Module):
...
@@ -297,45 +250,12 @@ class MultiProcessPipe(Module):
def
instantiate_partition
(
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
)
->
List
[
ModuleWrapper
]:
layers
:
NamedModules
=
OrderedDict
()
rank
=
group
.
rank
()
first_layer
=
sum
(
balance
[:
rank
])
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
num_layers
=
balance
[
rank
]
if
isinstance
(
layer
,
nn
.
Module
):
layers
=
module
[
first_layer
:
first_layer
+
num_layers
]
return
layer
instantiated_layers
=
[
l
if
isinstance
(
l
,
nn
.
Module
)
else
l
()
for
l
in
layers
]
elif
callable
(
layer
):
return
[
ModuleWrapper
(
nn
.
Sequential
(
*
instantiated_layers
),
Location
(
rank
,
0
))]
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
"""Counts the length of the underlying sequential module."""
...
...
tests/nn/pipe_process/skip/__init__.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tests/nn/pipe_process/skip/test_gpipe.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
def
x1to3
(
balance
,
checkpoint
,
pipe_class
):
torch
.
manual_seed
(
0
)
if
pipe_class
==
AsyncPipe
and
len
(
balance
)
>
1
:
print
(
f
"skipping yarg"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
yield
stash
(
"1to3"
,
input
)
output
=
self
.
conv
(
input
)
return
output
class
Layer2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
output
=
self
.
conv
(
input
)
return
output
@
skippable
(
pop
=
[
"1to3"
])
class
Layer3
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
skip_1to3
=
yield
pop
(
"1to3"
)
output
=
self
.
conv
(
input
)
+
skip_1to3
return
output
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
pipe_class
(
model
,
balance
,
chunks
=
3
,
checkpoint
=
checkpoint
,
input_device
=
torch
.
cuda
.
current_device
(),
worker_map
=
get_worker_map
(),
pipelined_backward
=
False
,
).
cuda
()
input
=
torch
.
rand
(
30
,
3
,
224
,
224
,
requires_grad
=
True
).
cuda
()
input
.
retain_grad
()
output
=
model
(
input
)
if
model
.
group
.
rank
()
==
len
(
balance
)
-
1
:
loss
=
output
.
mean
()
loss
.
backward
()
elif
model
.
group
.
rank
()
<
len
(
balance
)
-
1
:
model
.
back_helper
(
output
)
if
model
.
group
.
rank
()
==
len
(
balance
)
-
1
:
# TODO(tom) the single-process test uses 2e-1 but for some reason
# mutli-process is more noisy, need to investigate why
assert
torch
.
allclose
(
output
.
norm
(),
torch
.
tensor
(
1039.0
).
cuda
(),
atol
=
4e-1
)
if
model
.
group
.
rank
()
==
0
:
assert
torch
.
allclose
(
input
.
grad
.
norm
(),
torch
.
tensor
(
0.0004533053
).
cuda
())
torch
.
distributed
.
barrier
()
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
def
none_skip
(
pipe_class
):
if
pipe_class
==
AsyncPipe
:
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
@
skippable
(
stash
=
[
"none"
])
class
Stash
(
nn
.
Module
):
def
forward
(
self
,
input
):
yield
stash
(
"none"
,
None
)
return
input
@
skippable
(
pop
=
[
"none"
])
class
Pop
(
nn
.
Module
):
def
forward
(
self
,
input
):
none
=
yield
pop
(
"none"
)
assert
none
is
None
return
input
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
input
.
retain_grad
()
output
=
model
(
input
)
def
assert_grad_fn_is_not_portal
(
grad_fn
,
visited
=
set
()):
if
grad_fn
in
visited
or
grad_fn
is
None
:
return
assert
not
isinstance
(
grad_fn
,
PortalBlue
.
_backward_cls
)
assert
not
isinstance
(
grad_fn
,
PortalCopy
.
_backward_cls
)
assert
not
isinstance
(
grad_fn
,
PortalOrange
.
_backward_cls
)
visited
.
add
(
grad_fn
)
for
next_grad_fn
,
_
in
grad_fn
.
next_functions
:
assert_grad_fn_is_not_portal
(
next_grad_fn
,
visited
)
if
model
.
group
.
rank
()
==
1
:
assert_grad_fn_is_not_portal
(
output
.
grad_fn
)
output
.
sum
().
backward
()
else
:
model
.
back_helper
(
output
)
assert
input
.
grad
.
mean
().
item
()
==
1
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
lazy_skippable_error
(
pipe_class
):
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Linear
):
pass
@
skippable
(
pop
=
[
"1to3"
])
class
Layer3
(
nn
.
Linear
):
pass
model
=
[
LazyModule
(
lambda
:
Layer1
(
10
,
10
)),
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
)),
LazyModule
(
lambda
:
Layer3
(
10
,
10
)),
]
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
pipe_class
(
model
,
[
2
,
1
],
worker_map
=
get_worker_map
(),
)
tests/nn/pipe_process/skip/test_leak.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
skippable
(
stash
=
[
"skip"
])
class
Stash
(
nn
.
Module
):
def
forward
(
self
,
input
):
yield
stash
(
"skip"
,
input
)
return
input
@
skippable
(
pop
=
[
"skip"
])
class
Pop
(
nn
.
Module
):
def
forward
(
self
,
input
):
skip
=
yield
pop
(
"skip"
)
return
input
+
skip
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe_class
):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
if
pipe_class
==
AsyncPipe
:
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
if
skip_tracker
is
None
:
skip_tracker
=
current_skip_tracker
()
# Get the current portal.
portal
=
list
(
skip_tracker
.
portals
.
values
())[
0
]
if
tensor_life
==
0
:
return
portal
.
tensor_life
==
0
and
portal
.
tensor
is
None
else
:
return
portal
.
tensor_life
==
tensor_life
and
portal
.
tensor
is
not
None
# Check the portal tensor after 'Stash'.
stash_
=
Stash
()
@
stash_
.
register_forward_hook
def
check_portal_tensor_after_stash
(
*
_
):
if
is_checkpointing
():
assert
portal_tensor_life_is
(
2
)
elif
is_recomputing
():
assert
portal_tensor_life_is
(
0
)
else
:
assert
portal_tensor_life_is
(
1
)
pop_
=
Pop
()
@
pop_
.
register_forward_hook
def
check_portal_tensor_after_pop
(
*
_
):
if
is_checkpointing
():
assert
portal_tensor_life_is
(
1
)
elif
is_recomputing
():
assert
portal_tensor_life_is
(
0
)
else
:
assert
portal_tensor_life_is
(
0
)
class
NoPortalTensorAtBackward
(
nn
.
Module
):
class
F
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
):
ctx
.
skip_tracker
=
current_skip_tracker
()
return
input
.
detach
()
@
staticmethod
def
backward
(
ctx
,
grad
):
assert
portal_tensor_life_is
(
0
,
skip_tracker
=
ctx
.
skip_tracker
)
return
grad
def
forward
(
self
,
input
):
return
self
.
F
.
apply
(
input
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
pipe_class
(
model
,
balance
=
[
2
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
if
train
:
model
.
train
()
output
=
model
(
input
)
if
model
.
group
.
rank
()
==
1
:
output
.
norm
().
backward
()
else
:
model
.
back_helper
(
output
)
else
:
model
.
eval
()
with
torch
.
no_grad
():
model
(
input
)
torch
.
distributed
.
barrier
()
tests/nn/pipe_process/test_pipe.py
View file @
e3a20fef
...
@@ -629,9 +629,9 @@ def partitions(pipe_class):
...
@@ -629,9 +629,9 @@ def partitions(pipe_class):
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
assert
"0.0.weight"
in
model
.
state_dict
()
assert
model
[
0
].
weight
==
a
.
weight
else
:
else
:
assert
"0.1.weight"
in
model
.
state_dict
()
assert
model
[
0
].
weight
==
b
.
weight
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
@@ -677,6 +677,7 @@ def empty_module(pipe_class):
...
@@ -677,6 +677,7 @@ def empty_module(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skip
(
reason
=
"TODO(msb) handle named_children"
)
def
named_children
(
pipe_class
):
def
named_children
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment