Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3aae9d06
Unverified
Commit
3aae9d06
authored
May 25, 2020
by
liuzhe-lz
Committed by
GitHub
May 25, 2020
Browse files
Fix tensorflow import (#2481)
parent
241b364b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
8 deletions
+6
-8
examples/nas/enas-tf/datasets.py
examples/nas/enas-tf/datasets.py
+0
-1
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
+6
-7
No files found.
examples/nas/enas-tf/datasets.py
View file @
3aae9d06
...
...
@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import
tensorflow
as
tf
from
tensorflow.data
import
Dataset
def
get_dataset
():
(
x_train
,
y_train
),
(
x_valid
,
y_valid
)
=
tf
.
keras
.
datasets
.
cifar10
.
load_data
()
...
...
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
View file @
3aae9d06
...
...
@@ -4,7 +4,6 @@
import
logging
import
tensorflow
as
tf
from
tensorflow.data
import
Dataset
from
tensorflow.keras.optimizers
import
Adam
from
nni.nas.tensorflow.utils
import
AverageMeterGroup
,
fill_zero_grads
...
...
@@ -39,9 +38,9 @@ class EnasTrainer:
x
,
y
=
dataset_train
split
=
int
(
len
(
x
)
*
0.9
)
self
.
train_set
=
Dataset
.
from_tensor_slices
((
x
[:
split
],
y
[:
split
]))
self
.
valid_set
=
Dataset
.
from_tensor_slices
((
x
[
split
:],
y
[
split
:]))
self
.
test_set
=
Dataset
.
from_tensor_slices
(
dataset_valid
)
self
.
train_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x
[:
split
],
y
[:
split
]))
self
.
valid_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x
[
split
:],
y
[
split
:]))
self
.
test_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
dataset_valid
)
self
.
mutator
=
EnasMutator
(
model
)
self
.
mutator_optim
=
Adam
(
learning_rate
=
mutator_lr
)
...
...
@@ -151,9 +150,9 @@ class EnasTrainer:
def
_create_train_loader
(
self
):
train_set
=
self
.
train_set
.
shuffle
(
1000000
).
batch
(
self
.
batch_size
)
test_set
=
self
.
test_set
.
shuffle
(
1000000
).
batch
(
self
.
batch_size
)
train_set
=
self
.
train_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
test_set
=
self
.
test_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
return
iter
(
train_set
),
iter
(
test_set
)
def
_create_validate_loader
(
self
):
return
iter
(
self
.
test_set
.
shuffle
(
1000000
).
batch
(
self
.
batch_size
))
return
iter
(
self
.
test_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment