Unverified Commit 43178d7f authored by Ayushman Kumar's avatar Ayushman Kumar Committed by GitHub
Browse files

Merge pull request #1 from tensorflow/master

Updated
parents 8b47aa3d 75d13042
...@@ -172,7 +172,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase): ...@@ -172,7 +172,7 @@ class BertSquadBenchmarkReal(BertSquadBenchmarkBase):
run_eagerly=False, run_eagerly=False,
ds_type='mirrored'): ds_type='mirrored'):
"""Runs the benchmark and reports various metrics.""" """Runs the benchmark and reports various metrics."""
if FLAGS.train_batch_size <= 4: if FLAGS.train_batch_size <= 4 or run_eagerly:
FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_MEDIUM_INPUT_META_DATA_PATH
else: else:
FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH FLAGS.input_meta_data_path = SQUAD_LONG_INPUT_META_DATA_PATH
......
...@@ -143,12 +143,10 @@ class Config(params_dict.ParamsDict): ...@@ -143,12 +143,10 @@ class Config(params_dict.ParamsDict):
return subconfig_type return subconfig_type
def __post_init__(self, default_params, restrictions, *args, **kwargs): def __post_init__(self, default_params, restrictions, *args, **kwargs):
logging.error('DEBUG before init %r', type(self))
super().__init__(default_params=default_params, super().__init__(default_params=default_params,
restrictions=restrictions, restrictions=restrictions,
*args, *args,
**kwargs) **kwargs)
logging.error('DEBUG after init %r', type(self))
def _set(self, k, v): def _set(self, k, v):
"""Overrides same method in ParamsDict. """Overrides same method in ParamsDict.
...@@ -246,3 +244,71 @@ class Config(params_dict.ParamsDict): ...@@ -246,3 +244,71 @@ class Config(params_dict.ParamsDict):
default_params = {a: p for a, p in zip(attributes, args)} default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs) default_params.update(kwargs)
return cls(default_params) return cls(default_params)
@dataclasses.dataclass
class RuntimeConfig(Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_eager: Whether or not to enable eager mode.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_threads_enabled: Whether or not GPU threads are enabled.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
"""
distribution_strategy: str = 'mirrored'
enable_eager: bool = False
enable_xla: bool = False
gpu_threads_enabled: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
@dataclasses.dataclass
class TensorboardConfig(Config):
"""Configuration for Tensorboard.
Attributes:
track_lr: Whether or not to track the learning rate in Tensorboard. Defaults
to True.
write_model_weights: Whether or not to write the model weights as
images in Tensorboard. Defaults to False.
"""
track_lr: bool = True
write_model_weights: bool = False
@dataclasses.dataclass
class CallbacksConfig(Config):
"""Configuration for Callbacks.
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
"""
enable_checkpoint_and_export: bool = True
enable_tensorboard: bool = True
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based attention layer.""" """Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -45,7 +45,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -45,7 +45,7 @@ class Attention(tf.keras.layers.Layer):
interpolated by these probabilities, then concatenated back to a single interpolated by these probabilities, then concatenated back to a single
tensor and returned. tensor and returned.
Attributes: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. head_size: Size of each attention head.
dropout: Dropout probability. dropout: Dropout probability.
...@@ -186,7 +186,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -186,7 +186,7 @@ class Attention(tf.keras.layers.Layer):
class CachedAttention(Attention): class CachedAttention(Attention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
Attributes: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class. **kwargs: Other keyword arguments inherit from `Attention` class.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based einsum layer.""" """Keras-based einsum layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -30,7 +30,7 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -30,7 +30,7 @@ class DenseEinsum(tf.keras.layers.Layer):
This layer can perform einsum calculations of arbitrary dimensionality. This layer can perform einsum calculations of arbitrary dimensionality.
Attributes: Arguments:
output_shape: Positive integer or tuple, dimensionality of the output space. output_shape: Positive integer or tuple, dimensionality of the output space.
num_summed_dimensions: The number of dimensions to sum over. Standard 2D num_summed_dimensions: The number of dimensions to sum over. Standard 2D
matmul should use 1, 3D matmul should use 2, and so forth. matmul should use 1, 3D matmul should use 2, and so forth.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based softmax layer with optional masking.""" """Keras-based softmax layer with optional masking."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -26,7 +26,7 @@ import tensorflow as tf ...@@ -26,7 +26,7 @@ import tensorflow as tf
class MaskedSoftmax(tf.keras.layers.Layer): class MaskedSoftmax(tf.keras.layers.Layer):
"""Performs a softmax with optional masking on a tensor. """Performs a softmax with optional masking on a tensor.
Attributes: Arguments:
mask_expansion_axes: Any axes that should be padded on the mask tensor. mask_expansion_axes: Any axes that should be padded on the mask tensor.
""" """
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based one-hot embedding layer.""" """Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -31,7 +31,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
This layer uses either tf.gather or tf.one_hot to translate integer indices to This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings. float embeddings.
Attributes: Arguments:
vocab_size: Number of elements in the vocabulary. vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer. embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to initializer: The initializer to use for the embedding weights. Defaults to
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based positional embedding layer.""" """Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -37,7 +37,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -37,7 +37,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
input size must be fixed. input size must be fixed.
Attributes: Arguments:
use_dynamic_slicing: Whether to use the dynamic slicing path. use_dynamic_slicing: Whether to use the dynamic slicing path.
max_sequence_length: The maximum size of the dynamic sequence. Only max_sequence_length: The maximum size of the dynamic sequence. Only
applicable if `use_dynamic_slicing` is True. applicable if `use_dynamic_slicing` is True.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based transformer block layer.""" """Keras-based transformer block layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -32,7 +32,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -32,7 +32,7 @@ class Transformer(tf.keras.layers.Layer):
This layer implements the Transformer from "Attention Is All You Need". This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762). (https://arxiv.org/abs/1706.03762).
Attributes: Arguments:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based transformer scaffold layer.""" """Keras-based transformer scaffold layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -35,7 +35,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -35,7 +35,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
`attention_cfg`, in which case the scaffold will instantiate the class with `attention_cfg`, in which case the scaffold will instantiate the class with
the config, or pass a class instance to `attention_cls`. the config, or pass a class instance to `attention_cls`.
Attributes: Arguments:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.""" """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -41,7 +41,7 @@ class AlbertTransformerEncoder(network.Network): ...@@ -41,7 +41,7 @@ class AlbertTransformerEncoder(network.Network):
The default values for this object are taken from the ALBERT-Base The default values for this object are taken from the ALBERT-Base
implementation described in the paper. implementation described in the paper.
Attributes: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into is not equal to hidden size, embedding parameters will be factorized into
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -36,7 +36,7 @@ class BertClassifier(tf.keras.Model): ...@@ -36,7 +36,7 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. argument.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -37,7 +37,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -37,7 +37,7 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -35,7 +35,7 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -35,7 +35,7 @@ class BertSpanLabeler(tf.keras.Model):
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer stack, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classification network.""" """Classification network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class Classification(network.Network): ...@@ -31,7 +31,7 @@ class Classification(network.Network):
This network implements a simple classifier head based on a dense layer. This network implements a simple classifier head based on a dense layer.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -46,7 +46,7 @@ class EncoderScaffold(network.Network): ...@@ -46,7 +46,7 @@ class EncoderScaffold(network.Network):
If the hidden_cls is not overridden, a default transformer layer will be If the hidden_cls is not overridden, a default transformer layer will be
instantiated. instantiated.
Attributes: Arguments:
num_output_classes: The output size of the classification layer. num_output_classes: The output size of the classification layer.
classification_layer_initializer: The initializer for the classification classification_layer_initializer: The initializer for the classification
layer. layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language model network.""" """Masked language model network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -32,7 +32,7 @@ class MaskedLM(network.Network): ...@@ -32,7 +32,7 @@ class MaskedLM(network.Network):
This network implements a masked language model based on the provided network. This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method. It assumes that the network being passed has a "get_embedding_table()" method.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_predictions: The number of predictions to make per sequence. num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the source_network: The network with the embedding layer to use for the
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Span labeling network.""" """Span labeling network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class SpanLabeling(network.Network): ...@@ -31,7 +31,7 @@ class SpanLabeling(network.Network):
This network implements a simple single-span labeler based on a dense layer. This network implements a simple single-span labeler based on a dense layer.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The intializer for the dense layer in this network. Defaults to
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -40,7 +40,7 @@ class TransformerEncoder(network.Network): ...@@ -40,7 +40,7 @@ class TransformerEncoder(network.Network):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding". Understanding".
Attributes: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
......
...@@ -21,7 +21,9 @@ customization of freeze_bn_delay. ...@@ -21,7 +21,9 @@ customization of freeze_bn_delay.
""" """
import re import re
import tensorflow as tf import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.contrib.quantize.python import quant_ops from tensorflow.contrib.quantize.python import quant_ops
...@@ -72,17 +74,18 @@ def build(graph_rewriter_config, ...@@ -72,17 +74,18 @@ def build(graph_rewriter_config,
# Quantize the graph by inserting quantize ops for weights and activations # Quantize the graph by inserting quantize ops for weights and activations
if is_training: if is_training:
tf.contrib.quantize.experimental_create_training_graph( contrib_quantize.experimental_create_training_graph(
input_graph=graph, input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay, quant_delay=graph_rewriter_config.quantization.delay,
freeze_bn_delay=graph_rewriter_config.quantization.delay) freeze_bn_delay=graph_rewriter_config.quantization.delay)
else: else:
tf.contrib.quantize.experimental_create_eval_graph( contrib_quantize.experimental_create_eval_graph(
input_graph=graph, input_graph=graph,
quant_delay=graph_rewriter_config.quantization.delay quant_delay=graph_rewriter_config.quantization.delay
if not is_export else 0) if not is_export else 0)
tf.contrib.layers.summarize_collection('quant_vars') contrib_layers.summarize_collection('quant_vars')
return graph_rewrite_fn return graph_rewrite_fn
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
"""Tests for graph_rewriter_builder.""" """Tests for graph_rewriter_builder."""
import mock import mock
import tensorflow as tf import tensorflow.compat.v1 as tf
from tensorflow.contrib import layers as contrib_layers
from tensorflow.contrib import quantize as contrib_quantize
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from lstm_object_detection.builders import graph_rewriter_builder from lstm_object_detection.builders import graph_rewriter_builder
...@@ -27,9 +29,9 @@ class QuantizationBuilderTest(tf.test.TestCase): ...@@ -27,9 +29,9 @@ class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
with mock.patch.object( with mock.patch.object(
tf.contrib.quantize, contrib_quantize,
'experimental_create_training_graph') as mock_quant_fn: 'experimental_create_training_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers, with mock.patch.object(contrib_layers,
'summarize_collection') as mock_summarize_col: 'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10 graph_rewriter_proto.quantization.delay = 10
...@@ -44,9 +46,9 @@ class QuantizationBuilderTest(tf.test.TestCase): ...@@ -44,9 +46,9 @@ class QuantizationBuilderTest(tf.test.TestCase):
mock_summarize_col.assert_called_with('quant_vars') mock_summarize_col.assert_called_with('quant_vars')
def testQuantizationBuilderSetsUpCorrectEvalArguments(self): def testQuantizationBuilderSetsUpCorrectEvalArguments(self):
with mock.patch.object(tf.contrib.quantize, with mock.patch.object(contrib_quantize,
'experimental_create_eval_graph') as mock_quant_fn: 'experimental_create_eval_graph') as mock_quant_fn:
with mock.patch.object(tf.contrib.layers, with mock.patch.object(contrib_layers,
'summarize_collection') as mock_summarize_col: 'summarize_collection') as mock_summarize_col:
graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter() graph_rewriter_proto = graph_rewriter_pb2.GraphRewriter()
graph_rewriter_proto.quantization.delay = 10 graph_rewriter_proto.quantization.delay = 10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment