Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b372abf8
"tools/nni_cmd/ssh_utils.py" did not exist on "a1014619300dfe47582f0ff1cf92976f6671fe57"
Unverified
Commit
b372abf8
authored
Nov 04, 2020
by
HeekangPark
Committed by
GitHub
Nov 04, 2020
Browse files
Fix Error in nas SPOS trainer, apply_fixed_architecture (#3051)
parent
45e82b3e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
5 deletions
+13
-5
nni/algorithms/nas/pytorch/spos/trainer.py
nni/algorithms/nas/pytorch/spos/trainer.py
+2
-0
nni/nas/pytorch/fixed.py
nni/nas/pytorch/fixed.py
+11
-5
No files found.
nni/algorithms/nas/pytorch/spos/trainer.py
View file @
b372abf8
...
...
@@ -63,6 +63,7 @@ class SPOSSupernetTrainer(Trainer):
self
.
model
.
train
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
...
...
@@ -82,6 +83,7 @@ class SPOSSupernetTrainer(Trainer):
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
loss
=
self
.
loss
(
logits
,
y
)
...
...
nni/nas/pytorch/fixed.py
View file @
b372abf8
...
...
@@ -24,11 +24,14 @@ class FixedArchitecture(Mutator):
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
,
verbose
=
True
):
super
().
__init__
(
model
)
self
.
_fixed_arc
=
fixed_arc
self
.
verbose
=
verbose
mutable_keys
=
set
([
mutable
.
key
for
mutable
in
self
.
mutables
if
not
isinstance
(
mutable
,
MutableScope
)])
fixed_arc_keys
=
set
(
self
.
_fixed_arc
.
keys
())
...
...
@@ -99,10 +102,11 @@ class FixedArchitecture(Mutator):
if
sum
(
chosen
)
==
1
and
max
(
chosen
)
==
1
and
not
mutable
.
return_mask
:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
if
self
.
verbose
:
_logger
.
info
(
"Replacing %s with candidate number %d."
,
global_name
,
chosen
.
index
(
1
))
setattr
(
module
,
name
,
mutable
[
chosen
.
index
(
1
)])
else
:
if
mutable
.
return_mask
:
if
mutable
.
return_mask
and
self
.
verbose
:
_logger
.
info
(
"`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, "
\
"LayerChoice will not be replaced."
)
# remove unused parameters
...
...
@@ -113,7 +117,7 @@ class FixedArchitecture(Mutator):
self
.
replace_layer_choice
(
mutable
,
global_name
)
def
apply_fixed_architecture
(
model
,
fixed_arc
):
def
apply_fixed_architecture
(
model
,
fixed_arc
,
verbose
=
True
):
"""
Load architecture from `fixed_arc` and apply to model.
...
...
@@ -123,6 +127,8 @@ def apply_fixed_architecture(model, fixed_arc):
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
...
...
@@ -133,7 +139,7 @@ def apply_fixed_architecture(model, fixed_arc):
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
,
verbose
)
architecture
.
reset
()
# for the convenience of parameters counting
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment