Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4e2d8cd8
Unverified
Commit
4e2d8cd8
authored
Oct 19, 2020
by
liuzhe-lz
Committed by
GitHub
Oct 19, 2020
Browse files
Fix TF NAS naive example (#2948)
parent
cd23bc41
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
19 deletions
+21
-19
examples/nas/naive-tf/train.py
examples/nas/naive-tf/train.py
+21
-19
No files found.
examples/nas/naive-tf/train.py
View file @
4e2d8cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.keras
import
Model
from
tensorflow.keras
import
Model
from
tensorflow.keras.layers
import
(
AveragePooling2D
,
BatchNormalization
,
Conv2D
,
Dense
,
MaxPool2D
)
from
tensorflow.keras.layers
import
(
AveragePooling2D
,
BatchNormalization
,
Conv2D
,
Dense
,
MaxPool2D
)
...
@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
...
@@ -7,8 +10,6 @@ from tensorflow.keras.optimizers import SGD
from
nni.nas.tensorflow.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.tensorflow.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.tensorflow.enas
import
EnasTrainer
from
nni.nas.tensorflow.enas
import
EnasTrainer
tf
.
get_logger
().
setLevel
(
'ERROR'
)
class
Net
(
Model
):
class
Net
(
Model
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -53,35 +54,36 @@ class Net(Model):
...
@@ -53,35 +54,36 @@ class Net(Model):
return
x
return
x
def
accuracy
(
output
,
target
):
def
accuracy
(
truth
,
logits
):
bs
=
target
.
shape
[
0
]
truth
=
tf
.
reshape
(
truth
,
-
1
)
predicted
=
tf
.
cast
(
tf
.
argmax
(
output
,
1
),
target
.
dtype
)
predicted
=
tf
.
cast
(
tf
.
math
.
argmax
(
logits
,
axis
=
1
),
truth
.
dtype
)
target
=
tf
.
reshape
(
target
,
[
-
1
])
equal
=
tf
.
cast
(
predicted
==
truth
,
tf
.
int32
)
return
sum
(
tf
.
cast
(
predicted
==
target
,
tf
.
float32
))
/
bs
return
tf
.
math
.
reduce_sum
(
equal
).
numpy
()
/
equal
.
shape
[
0
]
def
accuracy_metrics
(
truth
,
logits
):
acc
=
accuracy
(
truth
,
logits
)
return
{
'accuracy'
:
acc
}
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
cifar10
=
tf
.
keras
.
datasets
.
cifar10
cifar10
=
tf
.
keras
.
datasets
.
cifar10
(
x_train
,
y_train
),
(
x_test
,
y_test
)
=
cifar10
.
load_data
()
(
x_train
,
y_train
),
(
x_valid
,
y_valid
)
=
cifar10
.
load_data
()
x_train
,
x_test
=
x_train
/
255.0
,
x_test
/
255.0
x_train
,
x_valid
=
x_train
/
255.0
,
x_valid
/
255.0
split
=
int
(
len
(
x_train
)
*
0.9
)
train_set
=
(
x_train
,
y_train
)
dataset_train
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x_train
[:
split
],
y_train
[:
split
])).
batch
(
64
)
valid_set
=
(
x_valid
,
y_valid
)
dataset_valid
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x_train
[
split
:],
y_train
[
split
:])).
batch
(
64
)
dataset_test
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x_test
,
y_test
)).
batch
(
64
)
net
=
Net
()
net
=
Net
()
trainer
=
EnasTrainer
(
trainer
=
EnasTrainer
(
net
,
net
,
loss
=
SparseCategoricalCrossentropy
(
reduction
=
Reduction
.
SUM
),
loss
=
SparseCategoricalCrossentropy
(
from_logits
=
True
,
reduction
=
Reduction
.
NONE
),
metrics
=
accuracy
,
metrics
=
accuracy
_metrics
,
reward_function
=
accuracy
,
reward_function
=
accuracy
,
optimizer
=
SGD
(
learning_rate
=
0.001
,
momentum
=
0.9
),
optimizer
=
SGD
(
learning_rate
=
0.001
,
momentum
=
0.9
),
batch_size
=
64
,
batch_size
=
64
,
num_epochs
=
2
,
num_epochs
=
2
,
dataset_train
=
dataset_train
,
dataset_train
=
train_set
,
dataset_valid
=
dataset_valid
,
dataset_valid
=
valid_set
dataset_test
=
dataset_test
)
)
trainer
.
train
()
trainer
.
train
()
#trainer.export('checkpoint')
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment