Unverified Commit 11dc461f authored by thunderfyc's avatar thunderfyc Committed by GitHub
Browse files

Rename sequence_projection to seq_flow_lite (#9448)

* Rename sequence_projection to seq_flow_lite

* Rename sequence_projection to seq_flow_lite
parent 63665121
...@@ -9,6 +9,10 @@ build --action_env=PYTHON_BIN_PATH=/usr/bin/python3 ...@@ -9,6 +9,10 @@ build --action_env=PYTHON_BIN_PATH=/usr/bin/python3
build --repo_env=PYTHON_BIN_PATH=/usr/bin/python3 build --repo_env=PYTHON_BIN_PATH=/usr/bin/python3
build --python_path=/usr/bin/python3 build --python_path=/usr/bin/python3
# Enable using platform specific build settings
build --enable_platform_specific_config
# Flag to enable remote config. Required starting from TF 2.2.
common --experimental_repo_remote_exec common --experimental_repo_remote_exec
build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain build:manylinux2010 --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010:toolchain
...@@ -24,5 +28,55 @@ build --linkopt="-lrt -lm" ...@@ -24,5 +28,55 @@ build --linkopt="-lrt -lm"
# of defines when using tf's headers. In particular in refcount.h. # of defines when using tf's headers. In particular in refcount.h.
build --cxxopt="-DNDEBUG" build --cxxopt="-DNDEBUG"
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --spawn_strategy=standalone
build -c opt
# Adding "--cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0" creates parity with TF
# compilation options. It also addresses memory use due to
# copy-on-write semantics of std::strings of the older ABI.
build --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=0
# Make Bazel print out all options from rc files.
build --announce_rc
# Other build flags.
build --define=grpc_no_ares=true
# See https://github.com/bazelbuild/bazel/issues/7362 for information on what
# --incompatible_remove_legacy_whole_archive flag does.
# This flag is set to true in Bazel 1.0 and newer versions. We tried to migrate
# Tensorflow to the default, however test coverage wasn't enough to catch the
# errors.
# There is ongoing work on Bazel team's side to provide support for transitive
# shared libraries. As part of migrating to transitive shared libraries, we
# hope to provide a better mechanism for control over symbol exporting, and
# then tackle this issue again.
#
# TODO: Remove this line once TF doesn't depend on Bazel wrapping all library
# archives in -whole_archive -no_whole_archive.
build --noincompatible_remove_legacy_whole_archive
# These are bazel 2.0's incompatible flags. Tensorflow needs to use bazel 2.0.0
# to use cc_shared_library, as part of the Tensorflow Build Improvements RFC:
# https://github.com/tensorflow/community/pull/179
build --noincompatible_prohibit_aapt1
# Build TF with C++ 17 features.
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
build --enable_platform_specific_config
# Options from ./configure # Options from ./configure
try-import %workspace%/.reverb.bazelrc try-import %workspace%/.tf_configure.bazelrc
# Put user-specific options in .bazelrc.user
try-import %workspace%/.bazelrc.user
...@@ -7,41 +7,43 @@ package( ...@@ -7,41 +7,43 @@ package(
) )
py_library( py_library(
name = "common_layer", name = "metric_functions",
srcs = ["common_layer.py"], srcs = ["metric_functions.py"],
srcs_version = "PY3", srcs_version = "PY3",
) )
py_library( py_library(
name = "prado_model", name = "input_fn_reader",
srcs = ["prado_model.py"], srcs = ["input_fn_reader.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":common_layer", "//layers:projection_layers",
"//tf_ops:sequence_string_projection_op_py",
], ],
) )
py_library( py_binary(
name = "metric_functions", name = "trainer",
srcs = ["metric_functions.py"], srcs = ["trainer.py"],
srcs_version = "PY3", python_version = "PY3",
)
py_library(
name = "input_fn_reader",
srcs = ["input_fn_reader.py"],
srcs_version = "PY3", srcs_version = "PY3",
deps = [
":input_fn_reader",
":metric_functions",
"//models:prado",
],
) )
py_binary( py_binary(
name = "runner", name = "export_to_tflite",
srcs = ["runner.py"], srcs = ["export_to_tflite.py"],
python_version = "PY3", python_version = "PY3",
srcs_version = "PY3", srcs_version = "PY3",
deps = [ deps = [
":input_fn_reader", ":input_fn_reader",
":metric_functions", ":metric_functions",
":prado_model", "//layers:base_layers",
"//layers:projection_layers",
"//models:prado",
"//utils:tflite_utils",
], ],
) )
...@@ -35,7 +35,7 @@ models computes them on the fly. ...@@ -35,7 +35,7 @@ models computes them on the fly.
Train a PRADO model on civil comments dataset Train a PRADO model on civil comments dataset
```shell ```shell
bazel run -c opt prado:runner -- \ bazel run -c opt prado:trainer -- \
--config_path=$(pwd)/prado/civil_comments_prado.txt \ --config_path=$(pwd)/prado/civil_comments_prado.txt \
--runner_mode=train --logtostderr --output_dir=/tmp/prado --runner_mode=train --logtostderr --output_dir=/tmp/prado
``` ```
...@@ -51,7 +51,7 @@ bazel run -c opt sgnn:train -- --logtostderr --output_dir=/tmp/sgnn ...@@ -51,7 +51,7 @@ bazel run -c opt sgnn:train -- --logtostderr --output_dir=/tmp/sgnn
Evaluate PRADO model: Evaluate PRADO model:
```shell ```shell
bazel run -c opt prado:runner -- \ bazel run -c opt prado:trainer -- \
--config_path=$(pwd)/prado/civil_comments_prado.txt \ --config_path=$(pwd)/prado/civil_comments_prado.txt \
--runner_mode=eval --output_dir= --logtostderr --runner_mode=eval --output_dir= --logtostderr
``` ```
......
workspace(name = "tensorflow_models_sequence_projection") workspace(name = "tensorflow_models_seq_flow_lite")
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@//third_party/py:python_configure.bzl", "python_configure") load("@//third_party/py:python_configure.bzl", "python_configure")
......
sh_binary(
name = "move_ops",
srcs = ["move_ops.sh"],
data = [
"//tf_ops:sequence_string_projection_op_py",
],
)
#!/bin/bash
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
RUNFILES_DIR=$(pwd)
cp -f "${RUNFILES_DIR}/tf_ops/libsequence_string_projection_op_py_gen_op.so" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
cp -f "${RUNFILES_DIR}/tf_ops/sequence_string_projection_op.py" \
"${BUILD_WORKSPACE_DIRECTORY}/tf_ops"
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import subprocess
from setuptools import find_packages
from setuptools import setup
from distutils import spawn
from distutils.command import build
class _BuildCommand(build.build):
sub_commands = [
('bazel_build', lambda self: True),
] + build.build.sub_commands
class _BazelBuildCommand(setuptools.Command):
def initialize_options(self):
pass
def finalize_options(self):
self._bazel_cmd = spawn.find_executable('bazel')
def run(self):
subprocess.check_call(
[self._bazel_cmd, 'run', '-c', 'opt', '//colab:move_ops'],
cwd=os.path.dirname(os.path.realpath(__file__)))
setup(
name='seq_flow_lite',
version='0.1',
packages=['tf_ops'],
package_data={'': ['*.so']},
cmdclass={
'build': _BuildCommand,
'bazel_build': _BazelBuildCommand,
},
description='Test')
#!/bin/bash
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
cd "$(dirname "$0")"
mv setup.py ..
touch ../tf_ops/__init__.py
...@@ -5,20 +5,17 @@ ...@@ -5,20 +5,17 @@
"quantize": true, "quantize": true,
"max_seq_len": 128, "max_seq_len": 128,
"max_seq_len_inference": 128, "max_seq_len_inference": 128,
"exclude_nonalphaspace_unicodes": false,
"split_on_space": true, "split_on_space": true,
"embedding_regularizer_scale": 35e-3, "embedding_regularizer_scale": 35e-3,
"embedding_size": 64, "embedding_size": 64,
"heads": [0, 64, 64, 0, 0], "bigram_channels": 64,
"trigram_channels": 64,
"feature_size": 512, "feature_size": 512,
"network_regularizer_scale": 1e-4, "network_regularizer_scale": 1e-4,
"keep_prob": 0.5, "keep_prob": 0.5,
"word_novelty_bits": 0, "distortion_probability": 0.25
"doc_size_levels": 0,
"add_eos_tag": false,
"pre_logits_fc_layers": [],
"text_distortion_probability": 0.25
}, },
"name": "models.prado",
"batch_size": 1024, "batch_size": 1024,
"save_checkpoints_steps": 100, "save_checkpoints_steps": 100,
"train_steps": 100000, "train_steps": 100000,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"quantize": true, "quantize": true,
"max_seq_len": 128, "max_seq_len": 128,
"max_seq_len_inference": 128, "max_seq_len_inference": 128,
"exclude_nonalphaspace_unicodes": false,
"split_on_space": true, "split_on_space": true,
"embedding_regularizer_scale": 35e-3, "embedding_regularizer_scale": 35e-3,
"embedding_size": 64, "embedding_size": 64,
...@@ -13,12 +12,9 @@ ...@@ -13,12 +12,9 @@
"feature_size": 512, "feature_size": 512,
"network_regularizer_scale": 1e-4, "network_regularizer_scale": 1e-4,
"keep_prob": 0.5, "keep_prob": 0.5,
"word_novelty_bits": 0, "distortion_probability": 0.0
"doc_size_levels": 0,
"add_eos_tag": false,
"pre_logits_fc_layers": [],
"text_distortion_probability": 0.0
}, },
"name": "models.prado",
"batch_size": 1024, "batch_size": 1024,
"save_checkpoints_steps": 100, "save_checkpoints_steps": 100,
"train_steps": 100000, "train_steps": 100000,
......
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""A tool to export TFLite model."""
import importlib
import json
import os
from absl import app
from absl import flags
import tensorflow.compat.v1 as tf
from layers import base_layers # import seq_flow_lite module
from layers import projection_layers # import seq_flow_lite module
from utils import tflite_utils # import seq_flow_lite module
FLAGS = flags.FLAGS
flags.DEFINE_string("output_dir", None, "The output or model directory.")
def load_runner_config():
config = os.path.join(FLAGS.output_dir, "runner_config.txt")
with tf.gfile.Open(config, "r") as f:
return json.loads(f.read())
def main(_):
runner_config = load_runner_config()
model_config = runner_config["model_config"]
rel_module_path = "" # empty base dir
model = importlib.import_module(rel_module_path + runner_config["name"])
with tf.Graph().as_default() as graph:
with tf.Session(graph=graph) as session:
text = tf.placeholder(tf.string, shape=[1], name="Input")
prxlayer = projection_layers.ProjectionLayer(model_config,
base_layers.TFLITE)
encoder = model.Encoder(model_config, base_layers.TFLITE)
projection, seq_lengh = prxlayer(text)
logits = encoder(projection, seq_lengh)
session.run(tf.global_variables_initializer())
session.run(tf.local_variables_initializer())
saver = tf.train.Saver()
saver.restore(session, tf.train.latest_checkpoint(FLAGS.output_dir))
tflite_fb = tflite_utils.generate_tflite(session, graph, [text], [logits])
output_file_name = os.path.join(FLAGS.output_dir, "tflite.fb")
with tf.gfile.Open(output_file_name, "wb") as f:
f.write(tflite_fb)
if __name__ == "__main__":
app.run(main)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Methods related to input datasets and readers."""
import functools
import sys
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from layers import projection_layers # import seq_flow_lite module
from utils import misc_utils # import seq_flow_lite module
def imdb_reviews(features, _):
return features["text"], features["label"]
def civil_comments(features, runner_config):
labels = runner_config["model_config"]["labels"]
label_tensor = tf.stack([features[label] for label in labels], axis=1)
label_tensor = tf.floor(label_tensor + 0.5)
return features["text"], label_tensor
def goemotions(features, runner_config):
labels = runner_config["model_config"]["labels"]
label_tensor = tf.stack([features[label] for label in labels], axis=1)
return features["comment_text"], tf.cast(label_tensor, tf.float32)
def create_input_fn(runner_config, mode, drop_remainder):
"""Returns an input function to use in the instantiation of tf.estimator.*."""
def _post_processor(features, batch_size):
"""Post process the data to a form expected by model_fn."""
data_processor = getattr(sys.modules[__name__], runner_config["dataset"])
text, label = data_processor(features, runner_config)
model_config = runner_config["model_config"]
if "max_seq_len" in model_config:
max_seq_len = model_config["max_seq_len"]
logging.info("Truncating text to have at most %d tokens", max_seq_len)
text = misc_utils.random_substr(text, max_seq_len)
text = tf.reshape(text, [batch_size])
num_classes = len(model_config["labels"])
label = tf.reshape(label, [batch_size, num_classes])
prxlayer = projection_layers.ProjectionLayer(model_config, mode)
projection, seq_length = prxlayer(text)
return {"projection": projection, "seq_length": seq_length, "label": label}
def _input_fn(params):
"""Method to be used for reading the data."""
assert mode != tf.estimator.ModeKeys.PREDICT
split = "train" if mode == tf.estimator.ModeKeys.TRAIN else "test"
ds = tfds.load(runner_config["dataset"], split=split)
ds = ds.batch(params["batch_size"], drop_remainder=drop_remainder)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
ds = ds.shuffle(buffer_size=100)
ds = ds.repeat(count=1 if mode == tf.estimator.ModeKeys.EVAL else None)
ds = ds.map(
functools.partial(_post_processor, batch_size=params["batch_size"]),
num_parallel_calls=tf.data.experimental.AUTOTUNE,
deterministic=False)
return ds
return _input_fn
py_strict_library = py_library
licenses(["notice"])
package(
default_visibility = ["//:friends"], # sequence projection
)
py_strict_library(
name = "base_layers",
srcs = ["base_layers.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
],
)
py_strict_library(
name = "quantization_layers",
srcs = ["quantization_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
# package tensorflow
],
)
py_strict_library(
name = "normalization_layers",
srcs = ["normalization_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
":quantization_layers",
# package tensorflow
# "//tf_ops:tf_custom_ops" # sequence projection
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
py_strict_library(
name = "dense_layers",
srcs = ["dense_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
":normalization_layers",
":quantization_layers",
# package tensorflow
],
)
py_strict_library(
name = "conv_layers",
srcs = ["conv_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
":normalization_layers",
":quantization_layers",
# package tensorflow
],
)
py_strict_library(
name = "projection_layers",
srcs = ["projection_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
# package absl/logging
# package tensorflow
# "//tf_ops:sequence_string_projection_op" # sequence projection
"//tf_ops:sequence_string_projection_op_py", # sequence projection
"//tf_ops:sequence_string_projection_op_v2", # sequence projection
"//tf_ops:sequence_string_projection_op_v2_py", # sequence projection
],
)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for building models trained with quantization."""
import tensorflow as tf
TRAIN = "train"
EVAL = "eval"
PREDICT = "infer"
TFLITE = "tflite"
_MODE = [TRAIN, EVAL, PREDICT, TFLITE]
class Parameters:
"""A class that encapsulates parameters."""
def __init__(self,
mode,
quantize=True,
regularizer_scale=0.0,
invalid_logit=-1e6,
initializer=None):
assert isinstance(quantize, bool)
self.quantize = quantize
assert mode in _MODE
self.mode = mode
self.regularizer_scale = regularizer_scale
self.invalid_logit = invalid_logit
self.initializer = initializer
class BaseLayer(tf.keras.layers.Layer):
"""Base class for encoders."""
def __init__(self, parameters, **kwargs):
assert isinstance(parameters, Parameters)
self.parameters = parameters
super(BaseLayer, self).__init__(**kwargs)
def _assert_rank_and_type(self, tensor, rank, dtype=tf.float32):
assert len(tensor.get_shape().as_list()) == rank
assert tensor.dtype == dtype
def add_qweight(self, shape, num_bits=8):
"""Return a quantized weight variable for the given shape."""
if self.parameters.initializer is not None:
initializer = self.parameters.initializer
else:
initializer = tf.keras.initializers.GlorotUniform()
weight = self.add_weight(
"weight", shape, initializer=initializer, trainable=True)
self.add_reg_loss(weight)
return self._weight_quantization(weight, num_bits=num_bits)
def _weight_quantization(self, tensor, num_bits=8):
"""Quantize weights when enabled."""
# For infer mode, toco computes the min/max from the weights offline to
# quantize it. During train/eval this is computed from the current value
# in the session by the graph itself.
if self.parameters.quantize and self.parameters.mode in [TRAIN, EVAL]:
# Toco expects 0.0 to be part of the quantization range.
batch_min = tf.minimum(tf.reduce_min(tensor), 0.0)
batch_max = tf.maximum(tf.reduce_max(tensor), 0.0)
return tf.quantization.fake_quant_with_min_max_vars(
tensor, batch_min, batch_max, num_bits=num_bits)
else:
return tensor
def add_bias(self, shape):
weight = self.add_weight(
"bias",
shape,
initializer=tf.keras.initializers.Zeros(),
trainable=True)
self.add_reg_loss(weight)
return weight
def add_reg_loss(self, weight):
if self.parameters.regularizer_scale > 0.0:
reg_scale = tf.convert_to_tensor(self.parameters.regularizer_scale)
reg_loss = tf.nn.l2_loss(weight) * reg_scale
self.add_loss(reg_loss)
def assign_moving_average(self, var, update, ema_decay):
return var.assign(var.read_value() * (1 - ema_decay) + (ema_decay) * update)
def qrange_sigmoid(self, tensor):
if self.parameters.quantize:
return tf.quantization.fake_quant_with_min_max_args(tensor, 0.0, 1.0)
return tensor
def qrange_tanh(self, tensor):
if self.parameters.quantize:
return tf.quantization.fake_quant_with_min_max_args(tensor, -1.0, 1.0)
return tensor
def quantized_tanh(self, tensor):
return self.qrange_tanh(tf.tanh(tensor))
def quantized_sigmoid(self, tensor):
return self.qrange_sigmoid(tf.sigmoid(tensor))
def get_batch_dimension(self, tensor):
return tensor.get_shape().as_list()[0] or tf.shape(tensor)[0]
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Base layer for convolution."""
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
class EncoderQConvolution(base_layers.BaseLayer):
"""Quantized encoder convolution layers."""
def __init__(self,
filters,
ksize,
stride=1,
padding="SAME",
dilations=None,
activation=tf.keras.layers.ReLU(),
bias=True,
rank=4,
**kwargs):
self.out_filters = filters
assert rank >= 3 and rank <= 4
self.rank = rank
self.ksize = self._unpack(ksize)
self.strides = self._unpack(stride)
self.dilations = [1] + self._unpack(dilations) + [1] if dilations else None
self.activation = activation
self.bias = bias
self.padding = padding
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
self._create_normalizer(**kwargs)
super(EncoderQConvolution, self).__init__(**kwargs)
def _unpack(self, value):
if not isinstance(value, list):
assert isinstance(value, int)
return [1 if self.rank == 3 else value, value]
else:
assert len(value) == 2 and self.rank == 4
assert isinstance(value[0], int) and isinstance(value[1], int)
return value
def build(self, input_shapes):
assert len(input_shapes) == self.rank
self.in_filters = input_shapes[-1]
shape = self.ksize + [self.in_filters, self.out_filters]
self.filters = self.add_qweight(shape=shape)
if self.bias:
self.b = self.add_bias(shape=[self.out_filters])
def _create_normalizer(self, **kwargs):
self.normalization = normalization_layers.BatchNormalization(**kwargs)
def _conv_r4(self, inputs, normalize_method):
outputs = tf.nn.conv2d(
inputs,
self.filters,
strides=self.strides,
padding=self.padding,
dilations=self.dilations)
if self.bias:
outputs = tf.nn.bias_add(outputs, self.b)
outputs = normalize_method(outputs)
if self.activation:
outputs = self.activation(outputs)
return self.qoutput(outputs)
def _conv_r3(self, inputs, normalize_method):
bsz = self.get_batch_dimension(inputs)
inputs_r4 = tf.reshape(inputs, [bsz, 1, -1, self.in_filters])
outputs = self._conv_r4(inputs_r4, normalize_method)
return tf.reshape(outputs, [bsz, -1, self.out_filters])
def call(self, inputs):
def normalize_method(tensor):
return self.normalization(tensor)
return self._do_call(inputs, normalize_method)
def _do_call(self, inputs, normalize_method):
if self.rank == 3:
return self._conv_r3(inputs, normalize_method)
return self._conv_r4(inputs, normalize_method)
def quantize_using_output_range(self, tensor):
return self.qoutput.quantize_using_range(tensor)
class EncoderQConvolutionVarLen(EncoderQConvolution):
"""Convolution on variable length sequence."""
def _create_normalizer(self, **kwargs):
self.normalization = normalization_layers.VarLenBatchNormalization(
rank=4, **kwargs)
def call(self, inputs, mask, inverse_normalizer):
def normalize_method(tensor):
return self.normalization(tensor, mask, inverse_normalizer)
return self._do_call(inputs, normalize_method)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Basic dense layers."""
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import normalization_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
class BaseQDense(base_layers.BaseLayer):
"""Quantized encoder dense layers."""
def __init__(self,
units,
activation=tf.keras.layers.ReLU(),
bias=True,
rank=2,
**kwargs):
self.units = units
self.rank = rank
assert rank >= 2 and rank <= 4
self.activation = activation
self.bias = bias
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
self._create_normalizer(**kwargs)
super(BaseQDense, self).__init__(**kwargs)
def build(self, input_shapes):
assert len(input_shapes) == self.rank
if self.rank == 4:
assert input_shapes[1] == 1 or input_shapes[2] == 1
self.in_units = input_shapes[-1]
shape = [self.in_units, self.units]
self.w = self.add_qweight(shape=shape)
if self.bias:
self.b = self.add_bias(shape=[self.units])
def _create_normalizer(self, **kwargs):
self.normalization = normalization_layers.BatchNormalization(**kwargs)
def _dense_r2(self, inputs, normalize_method):
outputs = tf.matmul(inputs, self.w)
if self.bias:
outputs = tf.nn.bias_add(outputs, self.b)
outputs = normalize_method(outputs)
if self.activation:
outputs = self.activation(outputs)
return self.qoutput(outputs)
def _dense_r34(self, inputs, normalize_method):
bsz = self.get_batch_dimension(inputs)
outputs = tf.reshape(inputs, [-1, self.in_units])
outputs = self._dense_r2(outputs, normalize_method)
if self.rank == 3:
return tf.reshape(outputs, [bsz, -1, self.units])
elif inputs.get_shape().as_list()[1] == 1:
return tf.reshape(outputs, [bsz, 1, -1, self.units])
else:
return tf.reshape(outputs, [bsz, -1, 1, self.units])
def call(self, inputs):
def normalize_method(tensor):
return self.normalization(tensor)
return self._do_call(inputs, normalize_method)
def _do_call(self, inputs, normalize_method):
if self.rank == 2:
return self._dense_r2(inputs, normalize_method)
return self._dense_r34(inputs, normalize_method)
def quantize_using_output_range(self, tensor):
return self.qoutput.quantize_using_range(tensor)
class BaseQDenseVarLen(BaseQDense):
"""Dense on variable length sequence."""
def _create_normalizer(self, **kwargs):
self.normalization = normalization_layers.VarLenBatchNormalization(
rank=2, **kwargs)
def call(self, inputs, mask, inverse_normalizer):
def normalize_method(tensor):
maskr2 = tf.reshape(mask, [-1, 1])
return self.normalization(tensor, maskr2, inverse_normalizer)
return self._do_call(inputs, normalize_method)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for normalization."""
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import quantization_layers # import seq_flow_lite module
from tf_ops import tf_custom_ops_py # import seq_flow_lite module
class BatchNormalization(base_layers.BaseLayer):
"""A class that applies batch normalization to the input tensor."""
def __init__(self, ema_decay=0.999, **kwargs):
self.ema_decay = ema_decay
super(BatchNormalization, self).__init__(**kwargs)
def build(self, input_shapes):
self.reduce_dims = list(range(len(input_shapes) - 1))
shape = [input_shapes[-1]]
self.offset = self.add_weight(
"offset",
shape=shape,
initializer=tf.keras.initializers.Zeros(),
trainable=True)
self.scale = self.add_weight(
"scale",
shape=shape,
initializer=tf.keras.initializers.Ones(),
trainable=True)
self.mva_mean = self.add_weight(
"mva_mean",
shape=shape,
initializer=tf.keras.initializers.Zeros(),
trainable=False)
self.mva_var = self.add_weight(
"mva_variance",
shape=shape,
initializer=tf.keras.initializers.Ones(),
trainable=False)
def call(self, inputs):
mean_mom, var_mom = None, None
if self.parameters.mode == base_layers.TRAIN:
mean_mom, var_mom = tf.nn.moments(inputs, self.reduce_dims)
return self._batch_norm(inputs, mean_mom, var_mom)
def _batch_norm(self, inputs, mean_mom, var_mom):
if self.parameters.mode == base_layers.TRAIN:
# During training compute summay stats, update them to moving average
# variables and use the summary stas for batch normalization.
with tf.control_dependencies([
self.assign_moving_average(self.mva_mean, mean_mom, self.ema_decay),
self.assign_moving_average(self.mva_var, var_mom, self.ema_decay)
]):
tensor = tf.nn.batch_normalization(inputs, mean_mom, var_mom,
self.offset, self.scale, 1e-9)
else:
# During eval/inference use the moving average variable for batch
# normalization. The variables would be frozen to constants before
# saving graph.
tensor = tf.nn.batch_normalization(inputs, self.mva_mean, self.mva_var,
self.offset, self.scale, 1e-9)
return tensor
class VarLenBatchNormalization(BatchNormalization):
"""A class that applies batch normalization to the input tensor."""
def __init__(self, rank=2, **kwargs):
self.rank = rank
assert rank == 2 or rank == 4
super(VarLenBatchNormalization, self).__init__(**kwargs)
def _reduce(self, tensor, multiplier):
return tf.reduce_sum(tensor, axis=self.reduce_dims) * multiplier
def call(self, inputs, mask, inverse_normalizer):
if self.parameters.mode == base_layers.TRAIN:
self._assert_rank_and_type(inputs, self.rank)
self._assert_rank_and_type(mask, self.rank)
inputs = mask * inputs
mean_mom = self._reduce(inputs, inverse_normalizer)
var_mom = self._reduce(inputs * inputs, inverse_normalizer)
return mask * self._batch_norm(inputs, mean_mom, var_mom)
elif self.parameters.mode == base_layers.EVAL:
return mask * self._batch_norm(inputs, None, None)
return self._batch_norm(inputs, None, None)
class LayerNormalization(base_layers.BaseLayer):
"""A class that applies layer normalization to the input tensor."""
def __init__(self, axes=None, **kwargs):
self.axes = axes or [-1]
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.rank = len(input_shape)
for i, axis in enumerate(self.axes):
if axis < 0:
self.axes[i] += self.rank
assert (self.axes[i] > 0 and self.axes[i] < self.rank)
self.offset = self.add_weight(
"offset",
shape=[1],
initializer=tf.keras.initializers.Zeros(),
trainable=True)
self.scale = self.add_weight(
"scale",
shape=[1],
initializer=tf.keras.initializers.Ones(),
trainable=True)
def call(self, tensor):
tensor = self.qactivation(tensor)
if self.parameters.mode != base_layers.TFLITE:
mean, variance = tf.nn.moments(tensor, self.axes, keepdims=True)
# If all the values in the tensor are same, variance will be 0. Adding a
# small epsilon to variance ensures that we get 0 as the normalized result
# instead of NaN in the resulting tensor.
tensor = (tensor - mean) / tf.sqrt(variance + 1e-6)
return tensor * self.scale + self.offset
else:
return tf_custom_ops_py.layer_norm_v2(
tensor, self.scale, self.offset, axes=self.axes)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tensorflow projection creator for PRADO model."""
from absl import logging
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from tf_ops import sequence_string_projection_op as ssp # import seq_flow_lite module
from tf_ops import sequence_string_projection_op_v2 as sspv2 # import seq_flow_lite module
class ProjectionLayer(base_layers.BaseLayer):
"""Base class for encoders."""
def __init__(self, model_config, mode):
"""Create projection."""
def _get_params(varname, default_value=None):
value = model_config[varname] if varname in model_config else default_value
default = "" if varname in model_config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
self.mode = mode
_get_params("feature_size")
_get_params("max_seq_len", 0)
_get_params("add_eos_tag", False)
_get_params("add_bos_tag", False)
_get_params("split_on_space", True)
_get_params("token_separators", "")
_get_params("vocabulary", "")
_get_params("quantize")
_get_params("word_novelty_bits", 0)
_get_params("doc_size_levels", 0)
self.distortion_probability = 0.0
if mode == base_layers.TRAIN:
_get_params("distortion_probability", 0.0)
parameters = base_layers.Parameters(mode, self.quantize)
super(ProjectionLayer, self).__init__(parameters=parameters)
def call(self, inputs):
projection, _, seq_length = ssp.sequence_string_projection(
input=inputs,
feature_size=self.feature_size,
max_splits=self.max_seq_len - 1,
distortion_probability=self.distortion_probability,
split_on_space=self.split_on_space,
token_separators=self.token_separators,
word_novelty_bits=self.word_novelty_bits,
doc_size_levels=self.doc_size_levels,
add_eos_tag=self.add_eos_tag,
add_bos_tag=self.add_bos_tag,
vocabulary=self.vocabulary)
modes = [base_layers.PREDICT, base_layers.TFLITE]
if self.mode not in modes and self.max_seq_len > 0:
short_by = self.max_seq_len - tf.shape(projection)[1]
projection = tf.pad(projection, [[0, 0], [0, short_by], [0, 0]])
batch_size = inputs.get_shape().as_list()[0]
projection = tf.reshape(projection,
[batch_size, self.max_seq_len, self.feature_size])
if self.mode in modes:
projection = self.qrange_tanh(projection)
return projection, seq_length
class ProjectionLayerPreSegmented(base_layers.BaseLayer):
"""Base class for encoders."""
def __init__(self, model_config, mode):
"""Create projection."""
def _get_params(varname, default_value=None):
value = model_config[varname] if varname in model_config else default_value
default = "" if varname in model_config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
self.mode = mode
_get_params("feature_size")
_get_params("add_eos_tag", False)
_get_params("add_bos_tag", False)
_get_params("vocabulary", "")
_get_params("quantize")
self.distortion_probability = 0.0
if mode == base_layers.TRAIN:
_get_params("distortion_probability", 0.0)
parameters = base_layers.Parameters(mode, self.quantize)
super(ProjectionLayerPreSegmented, self).__init__(parameters=parameters)
def call(self, inputs, sequence_length):
projection = sspv2.sequence_string_projection_v2(
input=inputs,
sequence_length=sequence_length,
feature_size=self.feature_size,
distortion_probability=self.distortion_probability,
add_eos_tag=self.add_eos_tag,
add_bos_tag=self.add_bos_tag,
vocabulary=self.vocabulary)
modes = [base_layers.PREDICT, base_layers.TFLITE]
if self.mode in modes:
projection = self.qrange_tanh(projection)
return projection
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Lint as: python3
"""Layers for quantization."""
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
class ActivationQuantization(base_layers.BaseLayer):
"""A class that applies quantization to a activation tensor."""
def __init__(self, ema_decay=0.99, num_bits=8, **kwargs):
self.ema_decay = ema_decay
self.num_bits = num_bits
super(ActivationQuantization, self).__init__(**kwargs)
if self.parameters.quantize:
self.min_var = self.add_weight(
"min", initializer=tf.keras.initializers.Zeros(), trainable=False)
self.max_var = self.add_weight(
"max", initializer=tf.keras.initializers.Ones(), trainable=False)
def call(self, inputs):
if self.parameters.quantize:
if self.parameters.mode == base_layers.TRAIN:
# Toco expects 0.0 to be part of the quantization range.
batch_min = tf.minimum(tf.reduce_min(inputs), 0.0)
min_var = self.assign_moving_average(self.min_var, batch_min,
self.ema_decay)
batch_max = tf.maximum(tf.reduce_max(inputs), 0.0)
max_var = self.assign_moving_average(self.max_var, batch_max,
self.ema_decay)
with tf.control_dependencies([min_var, max_var]):
return tf.quantization.fake_quant_with_min_max_vars(
inputs, batch_min, batch_max, num_bits=self.num_bits)
else:
return tf.quantization.fake_quant_with_min_max_vars(
inputs, self.min_var, self.max_var, num_bits=self.num_bits)
return inputs
def quantize_using_range(self, inputs):
if self.parameters.quantize:
return tf.quantization.fake_quant_with_min_max_vars(
inputs, self.min_var, self.max_var, num_bits=self.num_bits)
return inputs
class ConcatQuantization(ActivationQuantization):
"""A class that applies quantization to a activation tensor."""
def __init__(self, axis=2, **kwargs):
self.axis = axis
super(ConcatQuantization, self).__init__(**kwargs)
def reduce_list(self, tensor_list, functor):
reduce_result = [functor(tensor) for tensor in tensor_list]
# Toco expects 0.0 to be part of the quantization range.
reduce_result.append(tf.constant(0.0))
return functor(tf.stack(reduce_result))
def call(self, tensors):
if self.parameters.quantize:
if self.parameters.mode == base_layers.TRAIN:
# Toco expects 0.0 to be part of the quantization range.
batch_min = self.reduce_list(tensors, tf.reduce_min)
min_var = self.assign_moving_average(self.min_var, batch_min,
self.ema_decay)
batch_max = self.reduce_list(tensors, tf.reduce_max)
max_var = self.assign_moving_average(self.max_var, batch_max,
self.ema_decay)
else:
min_var, max_var = self.min_var, self.max_var
tensors = [
tf.quantization.fake_quant_with_min_max_vars(
tensor, min_var, max_var, num_bits=self.num_bits)
for tensor in tensors
]
tensor = tf.concat(tensors, axis=self.axis)
return tf.quantization.fake_quant_with_min_max_vars(
tensor, min_var, max_var, num_bits=self.num_bits)
return tf.concat(tensors, axis=self.axis)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment