Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e3a20fef
Unverified
Commit
e3a20fef
authored
Feb 03, 2021
by
msbaines
Committed by
GitHub
Feb 03, 2021
Browse files
[refactor] multiprocess_pipe: focus on LazyModule usage (#360)
parent
d624b81a
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
422 deletions
+12
-422
.circleci/config.yml
.circleci/config.yml
+1
-1
fairscale/nn/pipe/multiprocess_pipe.py
fairscale/nn/pipe/multiprocess_pipe.py
+8
-88
tests/nn/pipe_process/skip/__init__.py
tests/nn/pipe_process/skip/__init__.py
+0
-18
tests/nn/pipe_process/skip/test_gpipe.py
tests/nn/pipe_process/skip/test_gpipe.py
+0
-180
tests/nn/pipe_process/skip/test_leak.py
tests/nn/pipe_process/skip/test_leak.py
+0
-133
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
+3
-2
No files found.
.circleci/config.yml
View file @
e3a20fef
...
@@ -172,7 +172,7 @@ run_mp_pipe_benchmark: &run_mp_pipe_benchmark
...
@@ -172,7 +172,7 @@ run_mp_pipe_benchmark: &run_mp_pipe_benchmark
-
run
:
-
run
:
name
:
Run Multiprocess Pipe Benchmark
name
:
Run Multiprocess Pipe Benchmark
command
:
|
command
:
|
python benchmarks/pipe.py --multiprocess
python benchmarks/pipe.py --multiprocess
--lazy-construction
run_oss_benchmark
:
&run_oss_benchmark
run_oss_benchmark
:
&run_oss_benchmark
-
run
:
-
run
:
...
...
fairscale/nn/pipe/multiprocess_pipe.py
View file @
e3a20fef
...
@@ -35,8 +35,7 @@ from .async_schedule import Location, ModuleWrapper
...
@@ -35,8 +35,7 @@ from .async_schedule import Location, ModuleWrapper
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.multiprocess_pipeline
import
MultiProcessPipeline
from
.phony
import
get_phony
from
.phony
import
get_phony
from
.skip.layout
import
SkipLayout
,
inspect_skip_layout
from
.skip.layout
import
SkipLayout
from
.skip.skippable
import
Skippable
,
verify_skippables
from
.types
import
LazyModule
from
.types
import
LazyModule
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
__all__
=
[
"MultiProcessPipe"
,
"LazyModule"
]
...
@@ -68,43 +67,6 @@ def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[
...
@@ -68,43 +67,6 @@ def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[
raise
ValueError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
raise
ValueError
(
f
"all balance numbers must be positive integer (balance:
{
balance
}
)"
)
def
split_module
(
module
:
nn
.
Sequential
,
balance
:
List
[
int
])
->
List
[
nn
.
Sequential
]:
"""Splits a module into multiple partitions.
Returns:
partitions
Partitions are represented as a :class:`~torch.nn.ModuleList` whose
item is a partition. All layers in a partition are placed in the
same device.
Raises:
BalanceError:
wrong balance
IndexError:
the number of devices is fewer than the number of partitions.
"""
j
=
0
partitions
=
[]
layers
:
NamedModules
=
OrderedDict
()
for
name
,
layer
in
module
.
named_children
():
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
# Group buffered layers as a partition.
partition
=
nn
.
Sequential
(
layers
)
partitions
.
append
(
partition
)
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
return
partitions
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
MOVING_DENIED
=
TypeError
(
"denied to move parameters and buffers, because Pipe should manage device placement"
)
...
@@ -225,11 +187,6 @@ class MultiProcessPipe(Module):
...
@@ -225,11 +187,6 @@ class MultiProcessPipe(Module):
verify_module
(
module
)
verify_module
(
module
)
check_balance
(
module
,
self
.
balance
)
check_balance
(
module
,
self
.
balance
)
# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
if
isinstance
(
module
,
nn
.
Sequential
):
verify_skippables
(
module
)
self
.
chunks
=
chunks
self
.
chunks
=
chunks
self
.
checkpoint
=
checkpoint
self
.
checkpoint
=
checkpoint
self
.
pipelined_backward
=
pipelined_backward
self
.
pipelined_backward
=
pipelined_backward
...
@@ -251,10 +208,6 @@ class MultiProcessPipe(Module):
...
@@ -251,10 +208,6 @@ class MultiProcessPipe(Module):
f
"
{
len
(
self
.
balance
)
}
)"
f
"
{
len
(
self
.
balance
)
}
)"
)
)
if
isinstance
(
module
,
nn
.
Sequential
):
local_partitions
=
split_module
(
module
,
self
.
balance
)
self
.
_skip_layout
=
inspect_skip_layout
(
local_partitions
)
else
:
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
self
.
_skip_layout
=
SkipLayout
(
len
(
module
),
{})
# FIXME(tom)
rank
=
self
.
group
.
rank
()
rank
=
self
.
group
.
rank
()
...
@@ -297,45 +250,12 @@ class MultiProcessPipe(Module):
...
@@ -297,45 +250,12 @@ class MultiProcessPipe(Module):
def
instantiate_partition
(
def
instantiate_partition
(
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
self
,
module
:
Union
[
nn
.
Sequential
,
List
[
LazyModule
]],
balance
:
List
[
int
],
group
:
torch
.
distributed
.
ProcessGroup
,
)
->
List
[
ModuleWrapper
]:
)
->
List
[
ModuleWrapper
]:
layers
:
NamedModules
=
OrderedDict
()
rank
=
group
.
rank
()
first_layer
=
sum
(
balance
[:
rank
])
def
maybe_realize
(
layer
:
Any
)
->
nn
.
Module
:
num_layers
=
balance
[
rank
]
if
isinstance
(
layer
,
nn
.
Module
):
layers
=
module
[
first_layer
:
first_layer
+
num_layers
]
return
layer
instantiated_layers
=
[
l
if
isinstance
(
l
,
nn
.
Module
)
else
l
()
for
l
in
layers
]
elif
callable
(
layer
):
return
[
ModuleWrapper
(
nn
.
Sequential
(
*
instantiated_layers
),
Location
(
rank
,
0
))]
return
layer
()
else
:
raise
TypeError
(
f
"layer must be nn.Module or callable, is
{
type
(
layer
)
}
"
)
def
iterate_module
(
module
:
Union
[
nn
.
Sequential
,
list
])
->
Iterable
[
Tuple
[
Any
,
nn
.
Module
]]:
if
isinstance
(
module
,
nn
.
Sequential
):
yield
from
module
.
named_children
()
else
:
yield
from
((
str
(
k
),
v
)
for
k
,
v
in
enumerate
(
module
))
j
=
0
for
name
,
layer
in
iterate_module
(
module
):
layers
[
name
]
=
layer
if
len
(
layers
)
==
balance
[
j
]:
if
j
==
group
.
rank
():
for
key
in
layers
:
layers
[
key
]
=
maybe_realize
(
layers
[
key
])
if
not
isinstance
(
module
,
nn
.
Sequential
):
for
layer
in
layers
.
values
():
if
isinstance
(
layer
,
Skippable
):
raise
ValueError
(
"Can't use Skippable layers with multi-process pipe and lazy construction"
)
return
[
ModuleWrapper
(
nn
.
Sequential
(
layers
),
Location
(
j
,
0
))]
# Prepare for the next partition.
layers
.
clear
()
j
+=
1
raise
ValueError
(
"Souldn't get here, more ranks than partitions"
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""Counts the length of the underlying sequential module."""
"""Counts the length of the underlying sequential module."""
...
...
tests/nn/pipe_process/skip/__init__.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tests/nn/pipe_process/skip/test_gpipe.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
LazyModule
,
MultiProcessPipe
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.portal
import
PortalBlue
,
PortalCopy
,
PortalOrange
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
torch_spawn
([
3
])
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"balance"
,
[[
3
],
[
1
,
2
],
[
2
,
1
],
[
1
,
1
,
1
]],
ids
=
[
"3"
,
"1:2"
,
"2:1"
,
"1:1:1"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"never"
,
"always"
,
"except_last"
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
def
x1to3
(
balance
,
checkpoint
,
pipe_class
):
torch
.
manual_seed
(
0
)
if
pipe_class
==
AsyncPipe
and
len
(
balance
)
>
1
:
print
(
f
"skipping yarg"
)
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
yield
stash
(
"1to3"
,
input
)
output
=
self
.
conv
(
input
)
return
output
class
Layer2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
output
=
self
.
conv
(
input
)
return
output
@
skippable
(
pop
=
[
"1to3"
])
class
Layer3
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
3
,
1
)
def
forward
(
self
,
input
):
skip_1to3
=
yield
pop
(
"1to3"
)
output
=
self
.
conv
(
input
)
+
skip_1to3
return
output
model
=
nn
.
Sequential
(
Layer1
(),
Layer2
(),
Layer3
())
model
=
pipe_class
(
model
,
balance
,
chunks
=
3
,
checkpoint
=
checkpoint
,
input_device
=
torch
.
cuda
.
current_device
(),
worker_map
=
get_worker_map
(),
pipelined_backward
=
False
,
).
cuda
()
input
=
torch
.
rand
(
30
,
3
,
224
,
224
,
requires_grad
=
True
).
cuda
()
input
.
retain_grad
()
output
=
model
(
input
)
if
model
.
group
.
rank
()
==
len
(
balance
)
-
1
:
loss
=
output
.
mean
()
loss
.
backward
()
elif
model
.
group
.
rank
()
<
len
(
balance
)
-
1
:
model
.
back_helper
(
output
)
if
model
.
group
.
rank
()
==
len
(
balance
)
-
1
:
# TODO(tom) the single-process test uses 2e-1 but for some reason
# mutli-process is more noisy, need to investigate why
assert
torch
.
allclose
(
output
.
norm
(),
torch
.
tensor
(
1039.0
).
cuda
(),
atol
=
4e-1
)
if
model
.
group
.
rank
()
==
0
:
assert
torch
.
allclose
(
input
.
grad
.
norm
(),
torch
.
tensor
(
0.0004533053
).
cuda
())
torch
.
distributed
.
barrier
()
@
torch_spawn
([
2
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skip
(
reason
=
"flaky test"
)
def
none_skip
(
pipe_class
):
if
pipe_class
==
AsyncPipe
:
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
@
skippable
(
stash
=
[
"none"
])
class
Stash
(
nn
.
Module
):
def
forward
(
self
,
input
):
yield
stash
(
"none"
,
None
)
return
input
@
skippable
(
pop
=
[
"none"
])
class
Pop
(
nn
.
Module
):
def
forward
(
self
,
input
):
none
=
yield
pop
(
"none"
)
assert
none
is
None
return
input
model
=
nn
.
Sequential
(
Stash
(),
Pop
())
model
=
pipe_class
(
model
,
[
1
,
1
],
worker_map
=
get_worker_map
(),
input_device
=
torch
.
cuda
.
current_device
(),
chunks
=
5
,
).
cuda
()
input
=
torch
.
rand
(
10
,
requires_grad
=
True
).
cuda
()
input
.
retain_grad
()
output
=
model
(
input
)
def
assert_grad_fn_is_not_portal
(
grad_fn
,
visited
=
set
()):
if
grad_fn
in
visited
or
grad_fn
is
None
:
return
assert
not
isinstance
(
grad_fn
,
PortalBlue
.
_backward_cls
)
assert
not
isinstance
(
grad_fn
,
PortalCopy
.
_backward_cls
)
assert
not
isinstance
(
grad_fn
,
PortalOrange
.
_backward_cls
)
visited
.
add
(
grad_fn
)
for
next_grad_fn
,
_
in
grad_fn
.
next_functions
:
assert_grad_fn_is_not_portal
(
next_grad_fn
,
visited
)
if
model
.
group
.
rank
()
==
1
:
assert_grad_fn_is_not_portal
(
output
.
grad_fn
)
output
.
sum
().
backward
()
else
:
model
.
back_helper
(
output
)
assert
input
.
grad
.
mean
().
item
()
==
1
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
def
lazy_skippable_error
(
pipe_class
):
"""Using skippable layers in combination with lazy construction is currently
not supported, check that it raises an Exception"""
@
skippable
(
stash
=
[
"1to3"
])
class
Layer1
(
nn
.
Linear
):
pass
@
skippable
(
pop
=
[
"1to3"
])
class
Layer3
(
nn
.
Linear
):
pass
model
=
[
LazyModule
(
lambda
:
Layer1
(
10
,
10
)),
LazyModule
(
lambda
:
nn
.
Linear
(
10
,
10
)),
LazyModule
(
lambda
:
Layer3
(
10
,
10
)),
]
with
pytest
.
raises
(
ValueError
,
match
=
"Can't use Skippable layers with multi-process pipe and lazy construction"
):
pipe_class
(
model
,
[
2
,
1
],
worker_map
=
get_worker_map
(),
)
tests/nn/pipe_process/skip/test_leak.py
deleted
100644 → 0
View file @
d624b81a
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
pytest
import
torch
from
torch
import
nn
from
fairscale.nn.pipe
import
AsyncPipe
,
MultiProcessPipe
,
is_checkpointing
,
is_recomputing
from
fairscale.nn.pipe.skip
import
pop
,
skippable
,
stash
from
fairscale.nn.pipe.skip.tracker
import
current_skip_tracker
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
@
skippable
(
stash
=
[
"skip"
])
class
Stash
(
nn
.
Module
):
def
forward
(
self
,
input
):
yield
stash
(
"skip"
,
input
)
return
input
@
skippable
(
pop
=
[
"skip"
])
class
Pop
(
nn
.
Module
):
def
forward
(
self
,
input
):
skip
=
yield
pop
(
"skip"
)
return
input
+
skip
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"train"
,
[
True
,
False
],
ids
=
[
"train"
,
"eval"
])
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
[
"always"
,
"except_last"
,
"never"
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skipif
(
"OMPI_COMM_WORLD_RANK"
in
os
.
environ
,
reason
=
"broken on mpi"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
delete_portal_tensor
(
train
,
checkpoint
,
pipe_class
):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
# +----------+ +------------+
#
# With checkpointing:
# +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+
# | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 |
# +----------+ +------------+ +------------+ +----------+
if
pipe_class
==
AsyncPipe
:
pytest
.
skip
(
"Skip tensors NYI for AsyncPipe"
)
def
portal_tensor_life_is
(
tensor_life
,
skip_tracker
=
None
):
if
skip_tracker
is
None
:
skip_tracker
=
current_skip_tracker
()
# Get the current portal.
portal
=
list
(
skip_tracker
.
portals
.
values
())[
0
]
if
tensor_life
==
0
:
return
portal
.
tensor_life
==
0
and
portal
.
tensor
is
None
else
:
return
portal
.
tensor_life
==
tensor_life
and
portal
.
tensor
is
not
None
# Check the portal tensor after 'Stash'.
stash_
=
Stash
()
@
stash_
.
register_forward_hook
def
check_portal_tensor_after_stash
(
*
_
):
if
is_checkpointing
():
assert
portal_tensor_life_is
(
2
)
elif
is_recomputing
():
assert
portal_tensor_life_is
(
0
)
else
:
assert
portal_tensor_life_is
(
1
)
pop_
=
Pop
()
@
pop_
.
register_forward_hook
def
check_portal_tensor_after_pop
(
*
_
):
if
is_checkpointing
():
assert
portal_tensor_life_is
(
1
)
elif
is_recomputing
():
assert
portal_tensor_life_is
(
0
)
else
:
assert
portal_tensor_life_is
(
0
)
class
NoPortalTensorAtBackward
(
nn
.
Module
):
class
F
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
):
ctx
.
skip_tracker
=
current_skip_tracker
()
return
input
.
detach
()
@
staticmethod
def
backward
(
ctx
,
grad
):
assert
portal_tensor_life_is
(
0
,
skip_tracker
=
ctx
.
skip_tracker
)
return
grad
def
forward
(
self
,
input
):
return
self
.
F
.
apply
(
input
)
model
=
nn
.
Sequential
(
NoPortalTensorAtBackward
(),
stash_
,
pop_
)
model
=
pipe_class
(
model
,
balance
=
[
2
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
2
,
checkpoint
=
checkpoint
,)
input
=
torch
.
rand
(
10
,
requires_grad
=
True
)
if
train
:
model
.
train
()
output
=
model
(
input
)
if
model
.
group
.
rank
()
==
1
:
output
.
norm
().
backward
()
else
:
model
.
back_helper
(
output
)
else
:
model
.
eval
()
with
torch
.
no_grad
():
model
(
input
)
torch
.
distributed
.
barrier
()
tests/nn/pipe_process/test_pipe.py
View file @
e3a20fef
...
@@ -629,9 +629,9 @@ def partitions(pipe_class):
...
@@ -629,9 +629,9 @@ def partitions(pipe_class):
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
assert
isinstance
(
model
.
partitions
[
0
].
module
,
nn
.
Sequential
)
if
model
.
group
.
rank
()
==
0
:
if
model
.
group
.
rank
()
==
0
:
assert
"0.0.weight"
in
model
.
state_dict
()
assert
model
[
0
].
weight
==
a
.
weight
else
:
else
:
assert
"0.1.weight"
in
model
.
state_dict
()
assert
model
[
0
].
weight
==
b
.
weight
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
...
@@ -677,6 +677,7 @@ def empty_module(pipe_class):
...
@@ -677,6 +677,7 @@ def empty_module(pipe_class):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
parametrize
(
"pipe_class"
,
[
MultiProcessPipe
,
AsyncPipe
])
@
pytest
.
mark
.
skip
(
reason
=
"TODO(msb) handle named_children"
)
def
named_children
(
pipe_class
):
def
named_children
(
pipe_class
):
a
=
nn
.
Linear
(
1
,
1
)
a
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
b
=
nn
.
Linear
(
1
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment