Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
......@@ -18,8 +18,8 @@ from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking import task
from official.recommendation.ranking.data import data_pipeline
class TaskTest(parameterized.TestCase, tf.test.TestCase):
......@@ -34,6 +34,8 @@ class TaskTest(parameterized.TestCase, tf.test.TestCase):
params.task.train_data.global_batch_size = 16
params.task.validation_data.global_batch_size = 16
params.task.model.vocab_sizes = [40, 12, 11, 13, 2, 5]
params.task.model.embedding_dim = 8
params.task.model.bottom_mlp = [64, 32, 8]
params.task.use_synthetic_data = True
params.task.model.num_dense_features = 5
......
......@@ -20,15 +20,14 @@ from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.common import distribute_utils
from official.core import base_trainer
from official.core import train_lib
from official.core import train_utils
from official.recommendation.ranking import common
from official.recommendation.ranking.task import RankingTask
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
......@@ -86,7 +85,7 @@ def main(_) -> None:
enable_tensorboard = params.trainer.callbacks.enable_tensorboard
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......@@ -95,6 +94,21 @@ def main(_) -> None:
with strategy.scope():
model = task.build_model()
def get_dataset_fn(params):
return lambda input_context: task.build_inputs(params, input_context)
train_dataset = None
if 'train' in mode:
train_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.train_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
validation_dataset = None
if 'eval' in mode:
validation_dataset = strategy.distribute_datasets_from_function(
get_dataset_fn(params.task.validation_data),
options=tf.distribute.InputOptions(experimental_fetch_to_device=False))
if params.trainer.use_orbit:
with strategy.scope():
checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
......@@ -106,6 +120,8 @@ def main(_) -> None:
optimizer=model.optimizer,
train='train' in mode,
evaluate='eval' in mode,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
checkpoint_exporter=checkpoint_exporter)
train_lib.run_experiment(
......@@ -117,16 +133,6 @@ def main(_) -> None:
trainer=trainer)
else: # Compile/fit
train_dataset = None
if 'train' in mode:
train_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.train_data)
eval_dataset = None
if 'eval' in mode:
eval_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.validation_data)
checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
......@@ -169,7 +175,7 @@ def main(_) -> None:
initial_epoch=initial_epoch,
epochs=num_epochs,
steps_per_epoch=params.trainer.validation_interval,
validation_data=eval_dataset,
validation_data=validation_dataset,
validation_steps=eval_steps,
callbacks=callbacks,
)
......@@ -177,7 +183,7 @@ def main(_) -> None:
logging.info('Train history: %s', history.history)
elif mode == 'eval':
logging.info('Evaluation started')
validation_output = model.evaluate(eval_dataset, steps=eval_steps)
validation_output = model.evaluate(validation_dataset, steps=eval_steps)
logging.info('Evaluation output: %s', validation_output)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
......
......@@ -40,6 +40,8 @@ def _get_params_override(vocab_sizes,
'task': {
'model': {
'vocab_sizes': vocab_sizes,
'embedding_dim': [8] * len(vocab_sizes),
'bottom_mlp': [64, 32, 8],
'interaction': interaction,
},
'train_data': {
......
six
google-api-python-client>=1.6.7
google-cloud-bigquery>=0.31.0
kaggle>=1.3.9
numpy>=1.15.4
oauth2client
......
......@@ -75,6 +75,7 @@ ResNet-RS-350 | 320x320 | 164.3 | 84.2 | 96.9 | [config](https://github.c
backbone | resolution | epochs | FLOPs (B) | params (M) | box AP | download
------------ | :--------: | -----: | --------: | ---------: | -----: | -------:
MobileNetv2 | 256x256 | 600 | - | 2.27 | 23.5 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/retinanet/coco_mobilenetv2_tpu.yaml) |
Mobile SpineNet-49 | 384x384 | 600 | 1.0 | 2.32 | 28.1 | [config](https://github.com/tensorflow/models/blob/master/official/vision/beta/configs/experiments/retinanet/coco_spinenet49_mobile_tpu.yaml) |
......
This directory contains the new design of TF model garden vision framework.
Stay tuned.
......@@ -80,6 +80,11 @@ class SpineNetMobile(hyperparams.Config):
expand_ratio: int = 6
min_level: int = 3
max_level: int = 7
# If use_keras_upsampling_2d is True, model uses UpSampling2D keras layer
# instead of optimized custom TF op. It makes model be more keras style. We
# set this flag to True when we apply QAT from model optimization toolkit
# that requires the model should use keras layers.
use_keras_upsampling_2d: bool = False
@dataclasses.dataclass
......
# --experiment_type=retinanet_mobile_coco
# COCO AP 23.5%
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
task:
losses:
l2_weight_decay: 3.0e-05
model:
anchor:
anchor_size: 3
aspect_ratios: [0.5, 1.0, 2.0]
num_scales: 3
backbone:
mobilenet:
model_id: 'MobileNetV2'
filter_size_scale: 1.0
type: 'mobilenet'
decoder:
type: 'fpn'
fpn:
num_filters: 128
use_separable_conv: true
head:
num_convs: 4
num_filters: 128
use_separable_conv: true
input_size: [256, 256, 3]
max_level: 7
min_level: 3
norm_activation:
activation: 'relu6'
norm_epsilon: 0.001
norm_momentum: 0.99
use_sync_bn: true
train_data:
dtype: 'bfloat16'
global_batch_size: 256
is_training: true
parser:
aug_rand_hflip: true
aug_scale_max: 2.0
aug_scale_min: 0.5
validation_data:
dtype: 'bfloat16'
global_batch_size: 8
is_training: false
trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
linear:
warmup_learning_rate: 0.0067
warmup_steps: 2000
steps_per_loop: 462
train_steps: 277200
validation_interval: 462
validation_steps: 625
# --experiment_type=retinanet_spinenet_mobile_coco
# --experiment_type=retinanet_mobile_coco
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -47,7 +47,7 @@ trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [265650, 272580]
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
......
# --experiment_type=retinanet_spinenet_mobile_coco
# --experiment_type=retinanet_mobile_coco
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -47,7 +47,7 @@ trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [265650, 272580]
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
......
# --experiment_type=retinanet_spinenet_mobile_coco
# --experiment_type=retinanet_mobile_coco
runtime:
distribution_strategy: 'tpu'
mixed_precision_dtype: 'bfloat16'
......@@ -47,7 +47,7 @@ trainer:
optimizer_config:
learning_rate:
stepwise:
boundaries: [265650, 272580]
boundaries: [263340, 272580]
values: [0.32, 0.032, 0.0032]
type: 'stepwise'
warmup:
......
......@@ -78,6 +78,7 @@ class DataConfig(cfg.DataConfig):
parser: Parser = Parser()
shuffle_buffer_size: int = 10000
file_type: str = 'tfrecord'
drop_remainder: bool = True
@dataclasses.dataclass
......@@ -215,7 +216,8 @@ class Losses(hyperparams.Config):
class MaskRCNNTask(cfg.TaskConfig):
model: MaskRCNN = MaskRCNN()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
validation_data: DataConfig = DataConfig(is_training=False,
drop_remainder=False)
losses: Losses = Losses()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
......@@ -260,7 +262,8 @@ def fasterrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -324,7 +327,8 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=22500,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -401,7 +405,8 @@ def maskrcnn_spinenet_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 350,
validation_steps=coco_val_samples // eval_batch_size,
......@@ -486,7 +491,8 @@ def cascadercnn_spinenet_coco() -> cfg.ExperimentConfig:
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
is_training=False,
global_batch_size=eval_batch_size)),
global_batch_size=eval_batch_size,
drop_remainder=False)),
trainer=cfg.TrainerConfig(
train_steps=steps_per_epoch * 500,
validation_steps=coco_val_samples // eval_batch_size,
......
......@@ -130,6 +130,13 @@ class RetinaNet(hyperparams.Config):
norm_activation: common.NormActivation = common.NormActivation()
@dataclasses.dataclass
class ExportConfig(hyperparams.Config):
output_normalized_coordinates: bool = False
cast_num_detections_to_float: bool = False
cast_detection_classes_to_float: bool = False
@dataclasses.dataclass
class RetinaNetTask(cfg.TaskConfig):
model: RetinaNet = RetinaNet()
......@@ -140,6 +147,7 @@ class RetinaNetTask(cfg.TaskConfig):
init_checkpoint_modules: str = 'all' # all or backbone
annotation_file: Optional[str] = None
per_category_metrics: bool = False
export_config: ExportConfig = ExportConfig()
@exp_factory.register_config_factory('retinanet')
......@@ -318,9 +326,9 @@ def retinanet_spinenet_coco() -> cfg.ExperimentConfig:
return config
@exp_factory.register_config_factory('retinanet_spinenet_mobile_coco')
@exp_factory.register_config_factory('retinanet_mobile_coco')
def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
"""COCO object detection with RetinaNet using Mobile SpineNet backbone."""
"""COCO object detection with mobile RetinaNet."""
train_batch_size = 256
eval_batch_size = 8
steps_per_epoch = COCO_TRAIN_EXAMPLES // train_batch_size
......@@ -338,7 +346,8 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
max_level=7,
use_keras_upsampling_2d=False)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
head=RetinaNetHead(num_filters=48, use_separable_conv=True),
......@@ -398,8 +407,6 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
'task.model.min_level == task.model.backbone.spinenet_mobile.min_level',
'task.model.max_level == task.model.backbone.spinenet_mobile.max_level',
])
return config
......@@ -28,7 +28,7 @@ class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('retinanet_resnetfpn_coco',),
('retinanet_spinenet_coco',),
('retinanet_spinenet_mobile_coco',),
('retinanet_mobile_coco',),
)
def test_retinanet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
......
......@@ -3,17 +3,19 @@
# Processes the COCO few-shot benchmark into TFRecord files. Requires `wget`.
tmp_dir=$(mktemp -d -t coco-XXXXXXXXXX)
base_image_dir="/tmp/coco_images"
output_dir="/tmp/coco_few_shot"
while getopts "o:" o; do
while getopts ":i:o:" o; do
case "${o}" in
o) output_dir=${OPTARG} ;;
*) echo "Usage: ${0} [-o <output_dir>]" 1>&2; exit 1 ;;
i) base_image_dir=${OPTARG} ;;
*) echo "Usage: ${0} [-i <base_image_dir>] [-o <output_dir>]" 1>&2; exit 1 ;;
esac
done
cocosplit_url="dl.yf.io/fs-det/datasets/cocosplit"
wget --recursive --no-parent -q --show-progress --progress=bar:force:noscroll \
-P "${tmp_dir}" -A "5k.json,*10shot*.json,*30shot*.json" \
-P "${tmp_dir}" -A "trainvalno5k.json,5k.json,*10shot*.json,*30shot*.json" \
"http://${cocosplit_url}/"
mv "${tmp_dir}/${cocosplit_url}/"* "${tmp_dir}"
rm -rf "${tmp_dir}/${cocosplit_url}/"
......@@ -25,8 +27,8 @@ for seed in {0..9}; do
for shots in 10 30; do
python create_coco_tf_record.py \
--logtostderr \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
--image_dir="${base_image_dir}/train2014" \
--image_dir="${base_image_dir}/val2014" \
--image_info_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
--object_annotations_file="${tmp_dir}/${shots}shot_seed${seed}.json" \
--caption_annotations_file="" \
......@@ -37,12 +39,32 @@ done
python create_coco_tf_record.py \
--logtostderr \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/train2014 \
--image_dir=/namespace/vale-project/datasets/mscoco_raw/images/val2014 \
--image_dir="${base_image_dir}/train2014" \
--image_dir="${base_image_dir}/val2014" \
--image_info_file="${tmp_dir}/datasplit/5k.json" \
--object_annotations_file="${tmp_dir}/datasplit/5k.json" \
--caption_annotations_file="" \
--output_file_prefix="${output_dir}/5k" \
--num_shards=10
python create_coco_tf_record.py \
--logtostderr \
--image_dir="${base_image_dir}/train2014" \
--image_dir="${base_image_dir}/val2014" \
--image_info_file="${tmp_dir}/datasplit/trainvalno5k_base.json" \
--object_annotations_file="${tmp_dir}/datasplit/trainvalno5k_base.json" \
--caption_annotations_file="" \
--output_file_prefix="${output_dir}/trainvalno5k_base" \
--num_shards=200
python create_coco_tf_record.py \
--logtostderr \
--image_dir="${base_image_dir}/train2014" \
--image_dir="${base_image_dir}/val2014" \
--image_info_file="${tmp_dir}/datasplit/5k_base.json" \
--object_annotations_file="${tmp_dir}/datasplit/5k_base.json" \
--caption_annotations_file="" \
--output_file_prefix="${output_dir}/5k_base" \
--num_shards=10
rm -rf "${tmp_dir}"
......@@ -76,10 +76,30 @@ for _seed, _shots in itertools.product(SEEDS, SHOTS):
_shots,
_category))
# Base class IDs, as defined in
# https://github.com/ucbdrive/few-shot-object-detection/blob/master/fsdet/evaluation/coco_evaluation.py#L60-L65
BASE_CLASS_IDS = [8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 65, 70, 73, 74, 75,
76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
def main(unused_argv):
workdir = FLAGS.workdir
# Filter novel class annotations from the training and validation sets.
for name in ('trainvalno5k', '5k'):
file_path = os.path.join(workdir, 'datasplit', '{}.json'.format(name))
with tf.io.gfile.GFile(file_path, 'r') as f:
json_dict = json.load(f)
json_dict['annotations'] = [a for a in json_dict['annotations']
if a['category_id'] in BASE_CLASS_IDS]
output_path = os.path.join(
workdir, 'datasplit', '{}_base.json'.format(name))
with tf.io.gfile.GFile(output_path, 'w') as f:
json.dump(json_dict, f)
for seed, shots in itertools.product(SEEDS, SHOTS):
# Retrieve all examples for a given seed and shots setting.
file_paths = [os.path.join(workdir, suffix)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFDS factory functions."""
from official.vision.beta.dataloaders import decoder as base_decoder
from official.vision.beta.dataloaders import tfds_detection_decoders
from official.vision.beta.dataloaders import tfds_segmentation_decoders
from official.vision.beta.dataloaders import tfds_classification_decoders
def get_classification_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets classification decoder.
Args:
tfds_name: `str`, name of the tfds classification decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_classification_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(
f'TFDS Classification {tfds_name} is not supported')
return decoder
def get_detection_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets detection decoder.
Args:
tfds_name: `str`, name of the tfds detection decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_detection_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(f'TFDS Detection {tfds_name} is not supported')
return decoder
def get_segmentation_decoder(tfds_name: str) -> base_decoder.Decoder:
"""Gets segmentation decoder.
Args:
tfds_name: `str`, name of the tfds segmentation decoder.
Returns:
`base_decoder.Decoder` instance.
Raises:
ValueError if the tfds_name doesn't exist in the available decoders.
"""
if tfds_name in tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP:
decoder = tfds_segmentation_decoders.TFDS_ID_TO_DECODER_MAP[tfds_name]()
else:
raise ValueError(f'TFDS Segmentation {tfds_name} is not supported')
return decoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfds factory functions."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.dataloaders import decoder as base_decoder
from official.vision.beta.dataloaders import tfds_factory
class TFDSFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('imagenet2012'),
('cifar10'),
('cifar100'),
)
def test_classification_decoder(self, tfds_name):
decoder = tfds_factory.get_classification_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('flowers'),
('coco'),
)
def test_doesnt_exit_classification_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_classification_decoder(tfds_name)
@parameterized.parameters(
('coco'),
('coco/2014'),
('coco/2017'),
)
def test_detection_decoder(self, tfds_name):
decoder = tfds_factory.get_detection_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('pascal'),
('cityscapes'),
)
def test_doesnt_exit_detection_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_detection_decoder(tfds_name)
@parameterized.parameters(
('cityscapes'),
('cityscapes/semantic_segmentation'),
('cityscapes/semantic_segmentation_extra'),
)
def test_segmentation_decoder(self, tfds_name):
decoder = tfds_factory.get_segmentation_decoder(tfds_name)
self.assertIsInstance(decoder, base_decoder.Decoder)
@parameterized.parameters(
('coco'),
('imagenet'),
)
def test_doesnt_exit_segmentation_decoder(self, tfds_name):
with self.assertRaises(ValueError):
_ = tfds_factory.get_segmentation_decoder(tfds_name)
if __name__ == '__main__':
tf.test.main()
......@@ -143,3 +143,24 @@ def create_classification_example(
int64_list=tf.train.Int64List(value=labels))),
})).SerializeToString()
return serialized_example
def create_3d_image_test_example(image_height: int, image_width: int,
image_volume: int,
image_channel: int) -> tf.train.Example:
"""Creates 3D image and label."""
images = np.random.rand(image_height, image_width, image_volume,
image_channel)
images = images.astype(np.float32)
labels = np.random.randint(
low=2, size=(image_height, image_width, image_volume, image_channel))
labels = labels.astype(np.float32)
feature = {
IMAGE_KEY: (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[images.tobytes()]))),
CLASSIFICATION_LABEL_KEY: (tf.train.Feature(
bytes_list=tf.train.BytesList(value=[labels.tobytes()])))
}
return tf.train.Example(features=tf.train.Features(feature=feature))
......@@ -592,8 +592,9 @@ class MobileNet(tf.keras.Model):
x, endpoints, next_endpoint_level = self._mobilenet_base(inputs=inputs)
endpoints[str(next_endpoint_level)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
# Don't include the final layer in `self._output_specs` to support decoders.
endpoints[str(next_endpoint_level)] = x
super(MobileNet, self).__init__(
inputs=inputs, outputs=endpoints, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment