Commit afd5579f authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

Merge remote-tracking branch 'upstream/master' into context_tf2

parents dcd96e02 567bd18d
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,6 +13,7 @@ ...@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Orbit package definition."""
from orbit import utils from orbit import utils
from orbit.controller import Controller from orbit.controller import Controller
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,14 +15,8 @@ ...@@ -14,14 +15,8 @@
# ============================================================================== # ==============================================================================
"""A light weight utilities to train TF2 models.""" """A light weight utilities to train TF2 models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import time import time
from typing import Callable, Optional, Text, Union from typing import Callable, Optional, Text, Union
from absl import logging from absl import logging
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
...@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int], ...@@ -43,7 +38,7 @@ def _validate_interval(interval: Optional[int], steps_per_loop: Optional[int],
interval_name, interval, steps_per_loop)) interval_name, interval, steps_per_loop))
class Controller(object): class Controller:
"""Class that facilitates training and evaluation of models.""" """Class that facilitates training and evaluation of models."""
def __init__( def __init__(
...@@ -396,7 +391,7 @@ class Controller(object): ...@@ -396,7 +391,7 @@ class Controller(object):
return False return False
class StepTimer(object): class StepTimer:
"""Utility class for measuring steps/second.""" """Utility class for measuring steps/second."""
def __init__(self, step): def __init__(self, step):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,10 +15,6 @@ ...@@ -14,10 +15,6 @@
# ============================================================================== # ==============================================================================
"""Tests for orbit.controller.""" """Tests for orbit.controller."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
from absl import logging from absl import logging
from absl.testing import parameterized from absl.testing import parameterized
...@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer): ...@@ -203,7 +200,7 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
class ControllerTest(tf.test.TestCase, parameterized.TestCase): class ControllerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
super(ControllerTest, self).setUp() super().setUp()
self.model_dir = self.get_temp_dir() self.model_dir = self.get_temp_dir()
def test_no_checkpoint(self): def test_no_checkpoint(self):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,19 +15,12 @@ ...@@ -14,19 +15,12 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
from typing import Dict, Optional, Text from typing import Dict, Optional, Text
import six
import tensorflow as tf import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
class AbstractTrainer(tf.Module):
"""An abstract class defining the APIs required for training.""" """An abstract class defining the APIs required for training."""
@abc.abstractmethod @abc.abstractmethod
...@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module): ...@@ -56,8 +50,7 @@ class AbstractTrainer(tf.Module):
pass pass
@six.add_metaclass(abc.ABCMeta) class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
class AbstractEvaluator(tf.Module):
"""An abstract class defining the APIs required for evaluation.""" """An abstract class defining the APIs required for evaluation."""
@abc.abstractmethod @abc.abstractmethod
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,21 +15,14 @@ ...@@ -14,21 +15,14 @@
# ============================================================================== # ==============================================================================
"""An abstraction that users can easily handle their custom training loops.""" """An abstraction that users can easily handle their custom training loops."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
from typing import Any, Dict, Optional, Text from typing import Any, Dict, Optional, Text
from orbit import runner from orbit import runner
from orbit import utils from orbit import utils
import six
import tensorflow as tf import tensorflow as tf
@six.add_metaclass(abc.ABCMeta) class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
class StandardTrainer(runner.AbstractTrainer):
"""Implements the standard functionality of AbstractTrainer APIs.""" """Implements the standard functionality of AbstractTrainer APIs."""
def __init__(self, def __init__(self,
...@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer): ...@@ -145,8 +139,7 @@ class StandardTrainer(runner.AbstractTrainer):
self._train_iter = None self._train_iter = None
@six.add_metaclass(abc.ABCMeta) class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
class StandardEvaluator(runner.AbstractEvaluator):
"""Implements the standard functionality of AbstractEvaluator APIs.""" """Implements the standard functionality of AbstractEvaluator APIs."""
def __init__(self, eval_dataset, use_tf_function=True): def __init__(self, eval_dataset, use_tf_function=True):
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved. # Copyright 2020 The Orbit Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,18 +15,12 @@ ...@@ -14,18 +15,12 @@
# ============================================================================== # ==============================================================================
"""Some layered modules/functions to help users writing custom training loop.""" """Some layered modules/functions to help users writing custom training loop."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import abc import abc
import contextlib import contextlib
import functools import functools
import inspect import inspect
import numpy as np import numpy as np
import six
import tensorflow as tf import tensorflow as tf
...@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -132,10 +127,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
# names, pass `ctx` as the value of `input_context` when calling # names, pass `ctx` as the value of `input_context` when calling
# `dataset_or_fn`. Otherwise `ctx` will not be used when calling # `dataset_or_fn`. Otherwise `ctx` will not be used when calling
# `dataset_or_fn`. # `dataset_or_fn`.
if six.PY3: argspec = inspect.getfullargspec(dataset_or_fn)
argspec = inspect.getfullargspec(dataset_or_fn)
else:
argspec = inspect.getargspec(dataset_or_fn) # pylint: disable=deprecated-method
args_names = argspec.args args_names = argspec.args
if "input_context" in args_names: if "input_context" in args_names:
...@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): ...@@ -146,7 +138,7 @@ def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs):
return strategy.experimental_distribute_datasets_from_function(dataset_fn) return strategy.experimental_distribute_datasets_from_function(dataset_fn)
class SummaryManager(object): class SummaryManager:
"""A class manages writing summaries.""" """A class manages writing summaries."""
def __init__(self, summary_dir, summary_fn, global_step=None): def __init__(self, summary_dir, summary_fn, global_step=None):
...@@ -201,8 +193,7 @@ class SummaryManager(object): ...@@ -201,8 +193,7 @@ class SummaryManager(object):
self._summary_fn(name, tensor, step=self._global_step) self._summary_fn(name, tensor, step=self._global_step)
@six.add_metaclass(abc.ABCMeta) class Trigger(metaclass=abc.ABCMeta):
class Trigger(object):
"""An abstract class representing a "trigger" for some event.""" """An abstract class representing a "trigger" for some event."""
@abc.abstractmethod @abc.abstractmethod
...@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger): ...@@ -263,7 +254,7 @@ class IntervalTrigger(Trigger):
self._last_trigger_value = 0 self._last_trigger_value = 0
class EpochHelper(object): class EpochHelper:
"""A Helper class to handle epochs in Customized Training Loop.""" """A Helper class to handle epochs in Customized Training Loop."""
def __init__(self, epoch_steps, global_step): def __init__(self, epoch_steps, global_step):
......
## Attention-based Extraction of Structured Information from Street View Imagery # Attention-based Extraction of Structured Information from Street View Imagery
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/attention-based-extraction-of-structured/optical-character-recognition-on-fsns-test)](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/attention-based-extraction-of-structured/optical-character-recognition-on-fsns-test)](https://paperswithcode.com/sota/optical-character-recognition-on-fsns-test?p=attention-based-extraction-of-structured)
[![Paper](http://img.shields.io/badge/paper-arXiv.1704.03549-B3181B.svg)](https://arxiv.org/abs/1704.03549) [![Paper](http://img.shields.io/badge/paper-arXiv.1704.03549-B3181B.svg)](https://arxiv.org/abs/1704.03549)
...@@ -7,14 +7,20 @@ ...@@ -7,14 +7,20 @@
*A TensorFlow model for real-world image text extraction problems.* *A TensorFlow model for real-world image text extraction problems.*
This folder contains the code needed to train a new Attention OCR model on the This folder contains the code needed to train a new Attention OCR model on the
[FSNS dataset][FSNS] dataset to transcribe street names in France. You can [FSNS dataset][FSNS] to transcribe street names in France. You can also train the code on your own data.
also use it to train it on your own data.
More details can be found in our paper: More details can be found in our paper:
["Attention-based Extraction of Structured Information from Street View ["Attention-based Extraction of Structured Information from Street View
Imagery"](https://arxiv.org/abs/1704.03549) Imagery"](https://arxiv.org/abs/1704.03549)
## Description
* Paper presents a model based on ConvNets, RNN's and a novel attention mechanism.
Achieves **84.2%** on FSNS beating the previous benchmark (**72.46%**). Also studies
the speed/accuracy tradeoff that results from using CNN feature extractors of
different depths.
## Contacts ## Contacts
Authors Authors
...@@ -22,7 +28,18 @@ Authors ...@@ -22,7 +28,18 @@ Authors
* Zbigniew Wojna (zbigniewwojna@gmail.com) * Zbigniew Wojna (zbigniewwojna@gmail.com)
* Alexander Gorban (gorban@google.com) * Alexander Gorban (gorban@google.com)
Maintainer: Xavier Gibert [@xavigibert](https://github.com/xavigibert) Maintainer
* Xavier Gibert ([@xavigibert](https://github.com/xavigibert))
## Table of Contents
* [Requirements](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#requirements)
* [Dataset](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#dataset)
* [How to use this code](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-this-code)
* [Using your own image data](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#using-your-own-image-data)
* [How to use a pre-trained model](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#how-to-use-a-pre-trained-model)
* [Disclaimer](https://github.com/tensorflow/models/blob/master/research/attention_ocr/README.md#disclaimer)
## Requirements ## Requirements
...@@ -49,6 +66,42 @@ cd .. ...@@ -49,6 +66,42 @@ cd ..
[TF]: https://www.tensorflow.org/install/ [TF]: https://www.tensorflow.org/install/
[FSNS]: https://github.com/tensorflow/models/tree/master/research/street [FSNS]: https://github.com/tensorflow/models/tree/master/research/street
## Dataset
The French Street Name Signs (FSNS) dataset is split into subsets,
each of which is composed of multiple files. Note that these datasets
are very large. The approximate sizes are:
* Train: 512 files of 300MB each.
* Validation: 64 files of 40MB each.
* Test: 64 files of 50MB each.
* The datasets download includes a directory `testdata` that contains
some small datasets that are big enough to test that models can
actually learn something.
* Total: around 158GB
The download paths are in the following list:
```
https://download.tensorflow.org/data/fsns-20160927/charset_size=134.txt
https://download.tensorflow.org/data/fsns-20160927/test/test-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/test/test-00063-of-00064
https://download.tensorflow.org/data/fsns-20160927/testdata/arial-32-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/mnist-sample-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/testdata/numbers-16-00000-of-00001
https://download.tensorflow.org/data/fsns-20160927/train/train-00000-of-00512
...
https://download.tensorflow.org/data/fsns-20160927/train/train-00511-of-00512
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00000-of-00064
...
https://download.tensorflow.org/data/fsns-20160927/validation/validation-00063-of-00064
```
All URLs are stored in the [research/street](https://github.com/tensorflow/models/tree/master/research/street)
repository in the text file `python/fsns_urls.txt`.
## How to use this code ## How to use this code
To run all unit tests: To run all unit tests:
...@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz ...@@ -80,7 +133,7 @@ tar xf attention_ocr_2017_08_09.tar.gz
python train.py --checkpoint=model.ckpt-399731 python train.py --checkpoint=model.ckpt-399731
``` ```
## How to use your own image data to train the model ## Using your own image data
You need to define a new dataset. There are two options: You need to define a new dataset. There are two options:
......
...@@ -56,14 +56,14 @@ def augment_image(image): ...@@ -56,14 +56,14 @@ def augment_image(image):
Returns: Returns:
Distorted Tensor image of the same shape. Distorted Tensor image of the same shape.
""" """
with tf.variable_scope('AugmentImage'): with tf.compat.v1.variable_scope('AugmentImage'):
height = image.get_shape().dims[0].value height = image.get_shape().dims[0].value
width = image.get_shape().dims[1].value width = image.get_shape().dims[1].value
# Random crop cut from the street sign image, resized to the same size. # Random crop cut from the street sign image, resized to the same size.
# Assures that the crop is covers at least 0.8 area of the input image. # Assures that the crop is covers at least 0.8 area of the input image.
bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
tf.shape(image), image_size=tf.shape(input=image),
bounding_boxes=tf.zeros([0, 0, 4]), bounding_boxes=tf.zeros([0, 0, 4]),
min_object_covered=0.8, min_object_covered=0.8,
aspect_ratio_range=[0.8, 1.2], aspect_ratio_range=[0.8, 1.2],
...@@ -74,7 +74,7 @@ def augment_image(image): ...@@ -74,7 +74,7 @@ def augment_image(image):
# Randomly chooses one of the 4 interpolation methods # Randomly chooses one of the 4 interpolation methods
distorted_image = inception_preprocessing.apply_with_random_selector( distorted_image = inception_preprocessing.apply_with_random_selector(
distorted_image, distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method), lambda x, method: tf.image.resize(x, [height, width], method),
num_cases=4) num_cases=4)
distorted_image.set_shape([height, width, 3]) distorted_image.set_shape([height, width, 3])
...@@ -99,9 +99,10 @@ def central_crop(image, crop_size): ...@@ -99,9 +99,10 @@ def central_crop(image, crop_size):
Returns: Returns:
A tensor of shape [crop_height, crop_width, channels]. A tensor of shape [crop_height, crop_width, channels].
""" """
with tf.variable_scope('CentralCrop'): with tf.compat.v1.variable_scope('CentralCrop'):
target_width, target_height = crop_size target_width, target_height = crop_size
image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] image_height, image_width = tf.shape(
input=image)[0], tf.shape(input=image)[1]
assert_op1 = tf.Assert( assert_op1 = tf.Assert(
tf.greater_equal(image_height, target_height), tf.greater_equal(image_height, target_height),
['image_height < target_height', image_height, target_height]) ['image_height < target_height', image_height, target_height])
...@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None, ...@@ -129,7 +130,7 @@ def preprocess_image(image, augment=False, central_crop_size=None,
A float32 tensor of shape [H x W x 3] with RGB values in the required A float32 tensor of shape [H x W x 3] with RGB values in the required
range. range.
""" """
with tf.variable_scope('PreprocessImage'): with tf.compat.v1.variable_scope('PreprocessImage'):
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if augment or central_crop_size: if augment or central_crop_size:
if num_towers == 1: if num_towers == 1:
...@@ -182,7 +183,7 @@ def get_data(dataset, ...@@ -182,7 +183,7 @@ def get_data(dataset,
image_orig, augment, central_crop_size, num_towers=dataset.num_of_views) image_orig, augment, central_crop_size, num_towers=dataset.num_of_views)
label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes) label_one_hot = slim.one_hot_encoding(label, dataset.num_char_classes)
images, images_orig, labels, labels_one_hot = (tf.train.shuffle_batch( images, images_orig, labels, labels_one_hot = (tf.compat.v1.train.shuffle_batch(
[image, image_orig, label, label_one_hot], [image, image_orig, label, label_one_hot],
batch_size=batch_size, batch_size=batch_size,
num_threads=shuffle_config.num_batching_threads, num_threads=shuffle_config.num_batching_threads,
......
...@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'): ...@@ -72,7 +72,7 @@ def read_charset(filename, null_character=u'\u2591'):
""" """
pattern = re.compile(r'(\d+)\t(.+)') pattern = re.compile(r'(\d+)\t(.+)')
charset = {} charset = {}
with tf.gfile.GFile(filename) as f: with tf.io.gfile.GFile(filename) as f:
for i, line in enumerate(f): for i, line in enumerate(f):
m = pattern.match(line) m = pattern.match(line)
if m is None: if m is None:
...@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler): ...@@ -96,9 +96,9 @@ class _NumOfViewsHandler(slim.tfexample_decoder.ItemHandler):
self._num_of_views = num_of_views self._num_of_views = num_of_views
def tensors_to_item(self, keys_to_tensors): def tensors_to_item(self, keys_to_tensors):
return tf.to_int64( return tf.cast(
self._num_of_views * keys_to_tensors[self._original_width_key] / self._num_of_views * keys_to_tensors[self._original_width_key] /
keys_to_tensors[self._width_key]) keys_to_tensors[self._width_key], dtype=tf.int64)
def get_split(split_name, dataset_dir=None, config=None): def get_split(split_name, dataset_dir=None, config=None):
...@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None): ...@@ -133,19 +133,19 @@ def get_split(split_name, dataset_dir=None, config=None):
zero = tf.zeros([1], dtype=tf.int64) zero = tf.zeros([1], dtype=tf.int64)
keys_to_features = { keys_to_features = {
'image/encoded': 'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''), tf.io.FixedLenFeature((), tf.string, default_value=''),
'image/format': 'image/format':
tf.FixedLenFeature((), tf.string, default_value='png'), tf.io.FixedLenFeature((), tf.string, default_value='png'),
'image/width': 'image/width':
tf.FixedLenFeature([1], tf.int64, default_value=zero), tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/orig_width': 'image/orig_width':
tf.FixedLenFeature([1], tf.int64, default_value=zero), tf.io.FixedLenFeature([1], tf.int64, default_value=zero),
'image/class': 'image/class':
tf.FixedLenFeature([config['max_sequence_length']], tf.int64), tf.io.FixedLenFeature([config['max_sequence_length']], tf.int64),
'image/unpadded_class': 'image/unpadded_class':
tf.VarLenFeature(tf.int64), tf.io.VarLenFeature(tf.int64),
'image/text': 'image/text':
tf.FixedLenFeature([1], tf.string, default_value=''), tf.io.FixedLenFeature([1], tf.string, default_value=''),
} }
items_to_handlers = { items_to_handlers = {
'image': 'image':
...@@ -171,7 +171,7 @@ def get_split(split_name, dataset_dir=None, config=None): ...@@ -171,7 +171,7 @@ def get_split(split_name, dataset_dir=None, config=None):
config['splits'][split_name]['pattern']) config['splits'][split_name]['pattern'])
return slim.dataset.Dataset( return slim.dataset.Dataset(
data_sources=file_pattern, data_sources=file_pattern,
reader=tf.TFRecordReader, reader=tf.compat.v1.TFRecordReader,
decoder=decoder, decoder=decoder,
num_samples=config['splits'][split_name]['size'], num_samples=config['splits'][split_name]['size'],
items_to_descriptions=config['items_to_descriptions'], items_to_descriptions=config['items_to_descriptions'],
......
...@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase): ...@@ -91,7 +91,7 @@ class FsnsTest(tf.test.TestCase):
image_tf, label_tf = provider.get(['image', 'label']) image_tf, label_tf = provider.get(['image', 'label'])
with self.test_session() as sess: with self.test_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
with slim.queues.QueueRunners(sess): with slim.queues.QueueRunners(sess):
image_np, label_np = sess.run([image_tf, label_tf]) image_np, label_np = sess.run([image_tf, label_tf])
......
...@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5 ...@@ -10,7 +10,8 @@ KEEP_NUM_RECORDS = 5
print('Downloading %s ...' % URL) print('Downloading %s ...' % URL)
urllib.request.urlretrieve(URL, DST_ORIG) urllib.request.urlretrieve(URL, DST_ORIG)
print('Writing %d records from %s to %s ...' % (KEEP_NUM_RECORDS, DST_ORIG, DST)) print('Writing %d records from %s to %s ...' %
(KEEP_NUM_RECORDS, DST_ORIG, DST))
with tf.io.TFRecordWriter(DST) as writer: with tf.io.TFRecordWriter(DST) as writer:
for raw_record in itertools.islice(tf.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS): for raw_record in itertools.islice(tf.compat.v1.python_io.tf_record_iterator(DST_ORIG), KEEP_NUM_RECORDS):
writer.write(raw_record) writer.write(raw_record)
...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name): ...@@ -49,7 +49,7 @@ def load_images(file_pattern, batch_size, dataset_name):
for i in range(batch_size): for i in range(batch_size):
path = file_pattern % i path = file_pattern % i
print("Reading %s" % path) print("Reading %s" % path)
pil_image = PIL.Image.open(tf.gfile.GFile(path, 'rb')) pil_image = PIL.Image.open(tf.io.gfile.GFile(path, 'rb'))
images_actual_data[i, ...] = np.asarray(pil_image) images_actual_data[i, ...] = np.asarray(pil_image)
return images_actual_data return images_actual_data
...@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name): ...@@ -58,12 +58,13 @@ def create_model(batch_size, dataset_name):
width, height = get_dataset_image_size(dataset_name) width, height = get_dataset_image_size(dataset_name)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model( model = common_flags.create_model(
num_char_classes=dataset.num_char_classes, num_char_classes=dataset.num_char_classes,
seq_length=dataset.max_sequence_length, seq_length=dataset.max_sequence_length,
num_views=dataset.num_of_views, num_views=dataset.num_of_views,
null_code=dataset.null_code, null_code=dataset.null_code,
charset=dataset.charset) charset=dataset.charset)
raw_images = tf.placeholder(tf.uint8, shape=[batch_size, height, width, 3]) raw_images = tf.compat.v1.placeholder(
tf.uint8, shape=[batch_size, height, width, 3])
images = tf.map_fn(data_provider.preprocess_image, raw_images, images = tf.map_fn(data_provider.preprocess_image, raw_images,
dtype=tf.float32) dtype=tf.float32)
endpoints = model.create_base(images, labels_one_hot=None) endpoints = model.create_base(images, labels_one_hot=None)
...@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern): ...@@ -76,9 +77,9 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
images_data = load_images(image_path_pattern, batch_size, images_data = load_images(image_path_pattern, batch_size,
dataset_name) dataset_name)
session_creator = monitored_session.ChiefSessionCreator( session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint) checkpoint_filename_with_path=checkpoint)
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(
session_creator=session_creator) as sess: session_creator=session_creator) as sess:
predictions = sess.run(endpoints.predicted_text, predictions = sess.run(endpoints.predicted_text,
feed_dict={images_placeholder: images_data}) feed_dict={images_placeholder: images_data})
return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()] return [pr_bytes.decode('utf-8') for pr_bytes in predictions.tolist()]
...@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern): ...@@ -87,10 +88,10 @@ def run(checkpoint, batch_size, dataset_name, image_path_pattern):
def main(_): def main(_):
print("Predicted strings:") print("Predicted strings:")
predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name, predictions = run(FLAGS.checkpoint, FLAGS.batch_size, FLAGS.dataset_name,
FLAGS.image_path_pattern) FLAGS.image_path_pattern)
for line in predictions: for line in predictions:
print(line) print(line)
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.compat.v1.app.run()
...@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -14,12 +14,13 @@ class DemoInferenceTest(tf.test.TestCase):
super(DemoInferenceTest, self).setUp() super(DemoInferenceTest, self).setUp()
for suffix in ['.meta', '.index', '.data-00000-of-00001']: for suffix in ['.meta', '.index', '.data-00000-of-00001']:
filename = _CHECKPOINT + suffix filename = _CHECKPOINT + suffix
self.assertTrue(tf.gfile.Exists(filename), self.assertTrue(tf.io.gfile.exists(filename),
msg='Missing checkpoint file %s. ' msg='Missing checkpoint file %s. '
'Please download and extract it from %s' % 'Please download and extract it from %s' %
(filename, _CHECKPOINT_URL)) (filename, _CHECKPOINT_URL))
self._batch_size = 32 self._batch_size = 32
tf.flags.FLAGS.dataset_dir = os.path.join(os.path.dirname(__file__), 'datasets/testdata/fsns') tf.flags.FLAGS.dataset_dir = os.path.join(
os.path.dirname(__file__), 'datasets/testdata/fsns')
def test_moving_variables_properly_loaded_from_a_checkpoint(self): def test_moving_variables_properly_loaded_from_a_checkpoint(self):
batch_size = 32 batch_size = 32
...@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -30,15 +31,15 @@ class DemoInferenceTest(tf.test.TestCase):
images_data = demo_inference.load_images(image_path_pattern, batch_size, images_data = demo_inference.load_images(image_path_pattern, batch_size,
dataset_name) dataset_name)
tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean' tensor_name = 'AttentionOcr_v1/conv_tower_fn/INCE/InceptionV3/Conv2d_2a_3x3/BatchNorm/moving_mean'
moving_mean_tf = tf.get_default_graph().get_tensor_by_name( moving_mean_tf = tf.compat.v1.get_default_graph().get_tensor_by_name(
tensor_name + ':0') tensor_name + ':0')
reader = tf.train.NewCheckpointReader(_CHECKPOINT) reader = tf.compat.v1.train.NewCheckpointReader(_CHECKPOINT)
moving_mean_expected = reader.get_tensor(tensor_name) moving_mean_expected = reader.get_tensor(tensor_name)
session_creator = monitored_session.ChiefSessionCreator( session_creator = monitored_session.ChiefSessionCreator(
checkpoint_filename_with_path=_CHECKPOINT) checkpoint_filename_with_path=_CHECKPOINT)
with monitored_session.MonitoredSession( with monitored_session.MonitoredSession(
session_creator=session_creator) as sess: session_creator=session_creator) as sess:
moving_mean_np = sess.run(moving_mean_tf, moving_mean_np = sess.run(moving_mean_tf,
feed_dict={images_placeholder: images_data}) feed_dict={images_placeholder: images_data})
...@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase): ...@@ -50,38 +51,38 @@ class DemoInferenceTest(tf.test.TestCase):
'fsns', 'fsns',
image_path_pattern) image_path_pattern)
self.assertEqual([ self.assertEqual([
u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░', u'Boulevard de Lunel░░░░░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░', 'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░', 'Rue de Port Maria░░░░░░░░░░░░░░░░░░░░',
'Avenue Charles Gounod░░░░░░░░░░░░░░░░', 'Avenue Charles Gounod░░░░░░░░░░░░░░░░',
'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░', 'Rue de l‘Aurore░░░░░░░░░░░░░░░░░░░░░░',
'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░', 'Rue de Beuzeville░░░░░░░░░░░░░░░░░░░░',
'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░', 'Rue d‘Orbey░░░░░░░░░░░░░░░░░░░░░░░░░░',
'Rue Victor Schoulcher░░░░░░░░░░░░░░░░', 'Rue Victor Schoulcher░░░░░░░░░░░░░░░░',
'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░', 'Rue de la Gare░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Tulipes░░░░░░░░░░░░░░░░░░░░░░',
'Rue André Maginot░░░░░░░░░░░░░░░░░░░░', 'Rue André Maginot░░░░░░░░░░░░░░░░░░░░',
'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░', 'Route de Pringy░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Landelles░░░░░░░░░░░░░░░░░░░░', 'Rue des Landelles░░░░░░░░░░░░░░░░░░░░',
'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Ilettes░░░░░░░░░░░░░░░░░░░░░░',
'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░', 'Avenue de Maurin░░░░░░░░░░░░░░░░░░░░░',
'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa' 'Rue Théresa░░░░░░░░░░░░░░░░░░░░░░░░░░', # GT='Rue Thérésa'
'Route de la Balme░░░░░░░░░░░░░░░░░░░░', 'Route de la Balme░░░░░░░░░░░░░░░░░░░░',
'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░', 'Rue Hélène Roederer░░░░░░░░░░░░░░░░░░',
'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░', 'Rue Emile Bernard░░░░░░░░░░░░░░░░░░░░',
'Place de la Mairie░░░░░░░░░░░░░░░░░░░', 'Place de la Mairie░░░░░░░░░░░░░░░░░░░',
'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░', 'Rue des Perrots░░░░░░░░░░░░░░░░░░░░░░',
'Rue de la Libération░░░░░░░░░░░░░░░░░', 'Rue de la Libération░░░░░░░░░░░░░░░░░',
'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░', 'Impasse du Capcir░░░░░░░░░░░░░░░░░░░░',
'Avenue de la Grand Mare░░░░░░░░░░░░░░', 'Avenue de la Grand Mare░░░░░░░░░░░░░░',
'Rue Pierre Brossolette░░░░░░░░░░░░░░░', 'Rue Pierre Brossolette░░░░░░░░░░░░░░░',
'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░', 'Rue de Provence░░░░░░░░░░░░░░░░░░░░░░',
'Rue du Docteur Mourre░░░░░░░░░░░░░░░░', 'Rue du Docteur Mourre░░░░░░░░░░░░░░░░',
'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░', 'Rue d‘Ortheuil░░░░░░░░░░░░░░░░░░░░░░░',
'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░', 'Rue des Sarments░░░░░░░░░░░░░░░░░░░░░',
'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░', 'Rue du Centre░░░░░░░░░░░░░░░░░░░░░░░░',
'Impasse Pierre Mourgues░░░░░░░░░░░░░░', 'Impasse Pierre Mourgues░░░░░░░░░░░░░░',
'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░' 'Rue Marcel Dassault░░░░░░░░░░░░░░░░░░'
], predictions) ], predictions)
......
...@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None, ...@@ -45,8 +45,8 @@ flags.DEFINE_integer('number_of_steps', None,
def main(_): def main(_):
if not tf.gfile.Exists(FLAGS.eval_log_dir): if not tf.io.gfile.exists(FLAGS.eval_log_dir):
tf.gfile.MakeDirs(FLAGS.eval_log_dir) tf.io.gfile.makedirs(FLAGS.eval_log_dir)
dataset = common_flags.create_dataset(split_name=FLAGS.split_name) dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes, model = common_flags.create_model(dataset.num_char_classes,
...@@ -62,7 +62,7 @@ def main(_): ...@@ -62,7 +62,7 @@ def main(_):
eval_ops = model.create_summaries( eval_ops = model.create_summaries(
data, endpoints, dataset.charset, is_training=False) data, endpoints, dataset.charset, is_training=False)
slim.get_or_create_global_step() slim.get_or_create_global_step()
session_config = tf.ConfigProto(device_count={"GPU": 0}) session_config = tf.compat.v1.ConfigProto(device_count={"GPU": 0})
slim.evaluation.evaluation_loop( slim.evaluation.evaluation_loop(
master=FLAGS.master, master=FLAGS.master,
checkpoint_dir=FLAGS.train_log_dir, checkpoint_dir=FLAGS.train_log_dir,
......
...@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases): ...@@ -38,7 +38,7 @@ def apply_with_random_selector(x, func, num_cases):
The result of func(x, sel), where func receives the value of the The result of func(x, sel), where func receives the value of the
selector as a python integer, but sel is sampled dynamically. selector as a python integer, but sel is sampled dynamically.
""" """
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) sel = tf.random.uniform([], maxval=num_cases, dtype=tf.int32)
# Pass the real x only to one of the func calls. # Pass the real x only to one of the func calls.
return control_flow_ops.merge([ return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
...@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None): ...@@ -64,7 +64,7 @@ def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
Raises: Raises:
ValueError: if color_ordering not in [0, 3] ValueError: if color_ordering not in [0, 3]
""" """
with tf.name_scope(scope, 'distort_color', [image]): with tf.compat.v1.name_scope(scope, 'distort_color', [image]):
if fast_mode: if fast_mode:
if color_ordering == 0: if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.) image = tf.image.random_brightness(image, max_delta=32. / 255.)
...@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image, ...@@ -131,7 +131,7 @@ def distorted_bounding_box_crop(image,
Returns: Returns:
A tuple, a 3-D Tensor cropped_image and the distorted bbox A tuple, a 3-D Tensor cropped_image and the distorted bbox
""" """
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): with tf.compat.v1.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
# Each bounding box has shape [1, num_boxes, box coords] and # Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax]. # the coordinates are ordered [ymin, xmin, ymax, xmax].
...@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image, ...@@ -143,7 +143,7 @@ def distorted_bounding_box_crop(image,
# bounding box. If no box is supplied, then we assume the bounding box is # bounding box. If no box is supplied, then we assume the bounding box is
# the entire image. # the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image), image_size=tf.shape(input=image),
bounding_boxes=bbox, bounding_boxes=bbox,
min_object_covered=min_object_covered, min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range, aspect_ratio_range=aspect_ratio_range,
...@@ -188,7 +188,7 @@ def preprocess_for_train(image, ...@@ -188,7 +188,7 @@ def preprocess_for_train(image,
Returns: Returns:
3-D float Tensor of distorted image used for training with range [-1, 1]. 3-D float Tensor of distorted image used for training with range [-1, 1].
""" """
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): with tf.compat.v1.name_scope(scope, 'distort_image', [image, height, width, bbox]):
if bbox is None: if bbox is None:
bbox = tf.constant( bbox = tf.constant(
[0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) [0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
...@@ -198,7 +198,7 @@ def preprocess_for_train(image, ...@@ -198,7 +198,7 @@ def preprocess_for_train(image,
# the coordinates are ordered [ymin, xmin, ymax, xmax]. # the coordinates are ordered [ymin, xmin, ymax, xmax].
image_with_box = tf.image.draw_bounding_boxes( image_with_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), bbox) tf.expand_dims(image, 0), bbox)
tf.summary.image('image_with_bounding_boxes', image_with_box) tf.compat.v1.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses # Restore the shape since the dynamic slice based upon the bbox_size loses
...@@ -206,8 +206,8 @@ def preprocess_for_train(image, ...@@ -206,8 +206,8 @@ def preprocess_for_train(image,
distorted_image.set_shape([None, None, 3]) distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes( image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox) tf.expand_dims(image, 0), distorted_bbox)
tf.summary.image('images_with_distorted_bounding_box', tf.compat.v1.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box) image_with_distorted_box)
# This resizing operation may distort the images because the aspect # This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin # ratio is not respected. We select a resize method in a round robin
...@@ -218,11 +218,11 @@ def preprocess_for_train(image, ...@@ -218,11 +218,11 @@ def preprocess_for_train(image,
num_resize_cases = 1 if fast_mode else 4 num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector( distorted_image = apply_with_random_selector(
distorted_image, distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method=method), lambda x, method: tf.image.resize(x, [height, width], method=method),
num_cases=num_resize_cases) num_cases=num_resize_cases)
tf.summary.image('cropped_resized_image', tf.compat.v1.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0)) tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally. # Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image) distorted_image = tf.image.random_flip_left_right(distorted_image)
...@@ -233,8 +233,8 @@ def preprocess_for_train(image, ...@@ -233,8 +233,8 @@ def preprocess_for_train(image,
lambda x, ordering: distort_color(x, ordering, fast_mode), lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=4) num_cases=4)
tf.summary.image('final_distorted_image', tf.compat.v1.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0)) tf.expand_dims(distorted_image, 0))
distorted_image = tf.subtract(distorted_image, 0.5) distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0) distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image return distorted_image
...@@ -265,7 +265,7 @@ def preprocess_for_eval(image, ...@@ -265,7 +265,7 @@ def preprocess_for_eval(image,
Returns: Returns:
3-D float Tensor of prepared image. 3-D float Tensor of prepared image.
""" """
with tf.name_scope(scope, 'eval_image', [image, height, width]): with tf.compat.v1.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32: if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32) image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# Crop the central region of the image with an area containing 87.5% of # Crop the central region of the image with an area containing 87.5% of
...@@ -276,8 +276,8 @@ def preprocess_for_eval(image, ...@@ -276,8 +276,8 @@ def preprocess_for_eval(image,
if height and width: if height and width:
# Resize the image to the specified height and width. # Resize the image to the specified height and width.
image = tf.expand_dims(image, 0) image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear( image = tf.image.resize(
image, [height, width], align_corners=False) image, [height, width], method=tf.image.ResizeMethod.BILINEAR)
image = tf.squeeze(image, [0]) image = tf.squeeze(image, [0])
image = tf.subtract(image, 0.5) image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0) image = tf.multiply(image, 2.0)
......
...@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False): ...@@ -34,20 +34,21 @@ def char_accuracy(predictions, targets, rej_char, streaming=False):
a update_ops for execution and value tensor whose value on evaluation a update_ops for execution and value tensor whose value on evaluation
returns the total character accuracy. returns the total character accuracy.
""" """
with tf.variable_scope('CharAccuracy'): with tf.compat.v1.variable_scope('CharAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape()) predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets) targets = tf.cast(targets, dtype=tf.int32)
const_rej_char = tf.constant(rej_char, shape=targets.get_shape()) const_rej_char = tf.constant(rej_char, shape=targets.get_shape())
weights = tf.to_float(tf.not_equal(targets, const_rej_char)) weights = tf.cast(tf.not_equal(targets, const_rej_char), dtype=tf.float32)
correct_chars = tf.to_float(tf.equal(predictions, targets)) correct_chars = tf.cast(tf.equal(predictions, targets), dtype=tf.float32)
accuracy_per_example = tf.div( accuracy_per_example = tf.compat.v1.div(
tf.reduce_sum(tf.multiply(correct_chars, weights), 1), tf.reduce_sum(input_tensor=tf.multiply(
tf.reduce_sum(weights, 1)) correct_chars, weights), axis=1),
tf.reduce_sum(input_tensor=weights, axis=1))
if streaming: if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example) return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else: else:
return tf.reduce_mean(accuracy_per_example) return tf.reduce_mean(input_tensor=accuracy_per_example)
def sequence_accuracy(predictions, targets, rej_char, streaming=False): def sequence_accuracy(predictions, targets, rej_char, streaming=False):
...@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False): ...@@ -66,25 +67,26 @@ def sequence_accuracy(predictions, targets, rej_char, streaming=False):
returns the total sequence accuracy. returns the total sequence accuracy.
""" """
with tf.variable_scope('SequenceAccuracy'): with tf.compat.v1.variable_scope('SequenceAccuracy'):
predictions.get_shape().assert_is_compatible_with(targets.get_shape()) predictions.get_shape().assert_is_compatible_with(targets.get_shape())
targets = tf.to_int32(targets) targets = tf.cast(targets, dtype=tf.int32)
const_rej_char = tf.constant( const_rej_char = tf.constant(
rej_char, shape=targets.get_shape(), dtype=tf.int32) rej_char, shape=targets.get_shape(), dtype=tf.int32)
include_mask = tf.not_equal(targets, const_rej_char) include_mask = tf.not_equal(targets, const_rej_char)
include_predictions = tf.to_int32( include_predictions = tf.cast(
tf.where(include_mask, predictions, tf.compat.v1.where(include_mask, predictions,
tf.zeros_like(predictions) + rej_char)) tf.zeros_like(predictions) + rej_char), dtype=tf.int32)
correct_chars = tf.to_float(tf.equal(include_predictions, targets)) correct_chars = tf.cast(
tf.equal(include_predictions, targets), dtype=tf.float32)
correct_chars_counts = tf.cast( correct_chars_counts = tf.cast(
tf.reduce_sum(correct_chars, reduction_indices=[1]), dtype=tf.int32) tf.reduce_sum(input_tensor=correct_chars, axis=[1]), dtype=tf.int32)
target_length = targets.get_shape().dims[1].value target_length = targets.get_shape().dims[1].value
target_chars_counts = tf.constant( target_chars_counts = tf.constant(
target_length, shape=correct_chars_counts.get_shape()) target_length, shape=correct_chars_counts.get_shape())
accuracy_per_example = tf.to_float( accuracy_per_example = tf.cast(
tf.equal(correct_chars_counts, target_chars_counts)) tf.equal(correct_chars_counts, target_chars_counts), dtype=tf.float32)
if streaming: if streaming:
return tf.contrib.metrics.streaming_mean(accuracy_per_example) return tf.contrib.metrics.streaming_mean(accuracy_per_example)
else: else:
return tf.reduce_mean(accuracy_per_example) return tf.reduce_mean(input_tensor=accuracy_per_example)
...@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase): ...@@ -38,8 +38,8 @@ class AccuracyTest(tf.test.TestCase):
A session object that should be used as a context manager. A session object that should be used as a context manager.
""" """
with self.cached_session() as sess: with self.cached_session() as sess:
sess.run(tf.global_variables_initializer()) sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.local_variables_initializer()) sess.run(tf.compat.v1.local_variables_initializer())
yield sess yield sess
def _fake_labels(self): def _fake_labels(self):
...@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase): ...@@ -55,7 +55,7 @@ class AccuracyTest(tf.test.TestCase):
return incorrect return incorrect
def test_sequence_accuracy_identical_samples(self): def test_sequence_accuracy_identical_samples(self):
labels_tf = tf.convert_to_tensor(self._fake_labels()) labels_tf = tf.convert_to_tensor(value=self._fake_labels())
accuracy_tf = metrics.sequence_accuracy(labels_tf, labels_tf, accuracy_tf = metrics.sequence_accuracy(labels_tf, labels_tf,
self.rej_char) self.rej_char)
...@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase): ...@@ -66,9 +66,9 @@ class AccuracyTest(tf.test.TestCase):
def test_sequence_accuracy_one_char_difference(self): def test_sequence_accuracy_one_char_difference(self):
ground_truth_np = self._fake_labels() ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np) ground_truth_tf = tf.convert_to_tensor(value=ground_truth_np)
prediction_tf = tf.convert_to_tensor( prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0)))) value=self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.sequence_accuracy(prediction_tf, ground_truth_tf, accuracy_tf = metrics.sequence_accuracy(prediction_tf, ground_truth_tf,
self.rej_char) self.rej_char)
...@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase): ...@@ -80,9 +80,9 @@ class AccuracyTest(tf.test.TestCase):
def test_char_accuracy_one_char_difference_with_padding(self): def test_char_accuracy_one_char_difference_with_padding(self):
ground_truth_np = self._fake_labels() ground_truth_np = self._fake_labels()
ground_truth_tf = tf.convert_to_tensor(ground_truth_np) ground_truth_tf = tf.convert_to_tensor(value=ground_truth_np)
prediction_tf = tf.convert_to_tensor( prediction_tf = tf.convert_to_tensor(
self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0)))) value=self._incorrect_copy(ground_truth_np, bad_indexes=((0, 0))))
accuracy_tf = metrics.char_accuracy(prediction_tf, ground_truth_tf, accuracy_tf = metrics.char_accuracy(prediction_tf, ground_truth_tf,
self.rej_char) self.rej_char)
......
...@@ -92,8 +92,8 @@ class CharsetMapper(object): ...@@ -92,8 +92,8 @@ class CharsetMapper(object):
Args: Args:
ids: a tensor with shape [batch_size, max_sequence_length] ids: a tensor with shape [batch_size, max_sequence_length]
""" """
return tf.reduce_join( return tf.strings.reduce_join(
self.table.lookup(tf.to_int64(ids)), reduction_indices=1) inputs=self.table.lookup(tf.cast(ids, dtype=tf.int64)), axis=1)
def get_softmax_loss_fn(label_smoothing): def get_softmax_loss_fn(label_smoothing):
...@@ -110,7 +110,7 @@ def get_softmax_loss_fn(label_smoothing): ...@@ -110,7 +110,7 @@ def get_softmax_loss_fn(label_smoothing):
def loss_fn(labels, logits): def loss_fn(labels, logits):
return (tf.nn.softmax_cross_entropy_with_logits( return (tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)) logits=logits, labels=tf.stop_gradient(labels)))
else: else:
def loss_fn(labels, logits): def loss_fn(labels, logits):
...@@ -140,7 +140,7 @@ def get_tensor_dimensions(tensor): ...@@ -140,7 +140,7 @@ def get_tensor_dimensions(tensor):
raise ValueError( raise ValueError(
'Incompatible shape: len(tensor.get_shape().dims) != 4 (%d != 4)' % 'Incompatible shape: len(tensor.get_shape().dims) != 4 (%d != 4)' %
len(tensor.get_shape().dims)) len(tensor.get_shape().dims))
batch_size = tf.shape(tensor)[0] batch_size = tf.shape(input=tensor)[0]
height = tensor.get_shape().dims[1].value height = tensor.get_shape().dims[1].value
width = tensor.get_shape().dims[2].value width = tensor.get_shape().dims[2].value
num_features = tensor.get_shape().dims[3].value num_features = tensor.get_shape().dims[3].value
...@@ -161,8 +161,8 @@ def lookup_indexed_value(indices, row_vecs): ...@@ -161,8 +161,8 @@ def lookup_indexed_value(indices, row_vecs):
A tensor of shape (batch, ) formed by row_vecs[i, indices[i]]. A tensor of shape (batch, ) formed by row_vecs[i, indices[i]].
""" """
gather_indices = tf.stack((tf.range( gather_indices = tf.stack((tf.range(
tf.shape(row_vecs)[0], dtype=tf.int32), tf.cast(indices, tf.int32)), tf.shape(input=row_vecs)[0], dtype=tf.int32), tf.cast(indices, tf.int32)),
axis=1) axis=1)
return tf.gather_nd(row_vecs, gather_indices) return tf.gather_nd(row_vecs, gather_indices)
...@@ -181,7 +181,7 @@ def max_char_logprob_cumsum(char_log_prob): ...@@ -181,7 +181,7 @@ def max_char_logprob_cumsum(char_log_prob):
so the same function can be used regardless whether use_length_predictions so the same function can be used regardless whether use_length_predictions
is true or false. is true or false.
""" """
max_char_log_prob = tf.reduce_max(char_log_prob, reduction_indices=2) max_char_log_prob = tf.reduce_max(input_tensor=char_log_prob, axis=2)
# For an input array [a, b, c]) tf.cumsum returns [a, a + b, a + b + c] if # For an input array [a, b, c]) tf.cumsum returns [a, a + b, a + b + c] if
# exclusive set to False (default). # exclusive set to False (default).
return tf.cumsum(max_char_log_prob, axis=1, exclusive=False) return tf.cumsum(max_char_log_prob, axis=1, exclusive=False)
...@@ -203,7 +203,7 @@ def find_length_by_null(predicted_chars, null_code): ...@@ -203,7 +203,7 @@ def find_length_by_null(predicted_chars, null_code):
A [batch, ] tensor which stores the sequence length for each sample. A [batch, ] tensor which stores the sequence length for each sample.
""" """
return tf.reduce_sum( return tf.reduce_sum(
tf.cast(tf.not_equal(null_code, predicted_chars), tf.int32), axis=1) input_tensor=tf.cast(tf.not_equal(null_code, predicted_chars), tf.int32), axis=1)
def axis_pad(tensor, axis, before=0, after=0, constant_values=0.0): def axis_pad(tensor, axis, before=0, after=0, constant_values=0.0):
...@@ -248,7 +248,8 @@ def null_based_length_prediction(chars_log_prob, null_code): ...@@ -248,7 +248,8 @@ def null_based_length_prediction(chars_log_prob, null_code):
element #seq_length - is the probability of length=seq_length. element #seq_length - is the probability of length=seq_length.
predicted_length is a tensor with shape [batch]. predicted_length is a tensor with shape [batch].
""" """
predicted_chars = tf.to_int32(tf.argmax(chars_log_prob, axis=2)) predicted_chars = tf.cast(
tf.argmax(input=chars_log_prob, axis=2), dtype=tf.int32)
# We do right pad to support sequences with seq_length elements. # We do right pad to support sequences with seq_length elements.
text_log_prob = max_char_logprob_cumsum( text_log_prob = max_char_logprob_cumsum(
axis_pad(chars_log_prob, axis=1, after=1)) axis_pad(chars_log_prob, axis=1, after=1))
...@@ -334,9 +335,9 @@ class Model(object): ...@@ -334,9 +335,9 @@ class Model(object):
""" """
mparams = self._mparams['conv_tower_fn'] mparams = self._mparams['conv_tower_fn']
logging.debug('Using final_endpoint=%s', mparams.final_endpoint) logging.debug('Using final_endpoint=%s', mparams.final_endpoint)
with tf.variable_scope('conv_tower_fn/INCE'): with tf.compat.v1.variable_scope('conv_tower_fn/INCE'):
if reuse: if reuse:
tf.get_variable_scope().reuse_variables() tf.compat.v1.get_variable_scope().reuse_variables()
with slim.arg_scope(inception.inception_v3_arg_scope()): with slim.arg_scope(inception.inception_v3_arg_scope()):
with slim.arg_scope([slim.batch_norm, slim.dropout], with slim.arg_scope([slim.batch_norm, slim.dropout],
is_training=is_training): is_training=is_training):
...@@ -372,7 +373,7 @@ class Model(object): ...@@ -372,7 +373,7 @@ class Model(object):
def sequence_logit_fn(self, net, labels_one_hot): def sequence_logit_fn(self, net, labels_one_hot):
mparams = self._mparams['sequence_logit_fn'] mparams = self._mparams['sequence_logit_fn']
# TODO(gorban): remove /alias suffixes from the scopes. # TODO(gorban): remove /alias suffixes from the scopes.
with tf.variable_scope('sequence_logit_fn/SQLR'): with tf.compat.v1.variable_scope('sequence_logit_fn/SQLR'):
layer_class = sequence_layers.get_layer_class(mparams.use_attention, layer_class = sequence_layers.get_layer_class(mparams.use_attention,
mparams.use_autoregression) mparams.use_autoregression)
layer = layer_class(net, labels_one_hot, self._params, mparams) layer = layer_class(net, labels_one_hot, self._params, mparams)
...@@ -392,7 +393,7 @@ class Model(object): ...@@ -392,7 +393,7 @@ class Model(object):
] ]
xy_flat_shape = (batch_size, 1, height * width, num_features) xy_flat_shape = (batch_size, 1, height * width, num_features)
nets_for_merge = [] nets_for_merge = []
with tf.variable_scope('max_pool_views', values=nets_list): with tf.compat.v1.variable_scope('max_pool_views', values=nets_list):
for net in nets_list: for net in nets_list:
nets_for_merge.append(tf.reshape(net, xy_flat_shape)) nets_for_merge.append(tf.reshape(net, xy_flat_shape))
merged_net = tf.concat(nets_for_merge, 1) merged_net = tf.concat(nets_for_merge, 1)
...@@ -413,10 +414,11 @@ class Model(object): ...@@ -413,10 +414,11 @@ class Model(object):
Returns: Returns:
A tensor of shape [batch_size, seq_length, features_size]. A tensor of shape [batch_size, seq_length, features_size].
""" """
with tf.variable_scope('pool_views_fn/STCK'): with tf.compat.v1.variable_scope('pool_views_fn/STCK'):
net = tf.concat(nets, 1) net = tf.concat(nets, 1)
batch_size = tf.shape(net)[0] batch_size = tf.shape(input=net)[0]
image_size = net.get_shape().dims[1].value * net.get_shape().dims[2].value image_size = net.get_shape().dims[1].value * \
net.get_shape().dims[2].value
feature_size = net.get_shape().dims[3].value feature_size = net.get_shape().dims[3].value
return tf.reshape(net, tf.stack([batch_size, image_size, feature_size])) return tf.reshape(net, tf.stack([batch_size, image_size, feature_size]))
...@@ -438,11 +440,13 @@ class Model(object): ...@@ -438,11 +440,13 @@ class Model(object):
with shape [batch_size x seq_length]. with shape [batch_size x seq_length].
""" """
log_prob = utils.logits_to_log_prob(chars_logit) log_prob = utils.logits_to_log_prob(chars_logit)
ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars') ids = tf.cast(tf.argmax(input=log_prob, axis=2),
name='predicted_chars', dtype=tf.int32)
mask = tf.cast( mask = tf.cast(
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool)
all_scores = tf.nn.softmax(chars_logit) all_scores = tf.nn.softmax(chars_logit)
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores') selected_scores = tf.boolean_mask(
tensor=all_scores, mask=mask, name='char_scores')
scores = tf.reshape( scores = tf.reshape(
selected_scores, selected_scores,
shape=(-1, self._params.seq_length), shape=(-1, self._params.seq_length),
...@@ -499,7 +503,7 @@ class Model(object): ...@@ -499,7 +503,7 @@ class Model(object):
images = tf.subtract(images, 0.5) images = tf.subtract(images, 0.5)
images = tf.multiply(images, 2.5) images = tf.multiply(images, 2.5)
with tf.variable_scope(scope, reuse=reuse): with tf.compat.v1.variable_scope(scope, reuse=reuse):
views = tf.split( views = tf.split(
value=images, num_or_size_splits=self._params.num_views, axis=2) value=images, num_or_size_splits=self._params.num_views, axis=2)
logging.debug('Views=%d single view: %s', len(views), views[0]) logging.debug('Views=%d single view: %s', len(views), views[0])
...@@ -566,7 +570,7 @@ class Model(object): ...@@ -566,7 +570,7 @@ class Model(object):
# multiple losses including regularization losses. # multiple losses including regularization losses.
self.sequence_loss_fn(endpoints.chars_logit, data.labels) self.sequence_loss_fn(endpoints.chars_logit, data.labels)
total_loss = slim.losses.get_total_loss() total_loss = slim.losses.get_total_loss()
tf.summary.scalar('TotalLoss', total_loss) tf.compat.v1.summary.scalar('TotalLoss', total_loss)
return total_loss return total_loss
def label_smoothing_regularization(self, chars_labels, weight=0.1): def label_smoothing_regularization(self, chars_labels, weight=0.1):
...@@ -605,7 +609,7 @@ class Model(object): ...@@ -605,7 +609,7 @@ class Model(object):
A Tensor with shape [batch_size] - the log-perplexity for each sequence. A Tensor with shape [batch_size] - the log-perplexity for each sequence.
""" """
mparams = self._mparams['sequence_loss_fn'] mparams = self._mparams['sequence_loss_fn']
with tf.variable_scope('sequence_loss_fn/SLF'): with tf.compat.v1.variable_scope('sequence_loss_fn/SLF'):
if mparams.label_smoothing > 0: if mparams.label_smoothing > 0:
smoothed_one_hot_labels = self.label_smoothing_regularization( smoothed_one_hot_labels = self.label_smoothing_regularization(
chars_labels, mparams.label_smoothing) chars_labels, mparams.label_smoothing)
...@@ -625,7 +629,7 @@ class Model(object): ...@@ -625,7 +629,7 @@ class Model(object):
shape=(batch_size, seq_length), shape=(batch_size, seq_length),
dtype=tf.int64) dtype=tf.int64)
known_char = tf.not_equal(chars_labels, reject_char) known_char = tf.not_equal(chars_labels, reject_char)
weights = tf.to_float(known_char) weights = tf.cast(known_char, dtype=tf.float32)
logits_list = tf.unstack(chars_logits, axis=1) logits_list = tf.unstack(chars_logits, axis=1)
weights_list = tf.unstack(weights, axis=1) weights_list = tf.unstack(weights, axis=1)
...@@ -635,7 +639,7 @@ class Model(object): ...@@ -635,7 +639,7 @@ class Model(object):
weights_list, weights_list,
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing), softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing),
average_across_timesteps=mparams.average_across_timesteps) average_across_timesteps=mparams.average_across_timesteps)
tf.losses.add_loss(loss) tf.compat.v1.losses.add_loss(loss)
return loss return loss
def create_summaries(self, data, endpoints, charset, is_training): def create_summaries(self, data, endpoints, charset, is_training):
...@@ -665,13 +669,14 @@ class Model(object): ...@@ -665,13 +669,14 @@ class Model(object):
# tf.summary.text(sname('text/pr'), pr_text) # tf.summary.text(sname('text/pr'), pr_text)
# gt_text = charset_mapper.get_text(data.labels[:max_outputs,:]) # gt_text = charset_mapper.get_text(data.labels[:max_outputs,:])
# tf.summary.text(sname('text/gt'), gt_text) # tf.summary.text(sname('text/gt'), gt_text)
tf.summary.image(sname('image'), data.images, max_outputs=max_outputs) tf.compat.v1.summary.image(
sname('image'), data.images, max_outputs=max_outputs)
if is_training: if is_training:
tf.summary.image( tf.compat.v1.summary.image(
sname('image/orig'), data.images_orig, max_outputs=max_outputs) sname('image/orig'), data.images_orig, max_outputs=max_outputs)
for var in tf.trainable_variables(): for var in tf.compat.v1.trainable_variables():
tf.summary.histogram(var.op.name, var) tf.compat.v1.summary.histogram(var.op.name, var)
return None return None
else: else:
...@@ -700,7 +705,8 @@ class Model(object): ...@@ -700,7 +705,8 @@ class Model(object):
for name, value in names_to_values.items(): for name, value in names_to_values.items():
summary_name = 'eval/' + name summary_name = 'eval/' + name
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name)) tf.compat.v1.summary.scalar(
summary_name, tf.compat.v1.Print(value, [value], summary_name))
return list(names_to_updates.values()) return list(names_to_updates.values())
def create_init_fn_to_restore(self, def create_init_fn_to_restore(self,
...@@ -733,9 +739,9 @@ class Model(object): ...@@ -733,9 +739,9 @@ class Model(object):
logging.info('variables_to_restore:\n%s', logging.info('variables_to_restore:\n%s',
utils.variables_to_restore().keys()) utils.variables_to_restore().keys())
logging.info('moving_average_variables:\n%s', logging.info('moving_average_variables:\n%s',
[v.op.name for v in tf.moving_average_variables()]) [v.op.name for v in tf.compat.v1.moving_average_variables()])
logging.info('trainable_variables:\n%s', logging.info('trainable_variables:\n%s',
[v.op.name for v in tf.trainable_variables()]) [v.op.name for v in tf.compat.v1.trainable_variables()])
if master_checkpoint: if master_checkpoint:
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint) assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint)
......
...@@ -42,7 +42,8 @@ flags.DEFINE_integer( ...@@ -42,7 +42,8 @@ flags.DEFINE_integer(
'image_height', None, 'image_height', None,
'Image height used during training(or crop height if used)' 'Image height used during training(or crop height if used)'
' If not set, the dataset default is used instead.') ' If not set, the dataset default is used instead.')
flags.DEFINE_string('work_dir', '/tmp', 'A directory to store temporary files.') flags.DEFINE_string('work_dir', '/tmp',
'A directory to store temporary files.')
flags.DEFINE_integer('version_number', 1, 'Version number of the model') flags.DEFINE_integer('version_number', 1, 'Version number of the model')
flags.DEFINE_bool( flags.DEFINE_bool(
'export_for_serving', True, 'export_for_serving', True,
...@@ -116,7 +117,7 @@ def export_model(export_dir, ...@@ -116,7 +117,7 @@ def export_model(export_dir,
image_height = crop_image_height or dataset_image_height image_height = crop_image_height or dataset_image_height
if export_for_serving: if export_for_serving:
images_orig = tf.placeholder( images_orig = tf.compat.v1.placeholder(
tf.string, shape=[batch_size], name='tf_example') tf.string, shape=[batch_size], name='tf_example')
images_orig_float = model_export_lib.generate_tfexample_image( images_orig_float = model_export_lib.generate_tfexample_image(
images_orig, images_orig,
...@@ -126,22 +127,23 @@ def export_model(export_dir, ...@@ -126,22 +127,23 @@ def export_model(export_dir,
name='float_images') name='float_images')
else: else:
images_shape = (batch_size, image_height, image_width, image_depth) images_shape = (batch_size, image_height, image_width, image_depth)
images_orig = tf.placeholder( images_orig = tf.compat.v1.placeholder(
tf.uint8, shape=images_shape, name='original_image') tf.uint8, shape=images_shape, name='original_image')
images_orig_float = tf.image.convert_image_dtype( images_orig_float = tf.image.convert_image_dtype(
images_orig, dtype=tf.float32, name='float_images') images_orig, dtype=tf.float32, name='float_images')
endpoints = model.create_base(images_orig_float, labels_one_hot=None) endpoints = model.create_base(images_orig_float, labels_one_hot=None)
sess = tf.Session() sess = tf.compat.v1.Session()
saver = tf.train.Saver(slim.get_variables_to_restore(), sharded=True) saver = tf.compat.v1.train.Saver(
slim.get_variables_to_restore(), sharded=True)
saver.restore(sess, get_checkpoint_path()) saver.restore(sess, get_checkpoint_path())
tf.logging.info('Model restored successfully.') tf.compat.v1.logging.info('Model restored successfully.')
# Create model signature. # Create model signature.
if export_for_serving: if export_for_serving:
input_tensors = { input_tensors = {
tf.saved_model.signature_constants.CLASSIFY_INPUTS: images_orig tf.saved_model.CLASSIFY_INPUTS: images_orig
} }
else: else:
input_tensors = {'images': images_orig} input_tensors = {'images': images_orig}
...@@ -163,21 +165,21 @@ def export_model(export_dir, ...@@ -163,21 +165,21 @@ def export_model(export_dir,
dataset.max_sequence_length)): dataset.max_sequence_length)):
output_tensors['attention_mask_%d' % i] = t output_tensors['attention_mask_%d' % i] = t
signature_outputs = model_export_lib.build_tensor_info(output_tensors) signature_outputs = model_export_lib.build_tensor_info(output_tensors)
signature_def = tf.saved_model.signature_def_utils.build_signature_def( signature_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
signature_inputs, signature_outputs, signature_inputs, signature_outputs,
tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME) tf.saved_model.CLASSIFY_METHOD_NAME)
# Save model. # Save model.
builder = tf.saved_model.builder.SavedModelBuilder(export_dir) builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING], sess, [tf.saved_model.SERVING],
signature_def_map={ signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature_def signature_def
}, },
main_op=tf.tables_initializer(), main_op=tf.compat.v1.tables_initializer(),
strip_default_attrs=True) strip_default_attrs=True)
builder.save() builder.save()
tf.logging.info('Model has been exported to %s' % export_dir) tf.compat.v1.logging.info('Model has been exported to %s' % export_dir)
return signature_def return signature_def
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment