Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e7fccfb4
Unverified
Commit
e7fccfb4
authored
Aug 12, 2020
by
liuzhe-lz
Committed by
GitHub
Aug 12, 2020
Browse files
TF NAS fix: avoid checking member during forward (#2781)
Co-authored-by:
liuzhe
<
zhliu1@microsoft.com
>
parent
5623dbf3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
20 deletions
+15
-20
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
+4
-4
src/sdk/pynni/nni/nas/tensorflow/mutables.py
src/sdk/pynni/nni/nas/tensorflow/mutables.py
+11
-16
No files found.
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
View file @
e7fccfb4
...
...
@@ -136,10 +136,10 @@ class EnasTrainer:
meters
=
AverageMeterGroup
()
for
x
,
y
in
test_loader
:
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
,
training
=
False
)
if
isinstance
(
logits
,
tuple
):
logits
,
_
=
logits
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
y
,
logits
)
loss
=
self
.
loss
(
y
,
logits
)
metrics
[
'loss'
]
=
tf
.
reduce_mean
(
loss
).
numpy
()
meters
.
update
(
metrics
)
...
...
@@ -151,8 +151,8 @@ class EnasTrainer:
def
_create_train_loader
(
self
):
train_set
=
self
.
train_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
test_set
=
self
.
test
_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
test_set
=
self
.
valid
_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
return
iter
(
train_set
),
iter
(
test_set
)
def
_create_validate_loader
(
self
):
return
iter
(
self
.
test_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
))
return
iter
(
self
.
test_set
.
shuffle
(
1000000
).
batch
(
self
.
batch_size
))
src/sdk/pynni/nni/nas/tensorflow/mutables.py
View file @
e7fccfb4
...
...
@@ -28,20 +28,19 @@ class Mutable(Model):
def
__deepcopy__
(
self
,
memodict
=
None
):
raise
NotImplementedError
(
"Deep copy doesn't work for mutables."
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
self
.
_check_built
()
return
super
().
__call__
(
*
args
,
**
kwargs
)
def
set_mutator
(
self
,
mutator
):
if
'mutator'
in
self
.
__dict__
:
if
hasattr
(
self
,
'mutator'
)
:
raise
RuntimeError
(
'`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?'
)
self
.
__dict__
[
'
mutator
'
]
=
mutator
self
.
mutator
=
mutator
def
call
(
self
,
*
inputs
):
raise
NotImplementedError
(
'Method `call` of Mutable must be overridden'
)
def
build
(
self
,
input_shape
):
self
.
_check_built
()
@
property
def
key
(
self
):
return
self
.
_key
...
...
@@ -68,7 +67,6 @@ class Mutable(Model):
class
MutableScope
(
Mutable
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
try
:
self
.
_check_built
()
self
.
mutator
.
enter_mutable_scope
(
self
)
return
super
().
__call__
(
*
args
,
**
kwargs
)
finally
:
...
...
@@ -80,7 +78,7 @@ class LayerChoice(Mutable):
super
().
__init__
(
key
=
key
)
self
.
names
=
[]
if
isinstance
(
op_candidates
,
OrderedDict
):
for
name
,
_
in
op_candidates
.
items
()
:
for
name
in
op_candidates
:
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
self
.
names
.
append
(
name
)
...
...
@@ -94,21 +92,18 @@ class LayerChoice(Mutable):
self
.
choices
=
op_candidates
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
_built
=
False
def
call
(
self
,
*
inputs
):
if
not
self
.
_built
:
for
op
in
self
.
choices
:
if
len
(
inputs
)
>
1
:
# FIXME: not tested
op
.
build
([
inp
.
shape
for
inp
in
inputs
])
elif
len
(
inputs
)
==
1
:
op
.
build
(
inputs
[
0
].
shape
)
self
.
_built
=
True
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
inputs
)
if
self
.
return_mask
:
return
out
,
mask
return
out
def
build
(
self
,
input_shape
):
self
.
_check_built
()
for
op
in
self
.
choices
:
op
.
build
(
input_shape
)
def
__len__
(
self
):
return
len
(
self
.
choices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment