Unverified Commit 18dbf036 authored by Sehoon Kim's avatar Sehoon Kim Committed by GitHub
Browse files

Squeezeformer Initial Commit



Initial Commit
Co-authored-by: default avatarAlbert Shaw <ashaw596@gmail.com>
Co-authored-by: default avatarNicholas Lee <caldragon18456@berkeley.edu>
Co-authored-by: default avatarani <aninrusimha@berkeley.edu>
Co-authored-by: default avatardragon18456 <nicholas_lee@berkeley.edu>
parent 5d6f1ae4
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from src.utils import shape_util, math_util
logger = tf.get_logger()
class Conv2dSubsampling(tf.keras.layers.Layer):
def __init__(
self,
filters: int,
strides: int = 2,
kernel_size: int = 3,
ds: bool = False,
name="Conv2dSubsampling",
**kwargs,
):
super(Conv2dSubsampling, self).__init__(name=name, **kwargs)
self.strides = strides
self.kernel_size = kernel_size
assert self.strides == 2 and self.kernel_size == 3 # Fix this for simplicity
conv1_max = kernel_size ** -1
conv2_max = (kernel_size ** 2 * filters) ** -0.5
self.conv1 = tf.keras.layers.Conv2D(
filters=filters, kernel_size=kernel_size,
strides=strides, padding="valid", name=f"{name}_1",
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-conv1_max, maxval=conv1_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-conv1_max, maxval=conv1_max),
)
self.ds = ds
if not ds:
logger.info("Subsampling with full conv")
self.conv2 = tf.keras.layers.Conv2D(
filters=filters, kernel_size=kernel_size,
strides=strides, padding="valid", name=f"{name}_2",
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-conv2_max, maxval=conv2_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-conv2_max, maxval=conv2_max),
)
self.time_reduction_factor = self.conv1.strides[0] + self.conv2.strides[0]
else:
logger.info("Subsampling with DS conv")
dw_max = (kernel_size ** 2) ** -0.5
pw_max = filters ** -0.5
self.dw_conv = tf.keras.layers.DepthwiseConv2D(
kernel_size=(kernel_size, kernel_size), strides=strides,
padding="valid", name=f"{name}_2_dw",
depth_multiplier=1,
depthwise_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
)
self.pw_conv = tf.keras.layers.Conv2D(
filters=filters, kernel_size=1, strides=1,
padding="valid", name=f"{name}_2_pw",
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
)
self.time_reduction_factor = self.conv1.strides[0] + self.dw_conv.strides[0]
def call(self, inputs, training=False, **kwargs):
_, L, H, _ = shape_util.shape_list(inputs)
assert H == 80
outputs = tf.pad(inputs, [[0, 0], [0, 1], [0, 1], [0, 0]])
outputs = self.conv1(outputs, training=training)
outputs = tf.nn.relu(outputs)
outputs = tf.pad(outputs, [[0, 0], [0, 1], [0, 1], [0, 0]])
if not self.ds:
outputs = self.conv2(outputs, training=training)
else:
outputs = self.dw_conv(outputs, training=training)
outputs = self.pw_conv(outputs, training=training)
outputs = tf.nn.relu(outputs)
_, L, H, _ = shape_util.shape_list(outputs)
assert H == 20
return math_util.merge_two_last_dims(outputs)
def get_config(self):
conf = super(Conv2dSubsampling, self).get_config()
conf.update(self.conv1.get_config())
conf.update(self.conv2.get_config())
return conf
import tensorflow as tf
from ...utils import shape_util
class TimeReductionLayer(tf.keras.layers.Layer):
def __init__(
self,
input_dim,
output_dim,
kernel_size=5,
stride=2,
dropout=0.0,
name="time_reduction",
**kwargs,
):
super(TimeReductionLayer, self).__init__(name=name, **kwargs)
self.stride = stride
self.kernel_size = kernel_size
dw_max = kernel_size ** -0.5
pw_max = input_dim ** -0.5
self.dw_conv = tf.keras.layers.DepthwiseConv2D(
kernel_size=(kernel_size, 1), strides=self.stride,
padding="valid", name=f"{name}_dw_conv",
depth_multiplier=1,
depthwise_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-dw_max, maxval=dw_max),
)
#self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation")
self.pw_conv = tf.keras.layers.Conv2D(
filters=output_dim, kernel_size=1, strides=1,
padding="valid", name=f"{name}_pw_conv_2",
kernel_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
bias_initializer=tf.keras.initializers.RandomUniform(minval=-pw_max, maxval=pw_max),
)
def call(self, inputs, training=False, mask=None, pad_mask=None, **kwargs):
B, T, E = shape_util.shape_list(inputs)
outputs = tf.reshape(inputs, [B, T, 1, E])
_pad_mask = tf.expand_dims(tf.expand_dims(pad_mask, -1), -1)
outputs = outputs * tf.cast(_pad_mask, "float32")
padding = max(0, self.kernel_size - self.stride)
outputs = tf.pad(outputs, [[0, 0], [0, padding], [0, 0], [0, 0]])
outputs = self.dw_conv(outputs, training=training)
outputs = self.pw_conv(outputs, training=training)
B, T, _, E = shape_util.shape_list(outputs)
outputs = tf.reshape(outputs, [B, T, E])
mask = mask[:, ::self.stride, ::self.stride]
pad_mask = pad_mask[:, ::self.stride]
_, L = shape_util.shape_list(pad_mask)
outputs = tf.pad(outputs, [[0, 0], [0, L - T], [0, 0]])
return outputs, mask, pad_mask
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm import tqdm
import tensorflow as tf
from .metric_util import wer, cer
from ..metrics.error_rates import ErrorRate
from .file_util import read_file
logger = tf.get_logger()
def evaluate_results(filepath: str):
logger.info(f"Evaluating result from {filepath} ...")
metrics = {
"greedy_wer": ErrorRate(wer, name="greedy_wer", dtype=tf.float32),
"greedy_cer": ErrorRate(cer, name="greedy_cer", dtype=tf.float32),
"beamsearch_wer": ErrorRate(wer, name="beamsearch_wer", dtype=tf.float32),
"beamsearch_cer": ErrorRate(cer, name="beamsearch_cer", dtype=tf.float32)
}
with read_file(filepath) as path:
with open(path, "r", encoding="utf-8") as openfile:
lines = openfile.read().splitlines()
lines = lines[1:] # skip header
for eachline in tqdm(lines):
_, _, groundtruth, greedy, beamsearch = eachline.split("\t")
groundtruth = tf.convert_to_tensor([groundtruth], dtype=tf.string)
greedy = tf.convert_to_tensor([greedy], dtype=tf.string)
beamsearch = tf.convert_to_tensor([beamsearch], dtype=tf.string)
metrics["greedy_wer"].update_state(decode=greedy, target=groundtruth)
metrics["greedy_cer"].update_state(decode=greedy, target=groundtruth)
metrics["beamsearch_wer"].update_state(decode=beamsearch, target=groundtruth)
metrics["beamsearch_cer"].update_state(decode=beamsearch, target=groundtruth)
for key, value in metrics.items():
logger.info(f"{key}: {value.result().numpy()}")
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tf.data.Dataset does not work well for namedtuple so we are using dict
import tensorflow as tf
def create_inputs(
inputs: tf.Tensor,
inputs_length: tf.Tensor,
predictions: tf.Tensor = None,
predictions_length: tf.Tensor = None,
) -> dict:
data = {
"inputs": inputs,
"inputs_length": inputs_length,
}
if predictions is not None:
data["predictions"] = predictions
if predictions_length is not None:
data["predictions_length"] = predictions_length
return data
def create_logits(logits: tf.Tensor, logits_length: tf.Tensor) -> dict:
return {
"logits": logits,
"logits_length": logits_length
}
def create_labels(labels: tf.Tensor, labels_length: tf.Tensor) -> dict:
return {
"labels": labels,
"labels_length": labels_length,
}
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Union, List
import warnings
import tensorflow as tf
logger = tf.get_logger()
def setup_environment():
""" Setting tensorflow running environment """
warnings.simplefilter("ignore")
logger.setLevel(logging.INFO)
return logger
def setup_devices(devices: List[int], cpu: bool = False):
"""Setting visible devices
Args:
devices (list): list of visible devices' indices
"""
if cpu:
cpus = tf.config.list_physical_devices("CPU")
tf.config.set_visible_devices(cpus, "CPU")
tf.config.set_visible_devices([], "GPU")
logger.info(f"Run on {len(cpus)} Physical CPUs")
else:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
visible_gpus = [gpus[i] for i in devices]
tf.config.set_visible_devices(visible_gpus, "GPU")
logger.info(f"Run on {len(visible_gpus)} Physical GPUs")
def setup_tpu(tpu_address=None):
if tpu_address is None:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
else:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="grpc://" + tpu_address)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
logger.info(f"All TPUs: {tf.config.list_logical_devices('TPU')}")
return tf.distribute.experimental.TPUStrategy(resolver)
def setup_strategy(devices: List[int], tpu_address: str = None):
"""Setting mirrored strategy for training
Args:
devices (list): list of visible devices' indices
tpu_address (str): an optional custom tpu address
Returns:
tf.distribute.Strategy: TPUStrategy for training on tpus or MirroredStrategy for training on gpus
"""
try:
return setup_tpu(tpu_address)
except (ValueError, tf.errors.NotFoundError) as e:
logger.warn(e)
pass
setup_devices(devices)
return tf.distribute.MirroredStrategy()
def has_devices(devices: Union[List[str], str]):
if isinstance(devices, list):
return all([len(tf.config.list_logical_devices(d)) != 0 for d in devices])
return len(tf.config.list_logical_devices(devices)) != 0
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def float_feature(list_of_floats):
return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))
def int64_feature(list_of_ints):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))
def bytestring_feature(list_of_bytestrings):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import yaml
import tempfile
import contextlib
from typing import Union, List
import tensorflow as tf
def load_yaml(path):
# Fix yaml numbers https://stackoverflow.com/a/30462009/11037553
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
with open(path, "r", encoding="utf-8") as file:
return yaml.load(file, Loader=loader)
def is_hdf5_filepath(filepath: str) -> bool:
return (filepath.endswith('.h5') or filepath.endswith('.keras') or filepath.endswith('.hdf5'))
def is_cloud_path(path: str) -> bool:
""" Check if the path is on cloud (which requires tf.io.gfile)
Args:
path (str): Path to directory or file
Returns:
bool: True if path is on cloud, False otherwise
"""
return bool(re.match(r"^[a-z]+://", path))
def preprocess_paths(paths: Union[List[str], str], isdir: bool = False) -> Union[List[str], str]:
""" Expand the path to the root "/" and makedirs
Args:
paths (Union[List, str]): A path or list of paths
Returns:
Union[List, str]: A processed path or list of paths, return None if it's not path
"""
if isinstance(paths, list):
paths = [path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) for path in paths]
for path in paths:
dirpath = path if isdir else os.path.dirname(path)
if not tf.io.gfile.exists(dirpath): tf.io.gfile.makedirs(dirpath)
return paths
if isinstance(paths, str):
paths = paths if is_cloud_path(paths) else os.path.abspath(os.path.expanduser(paths))
dirpath = paths if isdir else os.path.dirname(paths)
if not tf.io.gfile.exists(dirpath): tf.io.gfile.makedirs(dirpath)
return paths
return None
@contextlib.contextmanager
def save_file(filepath: str):
if is_cloud_path(filepath) and is_hdf5_filepath(filepath):
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as tmp:
yield tmp.name
tf.io.gfile.copy(tmp.name, filepath, overwrite=True)
else:
yield filepath
@contextlib.contextmanager
def read_file(filepath: str):
if is_cloud_path(filepath) and is_hdf5_filepath(filepath):
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as tmp:
tf.io.gfile.copy(filepath, tmp.name, overwrite=True)
yield tmp.name
else:
yield filepath
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def get_rnn(rnn_type: str):
assert rnn_type in ["lstm", "gru", "rnn"]
if rnn_type == "lstm": return tf.keras.layers.LSTM
if rnn_type == "gru": return tf.keras.layers.GRU
return tf.keras.layers.SimpleRNN
def get_conv(conv_type):
assert conv_type in ["conv1d", "conv2d"]
if conv_type == "conv1d": return tf.keras.layers.Conv1D
return tf.keras.layers.Conv2D
import wandb
import tensorflow as tf
import numpy as np
from numpy import linalg as la
from . import env_util
logger = env_util.setup_environment()
class StepLossMetric(tf.keras.metrics.Metric):
def __init__(self, name='step_loss', **kwargs):
super(StepLossMetric, self).__init__(name=name, **kwargs)
self.loss = tf.zeros(())
def update_state(self, loss):
self.loss = loss
def result(self):
return self.loss
def reset_states(self):
self.loss = tf.zeros(())
class LoggingCallback(tf.keras.callbacks.Callback):
def __init__(
self,
optimizer,
model,
):
super(LoggingCallback, self).__init__()
self.optimizer = optimizer
self.model = model
def on_epoch_end(self, epoch, logs=None):
logger.info("saving checkpoint")
iterations = self.optimizer.iterations
lr = self.optimizer.learning_rate(iterations)
logger.info(f"[LR Logger] Epoch: {epoch}, lr: {lr}")
wandb.log({"epoch": epoch, "lr": lr, "iterations": iterations.numpy()})
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import tensorflow as tf
from . import shape_util
def log10(x):
numerator = tf.math.log(x)
denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def get_num_batches(nsamples, batch_size, drop_remainders=True):
if nsamples is None or batch_size is None: return None
if drop_remainders: return math.floor(float(nsamples) / float(batch_size))
return math.ceil(float(nsamples) / float(batch_size))
def nan_to_zero(input_tensor):
return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor)
def bytes_to_string(array: np.ndarray, encoding: str = "utf-8"):
if array is None: return None
return [transcript.decode(encoding) for transcript in array]
def get_reduced_length(length, reduction_factor):
return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32)
def count_non_blank(tensor: tf.Tensor, blank: int or tf.Tensor = 0, axis=None):
return tf.reduce_sum(tf.where(tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor)), axis=axis)
def merge_two_last_dims(x):
b, _, f, c = shape_util.shape_list(x)
return tf.reshape(x, shape=[b, -1, f * c])
def merge_repeated(yseqs, blank=0):
result = tf.reshape(yseqs[0], [1])
U = shape_util.shape_list(yseqs)[0]
i = tf.constant(1, dtype=tf.int32)
def _cond(i, result, yseqs, U): return tf.less(i, U)
def _body(i, result, yseqs, U):
if yseqs[i] != result[-1]:
result = tf.concat([result, [yseqs[i]]], axis=-1)
return i + 1, result, yseqs, U
_, result, _, _ = tf.while_loop(
_cond,
_body,
loop_vars=[i, result, yseqs, U],
shape_invariants=(
tf.TensorShape([]),
tf.TensorShape([None]),
tf.TensorShape([None]),
tf.TensorShape([])
)
)
return tf.pad(result, [[U - shape_util.shape_list(result)[0], 0]], constant_values=blank)
def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor:
with tf.name_scope("find_max_length_prediction_tfarray"):
index = tf.constant(0, dtype=tf.int32)
total = tfarray.size()
max_length = tf.constant(0, dtype=tf.int32)
def condition(index, _): return tf.less(index, total)
def body(index, max_length):
prediction = tfarray.read(index)
length = tf.shape(prediction)[0]
max_length = tf.where(tf.greater(length, max_length), length, max_length)
return index + 1, max_length
index, max_length = tf.while_loop(condition, body, loop_vars=[index, max_length], swap_memory=False)
return max_length
def pad_prediction_tfarray(tfarray: tf.TensorArray, blank: int or tf.Tensor) -> tf.TensorArray:
with tf.name_scope("pad_prediction_tfarray"):
index = tf.constant(0, dtype=tf.int32)
total = tfarray.size()
max_length = find_max_length_prediction_tfarray(tfarray) + 1
def condition(index, _): return tf.less(index, total)
def body(index, tfarray):
prediction = tfarray.read(index)
prediction = tf.pad(
prediction, paddings=[[0, max_length - tf.shape(prediction)[0]]],
mode="CONSTANT", constant_values=blank
)
tfarray = tfarray.write(index, prediction)
return index + 1, tfarray
index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False)
return tfarray
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from nltk.metrics import distance
import tensorflow as tf
from . import math_util
def execute_wer(decode, target):
decode = math_util.bytes_to_string(decode)
target = math_util.bytes_to_string(target)
dis = 0.0
length = 0.0
for dec, tar in zip(decode, target):
words = set(dec.split() + tar.split())
word2char = dict(zip(words, range(len(words))))
new_decode = [chr(word2char[w]) for w in dec.split()]
new_target = [chr(word2char[w]) for w in tar.split()]
dis += distance.edit_distance(''.join(new_decode), ''.join(new_target))
length += len(tar.split())
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
def wer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Word Error Rate
Args:
decode (np.ndarray): array of prediction texts
target (np.ndarray): array of groundtruth texts
Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
"""
return tf.numpy_function(execute_wer, inp=[decode, target], Tout=[tf.float32, tf.float32])
def execute_cer(decode, target):
decode = math_util.bytes_to_string(decode)
target = math_util.bytes_to_string(target)
dis = 0
length = 0
for dec, tar in zip(decode, target):
dis += distance.edit_distance(dec, tar)
length += len(tar)
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
def cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Character Error Rate
Args:
decode (np.ndarray): array of prediction texts
target (np.ndarray): array of groundtruth texts
Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
"""
return tf.numpy_function(execute_cer, inp=[decode, target], Tout=[tf.float32, tf.float32])
def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Tensorflwo Charactor Error rate
Args:
decoder (tf.Tensor): tensor shape [B]
target (tf.Tensor): tensor shape [B]
Returns:
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
"""
decode = tf.strings.bytes_split(decode) # [B, N]
target = tf.strings.bytes_split(target) # [B, M]
distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B]
lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B]
return tf.reduce_sum(distances), tf.reduce_sum(lengths)
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
def shape_list(x, out_type=tf.int32):
"""Deal with dynamic shape in tensorflow cleanly."""
static = x.shape.as_list()
dynamic = tf.shape(x, out_type=out_type)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_shape_invariants(tensor):
shapes = shape_list(tensor)
return tf.TensorShape([i if isinstance(i, int) else None for i in shapes])
def get_float_spec(tensor):
shape = get_shape_invariants(tensor)
return tf.TensorSpec(shape, dtype=tf.float32)
import tensorflow as tf
from tensorflow.python.keras import backend
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import ops
from tensorflow.python.eager import context
from tensorflow.python.util import nest
from tensorflow.python.ops import variables
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import values as ds_values
def _minimum_control_deps(outputs):
"""Returns the minimum control dependencies to ensure step succeeded."""
if context.executing_eagerly():
return [] # Control dependencies not needed.
outputs = nest.flatten(outputs, expand_composites=True)
for out in outputs:
# Variables can't be control dependencies.
if not isinstance(out, variables.Variable):
return [out] # Return first Tensor or Op from outputs.
return [] # No viable Tensor or Op to use for control deps.
def reduce_per_replica(values, strategy):
"""Reduce PerReplica objects.
Args:
values: Structure of `PerReplica` objects or `Tensor`s. `Tensor`s are
returned as-is.
strategy: `tf.distribute.Strategy` object.
reduction: One of 'first', 'concat'.
Returns:
Structure of `Tensor`s.
"""
def _reduce(v):
"""Reduce a single `PerReplica` object."""
if _collective_all_reduce_multi_worker(strategy):
return _multi_worker_concat(v, strategy)
if not isinstance(v, ds_values.PerReplica):
return v
if _is_tpu_multi_host(strategy):
return _tpu_multi_host_concat(v, strategy)
else:
return concat(strategy.unwrap(v))
return nest.map_structure(_reduce, values)
def concat(tensors, axis=0):
if len(tensors[0].shape) == 0:
return tf.math.add_n(tensors)
"""Concats `tensor`s along `axis`."""
if isinstance(tensors[0], sparse_tensor.SparseTensor):
return sparse_ops.sparse_concat_v2(axis=axis, sp_inputs=tensors)
return array_ops.concat(tensors, axis=axis)
def _collective_all_reduce_multi_worker(strategy):
return (isinstance(strategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategy)
) and strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
def _is_scalar(x):
return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
def write_scalar_summaries(logs, step):
for name, value in logs.items():
if _is_scalar(value):
summary_ops_v2.scalar('batch_' + name, value, step=step)
def _is_tpu_multi_host(strategy):
return (backend.is_tpu_strategy(strategy) and
strategy.extended.num_hosts > 1)
def _tpu_multi_host_concat(v, strategy):
"""Correctly order TPU PerReplica objects."""
replicas = strategy.unwrap(v)
# When distributed datasets are created from Tensors / NumPy,
# TPUStrategy.experimental_distribute_dataset shards data in
# (Replica, Host) order, and TPUStrategy.unwrap returns it in
# (Host, Replica) order.
# TODO(b/150317897): Figure out long-term plan here.
num_replicas_per_host = strategy.extended.num_replicas_per_host
ordered_replicas = []
for replica_id in range(num_replicas_per_host):
ordered_replicas += replicas[replica_id::num_replicas_per_host]
return concat(ordered_replicas)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment