Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ac892fc7
"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "8353c1d329dbf34f611d24926d9c1421236e3357"
Unverified
Commit
ac892fc7
authored
Aug 12, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 12, 2022
Browse files
Model space hub enhancements (v2.9) (#5050)
parent
802650ff
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
48 deletions
+60
-48
nni/nas/hub/pytorch/nasnet.py
nni/nas/hub/pytorch/nasnet.py
+49
-37
nni/nas/hub/pytorch/shufflenet.py
nni/nas/hub/pytorch/shufflenet.py
+11
-11
No files found.
nni/nas/hub/pytorch/nasnet.py
View file @
ac892fc7
...
@@ -7,7 +7,6 @@ The implementation is based on NDS.
...
@@ -7,7 +7,6 @@ The implementation is based on NDS.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
"""
"""
from
collections
import
OrderedDict
from
functools
import
partial
from
functools
import
partial
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
,
Optional
,
cast
from
typing
import
Tuple
,
List
,
Union
,
Iterable
,
Dict
,
Callable
,
Optional
,
cast
...
@@ -235,20 +234,6 @@ class AuxiliaryHead(nn.Module):
...
@@ -235,20 +234,6 @@ class AuxiliaryHead(nn.Module):
return
x
return
x
class
SequentialBreakdown
(
nn
.
Sequential
):
"""Return all layers of a sequential."""
def
__init__
(
self
,
sequential
:
nn
.
Sequential
):
super
().
__init__
(
OrderedDict
(
sequential
.
named_children
()))
def
forward
(
self
,
inputs
):
result
=
[]
for
module
in
self
:
inputs
=
module
(
inputs
)
result
.
append
(
inputs
)
return
result
class
CellPreprocessor
(
nn
.
Module
):
class
CellPreprocessor
(
nn
.
Module
):
"""
"""
Aligning the shape of predecessors.
Aligning the shape of predecessors.
...
@@ -296,7 +281,8 @@ class CellBuilder:
...
@@ -296,7 +281,8 @@ class CellBuilder:
C
:
nn
.
MaybeChoice
[
int
],
C
:
nn
.
MaybeChoice
[
int
],
num_nodes
:
int
,
num_nodes
:
int
,
merge_op
:
Literal
[
'all'
,
'loose_end'
],
merge_op
:
Literal
[
'all'
,
'loose_end'
],
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
):
first_cell_reduce
:
bool
,
last_cell_reduce
:
bool
,
drop_path_prob
:
float
):
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_prev_in
=
C_prev_in
# This is the out channels of the cell before last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C_in
=
C_in
# This is the out channesl of last cell.
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self
.
C
=
C
# This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
...
@@ -305,6 +291,7 @@ class CellBuilder:
...
@@ -305,6 +291,7 @@ class CellBuilder:
self
.
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
merge_op
self
.
merge_op
:
Literal
[
'all'
,
'loose_end'
]
=
merge_op
self
.
first_cell_reduce
=
first_cell_reduce
self
.
first_cell_reduce
=
first_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
last_cell_reduce
=
last_cell_reduce
self
.
drop_path_prob
=
drop_path_prob
self
.
_expect_idx
=
0
self
.
_expect_idx
=
0
# It takes an index that is the index in the repeat.
# It takes an index that is the index in the repeat.
...
@@ -318,11 +305,16 @@ class CellBuilder:
...
@@ -318,11 +305,16 @@ class CellBuilder:
op
:
str
,
channels
:
int
,
is_reduction_cell
:
bool
):
op
:
str
,
channels
:
int
,
is_reduction_cell
:
bool
):
if
is_reduction_cell
and
(
if
is_reduction_cell
and
(
input_index
is
None
or
input_index
<
self
.
num_predecessors
input_index
is
None
or
input_index
<
self
.
num_predecessors
):
# could be none when constructing search s
a
pce
):
# could be none when constructing search sp
a
ce
stride
=
2
stride
=
2
else
:
else
:
stride
=
1
stride
=
1
return
OPS
[
op
](
channels
,
stride
,
True
)
operation
=
OPS
[
op
](
channels
,
stride
,
True
)
if
self
.
drop_path_prob
>
0
and
not
isinstance
(
operation
,
nn
.
Identity
):
# Omit drop-path when operation is skip connect.
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model.py#L54
return
nn
.
Sequential
(
operation
,
DropPath_
(
self
.
drop_path_prob
))
return
operation
def
__call__
(
self
,
repeat_idx
:
int
):
def
__call__
(
self
,
repeat_idx
:
int
):
if
self
.
_expect_idx
!=
repeat_idx
:
if
self
.
_expect_idx
!=
repeat_idx
:
...
@@ -483,6 +475,8 @@ class NDS(nn.Module):
...
@@ -483,6 +475,8 @@ class NDS(nn.Module):
See :class:`~nni.retiarii.nn.pytorch.Cell`.
See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell : int
num_nodes_per_cell : int
See :class:`~nni.retiarii.nn.pytorch.Cell`.
See :class:`~nni.retiarii.nn.pytorch.Cell`.
drop_path_prob : float
Apply drop path. Enabled when it's set to be greater than 0.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -492,12 +486,14 @@ class NDS(nn.Module):
...
@@ -492,12 +486,14 @@ class NDS(nn.Module):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
16
,
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
16
,
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
20
,
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
20
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'imagenet'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
()
super
().
__init__
()
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
num_labels
=
10
if
dataset
==
'cifar'
else
1000
self
.
num_labels
=
10
if
dataset
==
'cifar'
else
1000
self
.
auxiliary_loss
=
auxiliary_loss
self
.
auxiliary_loss
=
auxiliary_loss
self
.
drop_path_prob
=
drop_path_prob
# preprocess the specified width and depth
# preprocess the specified width and depth
if
isinstance
(
width
,
Iterable
):
if
isinstance
(
width
,
Iterable
):
...
@@ -546,7 +542,7 @@ class NDS(nn.Module):
...
@@ -546,7 +542,7 @@ class NDS(nn.Module):
# C_curr is number of channels for each operator in current stage.
# C_curr is number of channels for each operator in current stage.
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder
=
CellBuilder
(
op_candidates
,
C_pprev
,
C_prev
,
C_curr
,
num_nodes_per_cell
,
cell_builder
=
CellBuilder
(
op_candidates
,
C_pprev
,
C_prev
,
C_curr
,
num_nodes_per_cell
,
merge_op
,
stage_idx
>
0
,
last_cell_reduce
)
merge_op
,
stage_idx
>
0
,
last_cell_reduce
,
drop_path_prob
)
stage
:
Union
[
NDSStage
,
nn
.
Sequential
]
=
NDSStage
(
cell_builder
,
num_cells_per_stage
[
stage_idx
])
stage
:
Union
[
NDSStage
,
nn
.
Sequential
]
=
NDSStage
(
cell_builder
,
num_cells_per_stage
[
stage_idx
])
if
isinstance
(
stage
,
NDSStage
):
if
isinstance
(
stage
,
NDSStage
):
...
@@ -581,7 +577,6 @@ class NDS(nn.Module):
...
@@ -581,7 +577,6 @@ class NDS(nn.Module):
if
auxiliary_loss
:
if
auxiliary_loss
:
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
assert
isinstance
(
self
.
stages
[
2
],
nn
.
Sequential
),
'Auxiliary loss can only be enabled in retrain mode.'
self
.
stages
[
2
]
=
SequentialBreakdown
(
cast
(
nn
.
Sequential
,
self
.
stages
[
2
]))
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
# type: ignore
self
.
auxiliary_head
=
AuxiliaryHead
(
C_to_auxiliary
,
self
.
num_labels
,
dataset
=
self
.
dataset
)
# type: ignore
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
((
1
,
1
))
...
@@ -595,12 +590,13 @@ class NDS(nn.Module):
...
@@ -595,12 +590,13 @@ class NDS(nn.Module):
s0
=
s1
=
self
.
stem
(
inputs
)
s0
=
s1
=
self
.
stem
(
inputs
)
for
stage_idx
,
stage
in
enumerate
(
self
.
stages
):
for
stage_idx
,
stage
in
enumerate
(
self
.
stages
):
if
stage_idx
==
2
and
self
.
auxiliary_loss
:
if
stage_idx
==
2
and
self
.
auxiliary_loss
and
self
.
training
:
s
=
list
(
stage
([
s0
,
s1
]).
values
())
assert
isinstance
(
stage
,
nn
.
Sequential
),
'Auxiliary loss is only supported for fixed architecture.'
s0
,
s1
=
s
[
-
1
]
for
block_idx
,
block
in
enumerate
(
stage
):
if
self
.
training
:
# auxiliary loss is attached to the first cell of the last stage.
# auxiliary loss is attached to the first cell of the last stage.
logits_aux
=
self
.
auxiliary_head
(
s
[
0
][
1
])
s0
,
s1
=
block
([
s0
,
s1
])
if
block_idx
==
0
:
logits_aux
=
self
.
auxiliary_head
(
s1
)
else
:
else
:
s0
,
s1
=
stage
([
s0
,
s1
])
s0
,
s1
=
stage
([
s0
,
s1
])
...
@@ -655,14 +651,16 @@ class NASNet(NDS):
...
@@ -655,14 +651,16 @@ class NASNet(NDS):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
(
self
.
NASNET_OPS
,
super
().
__init__
(
self
.
NASNET_OPS
,
merge_op
=
'loose_end'
,
merge_op
=
'loose_end'
,
num_nodes_per_cell
=
5
,
num_nodes_per_cell
=
5
,
width
=
width
,
width
=
width
,
num_cells
=
num_cells
,
num_cells
=
num_cells
,
dataset
=
dataset
,
dataset
=
dataset
,
auxiliary_loss
=
auxiliary_loss
)
auxiliary_loss
=
auxiliary_loss
,
drop_path_prob
=
drop_path_prob
)
@
model_wrapper
@
model_wrapper
...
@@ -686,14 +684,16 @@ class ENAS(NDS):
...
@@ -686,14 +684,16 @@ class ENAS(NDS):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
(
self
.
ENAS_OPS
,
super
().
__init__
(
self
.
ENAS_OPS
,
merge_op
=
'loose_end'
,
merge_op
=
'loose_end'
,
num_nodes_per_cell
=
5
,
num_nodes_per_cell
=
5
,
width
=
width
,
width
=
width
,
num_cells
=
num_cells
,
num_cells
=
num_cells
,
dataset
=
dataset
,
dataset
=
dataset
,
auxiliary_loss
=
auxiliary_loss
)
auxiliary_loss
=
auxiliary_loss
,
drop_path_prob
=
drop_path_prob
)
@
model_wrapper
@
model_wrapper
...
@@ -721,7 +721,8 @@ class AmoebaNet(NDS):
...
@@ -721,7 +721,8 @@ class AmoebaNet(NDS):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
(
self
.
AMOEBA_OPS
,
super
().
__init__
(
self
.
AMOEBA_OPS
,
merge_op
=
'loose_end'
,
merge_op
=
'loose_end'
,
...
@@ -729,7 +730,8 @@ class AmoebaNet(NDS):
...
@@ -729,7 +730,8 @@ class AmoebaNet(NDS):
width
=
width
,
width
=
width
,
num_cells
=
num_cells
,
num_cells
=
num_cells
,
dataset
=
dataset
,
dataset
=
dataset
,
auxiliary_loss
=
auxiliary_loss
)
auxiliary_loss
=
auxiliary_loss
,
drop_path_prob
=
drop_path_prob
)
@
model_wrapper
@
model_wrapper
...
@@ -757,14 +759,16 @@ class PNAS(NDS):
...
@@ -757,14 +759,16 @@ class PNAS(NDS):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
(
self
.
PNAS_OPS
,
super
().
__init__
(
self
.
PNAS_OPS
,
merge_op
=
'all'
,
merge_op
=
'all'
,
num_nodes_per_cell
=
5
,
num_nodes_per_cell
=
5
,
width
=
width
,
width
=
width
,
num_cells
=
num_cells
,
num_cells
=
num_cells
,
dataset
=
dataset
,
dataset
=
dataset
,
auxiliary_loss
=
auxiliary_loss
)
auxiliary_loss
=
auxiliary_loss
,
drop_path_prob
=
drop_path_prob
)
@
model_wrapper
@
model_wrapper
...
@@ -774,10 +778,16 @@ class DARTS(NDS):
...
@@ -774,10 +778,16 @@ class DARTS(NDS):
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
.. note::
``none`` is not included in the operator candidates.
It has already been handled in the differentiable implementation of cell.
"""
+
_INIT_PARAMETER_DOCS
"""
+
_INIT_PARAMETER_DOCS
DARTS_OPS
=
[
DARTS_OPS
=
[
'none'
,
#
'none',
'max_pool_3x3'
,
'max_pool_3x3'
,
'avg_pool_3x3'
,
'avg_pool_3x3'
,
'skip_connect'
,
'skip_connect'
,
...
@@ -791,14 +801,16 @@ class DARTS(NDS):
...
@@ -791,14 +801,16 @@ class DARTS(NDS):
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
width
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
16
,
24
,
32
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
num_cells
:
Union
[
Tuple
[
int
,
...],
int
]
=
(
4
,
8
,
12
,
16
,
20
),
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
dataset
:
Literal
[
'cifar'
,
'imagenet'
]
=
'cifar'
,
auxiliary_loss
:
bool
=
False
):
auxiliary_loss
:
bool
=
False
,
drop_path_prob
:
float
=
0.
):
super
().
__init__
(
self
.
DARTS_OPS
,
super
().
__init__
(
self
.
DARTS_OPS
,
merge_op
=
'all'
,
merge_op
=
'all'
,
num_nodes_per_cell
=
4
,
num_nodes_per_cell
=
4
,
width
=
width
,
width
=
width
,
num_cells
=
num_cells
,
num_cells
=
num_cells
,
dataset
=
dataset
,
dataset
=
dataset
,
auxiliary_loss
=
auxiliary_loss
)
auxiliary_loss
=
auxiliary_loss
,
drop_path_prob
=
drop_path_prob
)
@
classmethod
@
classmethod
def
load_searched_model
(
def
load_searched_model
(
...
...
nni/nas/hub/pytorch/shufflenet.py
View file @
ac892fc7
...
@@ -224,29 +224,29 @@ class ShuffleNetSpace(nn.Module):
...
@@ -224,29 +224,29 @@ class ShuffleNetSpace(nn.Module):
for
name
,
m
in
self
.
named_modules
():
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'first'
in
name
:
if
'first'
in
name
:
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
# type: ignore
else
:
else
:
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
1.0
/
m
.
weight
.
shape
[
1
])
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
1.0
/
m
.
weight
.
shape
[
1
])
# type: ignore
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
# type: ignore
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
if
m
.
weight
is
not
None
:
if
m
.
weight
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
# type: ignore
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
# type: ignore
if
m
.
running_mean
is
not
None
:
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
# type: ignore
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
elif
isinstance
(
m
,
nn
.
BatchNorm1d
):
if
m
.
weight
is
not
None
:
if
m
.
weight
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
torch
.
nn
.
init
.
constant_
(
m
.
weight
,
1
)
# type: ignore
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0.0001
)
# type: ignore
if
m
.
running_mean
is
not
None
:
if
m
.
running_mean
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
torch
.
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
# type: ignore
elif
isinstance
(
m
,
nn
.
Linear
):
elif
isinstance
(
m
,
nn
.
Linear
):
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
torch
.
nn
.
init
.
normal_
(
m
.
weight
,
0
,
0.01
)
# type: ignore
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
# type: ignore
@
classmethod
@
classmethod
def
fixed_arch
(
cls
,
arch
:
dict
)
->
FixedFactory
:
def
fixed_arch
(
cls
,
arch
:
dict
)
->
FixedFactory
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment