Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c2666cea
Commit
c2666cea
authored
May 19, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
May 19, 2020
Browse files
[Clean up] Remove enable_eager in the session config: Model garden is TF2 only now.
Remove is_v2_0 PiperOrigin-RevId: 312336907
parent
4ec2ee97
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
15 additions
and
165 deletions
+15
-165
official/benchmark/models/resnet_cifar_main.py
official/benchmark/models/resnet_cifar_main.py
+0
-1
official/benchmark/models/resnet_cifar_test.py
official/benchmark/models/resnet_cifar_test.py
+0
-7
official/benchmark/models/resnet_imagenet_main.py
official/benchmark/models/resnet_imagenet_main.py
+0
-1
official/benchmark/models/resnet_imagenet_test.py
official/benchmark/models/resnet_imagenet_test.py
+0
-16
official/benchmark/models/resnet_imagenet_test_tpu.py
official/benchmark/models/resnet_imagenet_test_tpu.py
+0
-5
official/benchmark/models/shakespeare/shakespeare_main.py
official/benchmark/models/shakespeare/shakespeare_main.py
+0
-2
official/benchmark/shakespeare_benchmark.py
official/benchmark/shakespeare_benchmark.py
+0
-4
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+1
-1
official/nlp/bert/run_squad_helper.py
official/nlp/bert/run_squad_helper.py
+1
-1
official/r1/boosted_trees/train_higgs.py
official/r1/boosted_trees/train_higgs.py
+1
-1
official/r1/mnist/mnist.py
official/r1/mnist/mnist.py
+1
-1
official/r1/mnist/mnist_eager_test.py
official/r1/mnist/mnist_eager_test.py
+0
-95
official/r1/mnist/mnist_test.py
official/r1/mnist/mnist_test.py
+2
-9
official/r1/resnet/cifar10_test.py
official/r1/resnet/cifar10_test.py
+1
-3
official/r1/resnet/imagenet_test.py
official/r1/resnet/imagenet_test.py
+1
-3
official/r1/utils/data/file_io_test.py
official/r1/utils/data/file_io_test.py
+1
-3
official/r1/wide_deep/census_dataset.py
official/r1/wide_deep/census_dataset.py
+1
-1
official/r1/wide_deep/census_main.py
official/r1/wide_deep/census_main.py
+1
-1
official/r1/wide_deep/census_test.py
official/r1/wide_deep/census_test.py
+3
-9
official/r1/wide_deep/movielens_dataset.py
official/r1/wide_deep/movielens_dataset.py
+1
-1
No files found.
official/benchmark/models/resnet_cifar_main.py
View file @
c2666cea
...
...
@@ -119,7 +119,6 @@ def run(flags_obj):
Dictionary of training and eval stats.
"""
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
# Execute flag override logic for better model performance
...
...
official/benchmark/models/resnet_cifar_test.py
View file @
c2666cea
...
...
@@ -26,7 +26,6 @@ from tensorflow.python.eager import context
from
tensorflow.python.platform
import
googletest
from
official.benchmark.models
import
cifar_preprocessing
from
official.benchmark.models
import
resnet_cifar_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
...
...
@@ -60,8 +59,6 @@ class KerasCifarTest(googletest.TestCase):
def
test_end_to_end_no_dist_strat
(
self
):
"""Test Keras model with 1 GPU, no distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"off"
,
...
...
@@ -94,8 +91,6 @@ class KerasCifarTest(googletest.TestCase):
def
test_end_to_end_1_gpu
(
self
):
"""Test Keras model with 1 GPU."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
1
:
self
.
skipTest
(
...
...
@@ -140,8 +135,6 @@ class KerasCifarTest(googletest.TestCase):
def
test_end_to_end_2_gpu
(
self
):
"""Test Keras model with 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
...
...
official/benchmark/models/resnet_imagenet_main.py
View file @
c2666cea
...
...
@@ -51,7 +51,6 @@ def run(flags_obj):
Dictionary of training and eval stats.
"""
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
# Execute flag override logic for better model performance
...
...
official/benchmark/models/resnet_imagenet_test.py
View file @
c2666cea
...
...
@@ -23,7 +23,6 @@ import tensorflow as tf
from
tensorflow.python.eager
import
context
from
official.benchmark.models
import
resnet_imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
...
...
@@ -85,8 +84,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_no_dist_strat
(
self
,
flags_key
):
"""Test Keras model with 1 GPU, no distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"off"
,
...
...
@@ -115,8 +112,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_1_gpu
(
self
,
flags_key
):
"""Test Keras model with 1 GPU."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
1
:
self
.
skipTest
(
...
...
@@ -138,8 +133,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_1_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 1 GPU and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
1
:
self
.
skipTest
(
...
...
@@ -164,8 +157,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_2_gpu
(
self
,
flags_key
):
"""Test Keras model with 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
...
...
@@ -186,8 +177,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_xla_2_gpu
(
self
,
flags_key
):
"""Test Keras model with XLA and 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
...
...
@@ -209,8 +198,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
...
...
@@ -235,9 +222,6 @@ class KerasImagenetTest(tf.test.TestCase):
def
test_end_to_end_xla_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with XLA, 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
if
context
.
num_gpus
()
<
2
:
self
.
skipTest
(
"{} GPUs are not available for this test. {} GPUs are available"
.
...
...
official/benchmark/models/resnet_imagenet_test_tpu.py
View file @
c2666cea
...
...
@@ -21,7 +21,6 @@ from __future__ import print_function
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.benchmark.models
import
resnet_imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification.resnet
import
imagenet_preprocessing
...
...
@@ -70,8 +69,6 @@ class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase):
])
def
test_end_to_end_tpu
(
self
,
flags_key
):
"""Test Keras model with TPU distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"tpu"
,
...
...
@@ -89,8 +86,6 @@ class KerasImagenetTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
parameters
([
"resnet"
])
def
test_end_to_end_tpu_bf16
(
self
,
flags_key
):
"""Test Keras model with TPU and bfloat16 activation."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"tpu"
,
...
...
official/benchmark/models/shakespeare/shakespeare_main.py
View file @
c2666cea
...
...
@@ -139,7 +139,6 @@ def build_model(vocab_size,
Returns:
A Keras Model.
"""
assert
keras_utils
.
is_v2_0
()
LSTM
=
functools
.
partial
(
tf
.
keras
.
layers
.
LSTM
,
implementation
=
2
)
# By indirecting the activation through a lambda layer, the logic to dispatch
...
...
@@ -275,7 +274,6 @@ def run(flags_obj):
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
keras_utils
.
set_session_config
(
enable_eager
=
flags_obj
.
enable_eager
,
enable_xla
=
flags_obj
.
enable_xla
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
...
...
official/benchmark/shakespeare_benchmark.py
View file @
c2666cea
...
...
@@ -273,7 +273,6 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
FLAGS
.
num_gpus
=
1
FLAGS
.
batch_size
=
64
FLAGS
.
cudnn
=
False
FLAGS
.
enable_eager
=
keras_utils
.
is_v2_0
()
self
.
_run_and_report_benchmark
()
def
benchmark_1_gpu_no_ds
(
self
):
...
...
@@ -307,7 +306,6 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
FLAGS
.
num_gpus
=
1
FLAGS
.
batch_size
=
64
FLAGS
.
cudnn
=
False
FLAGS
.
enable_eager
=
keras_utils
.
is_v2_0
()
FLAGS
.
enable_xla
=
True
self
.
_run_and_report_benchmark
()
...
...
@@ -326,7 +324,6 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
FLAGS
.
batch_size
=
64
*
8
FLAGS
.
log_steps
=
10
FLAGS
.
cudnn
=
False
FLAGS
.
enable_eager
=
keras_utils
.
is_v2_0
()
self
.
_run_and_report_benchmark
()
def
benchmark_xla_8_gpu
(
self
):
...
...
@@ -345,7 +342,6 @@ class ShakespeareKerasBenchmarkReal(ShakespeareBenchmarkBase):
FLAGS
.
batch_size
=
64
*
8
FLAGS
.
log_steps
=
10
FLAGS
.
cudnn
=
False
FLAGS
.
enable_eager
=
keras_utils
.
is_v2_0
()
FLAGS
.
enable_xla
=
True
self
.
_run_and_report_benchmark
()
...
...
official/nlp/bert/run_classifier.py
View file @
c2666cea
...
...
@@ -347,7 +347,7 @@ def run_bert(strategy,
if
FLAGS
.
mode
!=
'train_and_eval'
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_config
_v2
(
FLAGS
.
enable_xla
)
keras_utils
.
set_
session_
config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
...
...
official/nlp/bert/run_squad_helper.py
View file @
c2666cea
...
...
@@ -227,7 +227,7 @@ def train_squad(strategy,
logging
.
info
(
'Training using customized training loop with distribution'
' strategy.'
)
# Enables XLA in Session Config. Should not be set for TPU.
keras_utils
.
set_config
_v2
(
FLAGS
.
enable_xla
)
keras_utils
.
set_
session_
config
(
FLAGS
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
common_flags
.
dtype
())
epochs
=
FLAGS
.
num_train_epochs
...
...
official/r1/boosted_trees/train_higgs.py
View file @
c2666cea
...
...
@@ -48,7 +48,7 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
from
official.r1.utils.logs
import
logger
from
official.utils.flags
import
core
as
flags_core
...
...
official/r1/mnist/mnist.py
View file @
c2666cea
...
...
@@ -21,7 +21,7 @@ from absl import app as absl_app
from
absl
import
flags
from
absl
import
logging
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
from
official.r1.mnist
import
dataset
from
official.r1.utils.logs
import
hooks_helper
...
...
official/r1/mnist/mnist_eager_test.py
deleted
100644 → 0
View file @
4ec2ee97
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tensorflow.python
import
eager
as
tfe
# pylint: disable=g-bad-import-order
from
official.r1.mnist
import
mnist
from
official.r1.mnist
import
mnist_eager
from
official.utils.misc
import
keras_utils
def
device
():
return
'/device:GPU:0'
if
tfe
.
context
.
num_gpus
()
else
'/device:CPU:0'
def
data_format
():
return
'channels_first'
if
tfe
.
context
.
num_gpus
()
else
'channels_last'
def
random_dataset
():
batch_size
=
64
images
=
tf
.
random_normal
([
batch_size
,
784
])
labels
=
tf
.
random_uniform
([
batch_size
],
minval
=
0
,
maxval
=
10
,
dtype
=
tf
.
int32
)
return
tf
.
data
.
Dataset
.
from_tensors
((
images
,
labels
))
def
train
(
defun
=
False
):
model
=
mnist
.
create_model
(
data_format
())
if
defun
:
model
.
call
=
tf
.
function
(
model
.
call
)
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
dataset
=
random_dataset
()
with
tf
.
device
(
device
()):
mnist_eager
.
train
(
model
,
optimizer
,
dataset
,
step_counter
=
tf
.
train
.
get_or_create_global_step
())
def
evaluate
(
defun
=
False
):
model
=
mnist
.
create_model
(
data_format
())
dataset
=
random_dataset
()
if
defun
:
model
.
call
=
tf
.
function
(
model
.
call
)
with
tf
.
device
(
device
()):
mnist_eager
.
test
(
model
,
dataset
)
class
MNISTTest
(
tf
.
test
.
TestCase
):
"""Run tests for MNIST eager loop.
MNIST eager uses contrib and will not work with TF 2.0. All tests are
disabled if using TF 2.0.
"""
def
setUp
(
self
):
if
not
keras_utils
.
is_v2_0
():
tf
.
compat
.
v1
.
enable_v2_behavior
()
super
(
MNISTTest
,
self
).
setUp
()
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_train
(
self
):
train
(
defun
=
False
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_evaluate
(
self
):
evaluate
(
defun
=
False
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_train_with_defun
(
self
):
train
(
defun
=
True
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_evaluate_with_defun
(
self
):
evaluate
(
defun
=
True
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/r1/mnist/mnist_test.py
View file @
c2666cea
...
...
@@ -18,12 +18,10 @@ from __future__ import division
from
__future__
import
print_function
import
time
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
.compat.v1
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.mnist
import
mnist
from
official.utils.misc
import
keras_utils
BATCH_SIZE
=
100
...
...
@@ -51,7 +49,6 @@ class Tests(tf.test.TestCase):
using TF 2.0.
"""
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_mnist
(
self
):
classifier
=
make_estimator
()
classifier
.
train
(
input_fn
=
dummy_input_fn
,
steps
=
2
)
...
...
@@ -71,7 +68,6 @@ class Tests(tf.test.TestCase):
self
.
assertEqual
(
predictions
[
'probabilities'
].
shape
,
(
10
,))
self
.
assertEqual
(
predictions
[
'classes'
].
shape
,
())
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
mnist_model_fn_helper
(
self
,
mode
,
multi_gpu
=
False
):
features
,
labels
=
dummy_input_fn
()
image_count
=
features
.
shape
[
0
]
...
...
@@ -99,19 +95,15 @@ class Tests(tf.test.TestCase):
self
.
assertEqual
(
eval_metric_ops
[
'accuracy'
][
0
].
dtype
,
tf
.
float32
)
self
.
assertEqual
(
eval_metric_ops
[
'accuracy'
][
1
].
dtype
,
tf
.
float32
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_mnist_model_fn_train_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
TRAIN
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_mnist_model_fn_train_mode_multi_gpu
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
TRAIN
,
multi_gpu
=
True
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_mnist_model_fn_eval_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
EVAL
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_mnist_model_fn_predict_mode
(
self
):
self
.
mnist_model_fn_helper
(
tf
.
estimator
.
ModeKeys
.
PREDICT
)
...
...
@@ -144,4 +136,5 @@ class Benchmarks(tf.test.Benchmark):
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
ERROR
)
tf
.
disable_v2_behavior
()
tf
.
test
.
main
()
official/r1/resnet/cifar10_test.py
View file @
c2666cea
...
...
@@ -24,7 +24,6 @@ import numpy as np
import
tensorflow
as
tf
from
official.r1.resnet
import
cifar10_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
logging
.
set_verbosity
(
logging
.
ERROR
)
...
...
@@ -44,8 +43,7 @@ class BaseTest(tf.test.TestCase):
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
BaseTest
,
cls
).
setUpClass
()
if
keras_utils
.
is_v2_0
:
tf
.
compat
.
v1
.
disable_eager_execution
()
tf
.
compat
.
v1
.
disable_eager_execution
()
cifar10_main
.
define_cifar_flags
()
def
setUp
(
self
):
...
...
official/r1/resnet/imagenet_test.py
View file @
c2666cea
...
...
@@ -23,7 +23,6 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
absl
import
logging
from
official.r1.resnet
import
imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
logging
.
set_verbosity
(
logging
.
ERROR
)
...
...
@@ -43,8 +42,7 @@ class BaseTest(tf.test.TestCase):
def
setUp
(
self
):
super
(
BaseTest
,
self
).
setUp
()
if
keras_utils
.
is_v2_0
:
tf
.
compat
.
v1
.
disable_eager_execution
()
tf
.
compat
.
v1
.
disable_eager_execution
()
self
.
_num_validation_images
=
imagenet_main
.
NUM_IMAGES
[
'validation'
]
imagenet_main
.
NUM_IMAGES
[
'validation'
]
=
4
...
...
official/r1/utils/data/file_io_test.py
View file @
c2666cea
...
...
@@ -28,7 +28,6 @@ import tensorflow as tf
# pylint: enable=wrong-import-order
from
official.r1.utils.data
import
file_io
from
official.utils.misc
import
keras_utils
_RAW_ROW
=
"raw_row"
...
...
@@ -108,8 +107,7 @@ class BaseTest(tf.test.TestCase):
def
setUp
(
self
):
super
(
BaseTest
,
self
).
setUp
()
if
keras_utils
.
is_v2_0
:
tf
.
compat
.
v1
.
disable_eager_execution
()
tf
.
compat
.
v1
.
disable_eager_execution
()
def
_test_sharding
(
self
,
row_count
,
cpu_count
,
expected
):
df
=
pd
.
DataFrame
({
_DUMMY_COL
:
list
(
range
(
row_count
))})
...
...
official/r1/wide_deep/census_dataset.py
View file @
c2666cea
...
...
@@ -26,7 +26,7 @@ from absl import app as absl_app
from
absl
import
flags
from
six.moves
import
urllib
from
six.moves
import
zip
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
# pylint: enable=wrong-import-order
from
official.utils.flags
import
core
as
flags_core
...
...
official/r1/wide_deep/census_main.py
View file @
c2666cea
...
...
@@ -18,7 +18,7 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
from
official.r1.utils.logs
import
logger
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
wide_deep_run_loop
...
...
official/r1/wide_deep/census_test.py
View file @
c2666cea
...
...
@@ -18,15 +18,13 @@ from __future__ import division
from
__future__
import
print_function
import
os
import
unittest
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
absl
import
logging
import
tensorflow.compat.v1
as
tf
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.r1.wide_deep
import
census_dataset
from
official.r1.wide_deep
import
census_main
from
official.utils.testing
import
integration
logging
.
set_verbosity
(
logging
.
ERROR
)
...
...
@@ -73,7 +71,6 @@ class BaseTest(tf.test.TestCase):
os
.
path
.
join
(
self
.
temp_dir
,
fname
),
'w'
)
as
test_csv
:
test_csv
.
write
(
test_csv_contents
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_input_fn
(
self
):
dataset
=
census_dataset
.
input_fn
(
self
.
input_csv
,
1
,
False
,
1
)
features
,
labels
=
dataset
.
make_one_shot_iterator
().
get_next
()
...
...
@@ -127,11 +124,9 @@ class BaseTest(tf.test.TestCase):
initial_results
[
'auc_precision_recall'
])
self
.
assertGreater
(
final_results
[
'accuracy'
],
initial_results
[
'accuracy'
])
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_wide_deep_estimator_training
(
self
):
self
.
build_and_test_estimator
(
'wide_deep'
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_end_to_end_wide
(
self
):
integration
.
run_synthetic
(
main
=
census_main
.
main
,
tmp_root
=
self
.
get_temp_dir
(),
...
...
@@ -142,7 +137,6 @@ class BaseTest(tf.test.TestCase):
],
synth
=
False
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_end_to_end_deep
(
self
):
integration
.
run_synthetic
(
main
=
census_main
.
main
,
tmp_root
=
self
.
get_temp_dir
(),
...
...
@@ -153,7 +147,6 @@ class BaseTest(tf.test.TestCase):
],
synth
=
False
)
@
unittest
.
skipIf
(
keras_utils
.
is_v2_0
(),
'TF 1.0 only test.'
)
def
test_end_to_end_wide_deep
(
self
):
integration
.
run_synthetic
(
main
=
census_main
.
main
,
tmp_root
=
self
.
get_temp_dir
(),
...
...
@@ -166,4 +159,5 @@ class BaseTest(tf.test.TestCase):
if
__name__
==
'__main__'
:
tf
.
disable_eager_execution
()
tf
.
test
.
main
()
official/r1/wide_deep/movielens_dataset.py
View file @
c2666cea
...
...
@@ -25,7 +25,7 @@ import os
from
absl
import
app
as
absl_app
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
.compat.v1
as
tf
# pylint: enable=wrong-import-order
from
official.recommendation
import
movielens
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment