"mmdet3d/vscode:/vscode.git/clone" did not exist on "2ebce2889a19b891b4f2873eb3b100c6a09bdb49"
Unverified Commit a81f8590 authored by karun's avatar karun Committed by GitHub
Browse files

Adding transformer based bytestream models (#10734)


Co-authored-by: default avatarArun Kandoor <akandoor@google.com>
parent 82a26070
...@@ -116,3 +116,19 @@ py_strict_library( ...@@ -116,3 +116,19 @@ py_strict_library(
"//layers:quantization_layers", "//layers:quantization_layers",
], ],
) )
py_strict_library(
name = "transformer_layers",
srcs = ["transformer_layers.py"],
srcs_version = "PY3",
deps = [
":embedding_layers",
# package tensorflow
"//layers:base_layers",
"//layers:dense_layers",
"//layers:normalization_layers",
"//layers:quantization_layers",
"//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for Transformer encoder."""
# pylint: disable=arguments-renamed
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import normalization_layers
from layers import quantization_layers
from tf_ops import tf_custom_ops_py
class SelfAttention(base_layers.BaseLayer):
"""Self attention encoder (not suitable for causal attention)."""
def __init__(self,
model_dimension,
num_heads,
attention_dropout_rate=0.0,
**kwargs):
self.model_dimension = model_dimension
self.num_heads = num_heads
self.filters = model_dimension // num_heads
self.dense_layers = [
dense_layers.BaseQDenseVarLen(
units=self.filters, activation=None, **kwargs)
for i in range(num_heads * 3)
]
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
self.attention_dropout_rate = attention_dropout_rate
self.qconcat = quantization_layers.ConcatQuantization(axis=2, **kwargs)
super(SelfAttention, self).__init__(**kwargs)
def call(self, inputs, mask, inverse_normalizer, attn_mask=None):
batch_size = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
inputs_rank2 = tf.reshape(inputs, [-1, self.model_dimension])
mask_rank2 = tf.reshape(mask, [-1, 1])
tensors = [
layer(inputs_rank2, mask_rank2, inverse_normalizer)
for layer in self.dense_layers
]
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
tensors = [
tf.reshape(tensor, [batch_size, -1, self.filters])
for tensor in tensors
]
context = []
if attn_mask is None:
attn_mask = tf.matmul(mask, tf.transpose(mask, [0, 2, 1]))
if (self.attention_dropout_rate > 0.0 and
self.parameters.mode == base_layers.TRAIN):
attn_mask *= self.random_drop_to_zero(attn_mask,
self.attention_dropout_rate)
invalid_mask = (1 - attn_mask) * self.parameters.invalid_logit
for _ in range(self.num_heads):
keys = tensors.pop()
values = tensors.pop()
queries = tensors.pop()
# Attention is not scaled dot product, batch normalization compensates
# for it.
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
queries = tf.transpose(queries, [0, 2, 1])
attn_logits = self.qactivation(tf.matmul(keys, queries))
attn_logits_masked = attn_logits * attn_mask + invalid_mask
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
context.append(tf.matmul(attention, values))
else:
queries = tf.transpose(queries)
attn_logits_masked = self.qactivation(tf.matmul(keys, queries))
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
ctx = tf.matmul(attention, values)
ctx = tf.reshape(ctx, [1, -1, self.filters])
context.append(ctx)
return self.qconcat(context)
class SelfAttentionV2(base_layers.BaseLayer):
"""Self attention encoder (not suitable for causal attention)."""
def __init__(self,
model_dimension,
num_heads,
attention_dropout_rate=0.0,
**kwargs):
self.model_dimension = model_dimension
self.num_heads = num_heads
self.filters = model_dimension // num_heads
self.dense_layers = dense_layers.BaseQDenseVarLen(
units=model_dimension * 3, activation=None, **kwargs)
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
self.attention_dropout_rate = attention_dropout_rate
self.qconcat = quantization_layers.ConcatQuantization(axis=1, **kwargs)
super(SelfAttentionV2, self).__init__(**kwargs)
def call(self, inputs, mask, inverse_normalizer, attn_mask=None):
bsz = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
inputs_rank2 = tf.reshape(inputs, [-1, self.model_dimension])
mask_rank2 = tf.reshape(mask, [-1, 1])
tensors = self.dense_layers(inputs_rank2, mask_rank2, inverse_normalizer)
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
tensors = tf.reshape(tensors, [bsz, -1, 3, self.num_heads, self.filters])
tensors = tf.unstack(tensors, axis=2)
else:
tensors = tf.split(tensors, self.num_heads * 3, axis=1)
if attn_mask is None:
attn_mask = tf.matmul(mask, mask, transpose_b=True)
if (self.attention_dropout_rate > 0.0 and
self.parameters.mode == base_layers.TRAIN):
attn_mask *= self.random_drop_to_zero(attn_mask,
self.attention_dropout_rate)
attn_mask = tf.expand_dims(attn_mask, axis=1)
invalid_mask = (1 - attn_mask) * self.parameters.invalid_logit
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
queries = tf.transpose(tensors[0], [0, 2, 1, 3])
keys = tf.transpose(tensors[1], [0, 2, 1, 3])
values = tf.transpose(tensors[2], [0, 2, 1, 3])
attn_logits = self.qactivation(tf.matmul(queries, keys, transpose_b=True))
attn_logits_masked = attn_logits * attn_mask + invalid_mask
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
result = tf.matmul(attention, values)
result = tf.transpose(result, [0, 2, 1, 3])
result = tf.reshape(result, [bsz, -1, self.model_dimension])
return self.qconcat([result])
else:
context = []
for idx in range(self.num_heads):
queries = tensors[idx]
keys = tensors[idx + self.num_heads]
values = tensors[idx + self.num_heads * 2]
# Attention is not scaled dot product, batch normalization compensates
# for it.
attn_logits_masked = self.qactivation(
tf.matmul(queries, keys, transpose_b=True))
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
context.append(tf.matmul(attention, values))
result = self.qconcat(context)
return tf.reshape(result, [1, -1, self.model_dimension])
class TransformerEncoder(base_layers.BaseLayer):
"""Transformer Encoder."""
def __init__(self,
model_dimension,
num_heads,
intermediate_size,
initializer_stddev=0.02,
activation_dropout_rate=0.0,
attention_dropout_rate=0.0,
**kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.model_dimension = model_dimension
self.parameters.initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_stddev)
self.self_attn = SelfAttentionV2(
model_dimension,
num_heads,
attention_dropout_rate=attention_dropout_rate,
parameters=self.parameters)
self.prx = dense_layers.BaseQDenseVarLen(
model_dimension, activation=None, parameters=self.parameters)
self.upprx = dense_layers.BaseQDenseVarLen(
intermediate_size, parameters=self.parameters)
self.downprx = dense_layers.BaseQDenseVarLen(
model_dimension, activation=None, parameters=self.parameters)
self.activation_dropout_rate = activation_dropout_rate
self.ln1 = normalization_layers.LayerNormalization(**kwargs)
self.ln2 = normalization_layers.LayerNormalization(**kwargs)
self.q1 = quantization_layers.ActivationQuantization(**kwargs)
self.q2 = quantization_layers.ActivationQuantization(**kwargs)
def call(self, inputs, mask, inverse_normalizer, attn_mask=None):
batch_size = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
mask_rank2 = tf.reshape(mask, [-1, 1])
assert inputs.get_shape().as_list()[-1] == self.model_dimension
tensor = self.self_attn(inputs, mask, inverse_normalizer, attn_mask)
inputs = tf.reshape(inputs, [-1, self.model_dimension])
tensor = tf.reshape(tensor, [-1, self.model_dimension])
tensor = self.prx(tensor, mask_rank2, inverse_normalizer)
if (self.parameters.mode == base_layers.TRAIN and
self.activation_dropout_rate > 0.0):
tensor = tf.nn.dropout(tensor, rate=self.activation_dropout_rate)
inputs_plus_selfattn = self.q1(self.ln1(inputs + tensor))
ffn_up = self.upprx(inputs_plus_selfattn, mask_rank2, inverse_normalizer)
ffn_down = self.downprx(ffn_up, mask_rank2, inverse_normalizer)
if (self.parameters.mode == base_layers.TRAIN and
self.activation_dropout_rate > 0.0):
ffn_down = tf.nn.dropout(ffn_down, rate=self.activation_dropout_rate)
inputs_plus_ffn = self.q2(self.ln2(inputs_plus_selfattn + ffn_down))
return tf.reshape(inputs_plus_ffn, [batch_size, -1, self.model_dimension])
class TransformerEncoderStack(base_layers.BaseLayer):
"""Transformer Encoder."""
def __init__(self, num_layers, max_time_step, vocabulary_size, embedding_size,
model_dimension, num_heads, intermediate_size, **kwargs):
self.max_time_step = max_time_step
self.vocabulary_size = vocabulary_size
self.embedding_size = embedding_size
activation_dropout_rate = kwargs.pop('activation_dropout_rate', 0.0)
attention_dropout_rate = kwargs.pop('attention_dropout_rate', 0.0)
self.layers = []
for _ in range(num_layers):
self.layers.append(
TransformerEncoder(
model_dimension=model_dimension,
num_heads=num_heads,
intermediate_size=intermediate_size,
activation_dropout_rate=activation_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
**kwargs))
self.embedding = embedding_layers.EmbeddingLayer(
shape=[self.vocabulary_size, self.embedding_size], **kwargs)
self.positional_embedding = embedding_layers.EmbeddingLayer(
shape=[self.max_time_step, self.embedding_size], **kwargs)
self.ln = normalization_layers.LayerNormalization(**kwargs)
self.qact = quantization_layers.ActivationQuantization(**kwargs)
super(TransformerEncoderStack, self).__init__(**kwargs)
def call(self, input_indices, sequence_length):
mask_rank2 = tf.sequence_mask(
sequence_length, tf.shape(input_indices)[1], dtype=tf.float32)
mask_rank3 = tf.expand_dims(mask_rank2, axis=2)
inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask_rank3))
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
sequence_length = tf.reduce_sum(input_indices + 1 - input_indices)
pos_indices = tf.range(sequence_length, dtype=tf.int32)
pos_indices = tf.reshape(pos_indices, [1, -1])
else:
pos_indices = tf.cumsum(mask_rank2, axis=1, exclusive=True)
pos_indices = tf.cast(pos_indices, dtype=tf.int32)
input_values = self.embedding(input_indices)
pos_values = self.positional_embedding(pos_indices)
inputs = self.qact(self.ln(input_values + pos_values))
attn_mask = tf.matmul(mask_rank3, tf.transpose(mask_rank3, [0, 2, 1]))
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
inputs = inputs * mask_rank3
for layer in self.layers:
outputs = layer(inputs, mask_rank3, inverse_normalizer, attn_mask)
inputs = outputs
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
outputs = outputs * mask_rank3
return outputs
class TransformerEncoderStackWithInputEmbedding(TransformerEncoderStack):
"""Transformer Encoder."""
def call(self, inputs, sequence_length):
mask_rank2 = tf.sequence_mask(
sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
mask_rank3 = tf.expand_dims(mask_rank2, axis=2)
inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask_rank3))
attn_mask = tf.matmul(mask_rank3, tf.transpose(mask_rank3, [0, 2, 1]))
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
inputs = inputs * mask_rank3
for layer in self.layers:
outputs = layer(inputs, mask_rank3, inverse_normalizer, attn_mask)
inputs = outputs
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
outputs = outputs * mask_rank3
return outputs
class FunnelAttention(base_layers.BaseLayer):
"""Self attention encoder (not suitable for causal attention)."""
def __init__(self,
model_dimension,
num_heads,
attention_dropout_rate=0.0,
**kwargs):
self.model_dimension = model_dimension
self.num_heads = num_heads
self.filters = model_dimension // num_heads
self.q_dense_layer = dense_layers.BaseQDenseVarLen(
units=model_dimension, activation=None, **kwargs)
self.kv_dense_layer = dense_layers.BaseQDenseVarLen(
units=model_dimension * 2, activation=None, **kwargs)
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
self.attention_dropout_rate = attention_dropout_rate
self.qconcat = quantization_layers.ConcatQuantization(axis=1, **kwargs)
super(FunnelAttention, self).__init__(**kwargs)
def call(self, inputs, mask, inverse_normalizer, memory, memory_mask,
memory_inverse_normalizer, attn_mask):
bsz = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
self._assert_rank_and_type(memory, 3)
self._assert_rank_and_type(memory_mask, 3)
assert memory.get_shape().as_list()[-1] == self.model_dimension
inputs_rank2 = tf.reshape(inputs, [-1, self.model_dimension])
mask_rank2 = tf.reshape(mask, [-1, 1])
q_tensor = self.q_dense_layer(inputs_rank2, mask_rank2, inverse_normalizer)
memory_rank2 = tf.reshape(memory, [-1, self.model_dimension])
memory_mask_rank2 = tf.reshape(memory_mask, [-1, 1])
kv_tensors = self.kv_dense_layer(memory_rank2, memory_mask_rank2,
inverse_normalizer)
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
q_tensor = tf.reshape(q_tensor, [bsz, -1, self.num_heads, self.filters])
kv_tensors = tf.reshape(kv_tensors,
[bsz, -1, 2, self.num_heads, self.filters])
kv_tensors = tf.unstack(kv_tensors, axis=2)
else:
q_tensor = tf.split(q_tensor, self.num_heads, axis=1)
kv_tensors = tf.split(kv_tensors, self.num_heads * 2, axis=1)
attn_mask = tf.expand_dims(attn_mask, axis=1)
invalid_mask = (1 - attn_mask) * self.parameters.invalid_logit
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
queries = tf.transpose(q_tensor, [0, 2, 1, 3])
keys = tf.transpose(kv_tensors[0], [0, 2, 1, 3])
values = tf.transpose(kv_tensors[1], [0, 2, 1, 3])
attn_logits = self.qactivation(tf.matmul(queries, keys, transpose_b=True))
attn_logits_masked = attn_logits * attn_mask + invalid_mask
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
result = tf.matmul(attention, values)
result = tf.transpose(result, [0, 2, 1, 3])
result = tf.reshape(result, [bsz, -1, self.model_dimension])
return self.qconcat([result])
else:
context = []
for idx in range(self.num_heads):
queries = q_tensor[idx]
keys = kv_tensors[idx]
values = kv_tensors[idx + self.num_heads]
# Attention is not scaled dot product, batch normalization compensates
# for it.
attn_logits_masked = self.qactivation(
tf.matmul(queries, keys, transpose_b=True))
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
context.append(tf.matmul(attention, values))
result = self.qconcat(context)
return tf.reshape(result, [1, -1, self.model_dimension])
class FunnelTransformerEncoder(base_layers.BaseLayer):
"""Transformer Encoder."""
def __init__(self,
model_dimension,
num_heads,
intermediate_size,
initializer_stddev=0.02,
activation_dropout_rate=0.0,
attention_dropout_rate=0.0,
**kwargs):
super(FunnelTransformerEncoder, self).__init__(**kwargs)
self.model_dimension = model_dimension
self.parameters.initializer = tf.keras.initializers.TruncatedNormal(
stddev=initializer_stddev)
self.self_attn = FunnelAttention(
model_dimension,
num_heads,
attention_dropout_rate=attention_dropout_rate,
parameters=self.parameters)
self.prx = dense_layers.BaseQDenseVarLen(
model_dimension, activation=None, parameters=self.parameters)
self.upprx = dense_layers.BaseQDenseVarLen(
intermediate_size, parameters=self.parameters)
self.downprx = dense_layers.BaseQDenseVarLen(
model_dimension, activation=None, parameters=self.parameters)
self.activation_dropout_rate = activation_dropout_rate
self.ln1 = normalization_layers.LayerNormalization(**kwargs)
self.ln2 = normalization_layers.LayerNormalization(**kwargs)
self.q1 = quantization_layers.ActivationQuantization(**kwargs)
self.q2 = quantization_layers.ActivationQuantization(**kwargs)
def call(self, inputs, mask, inverse_normalizer, memory, memory_mask,
memory_inverse_normalizer, attn_mask):
batch_size = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
mask_rank2 = tf.reshape(mask, [-1, 1])
assert inputs.get_shape().as_list()[-1] == self.model_dimension
tensor = self.self_attn(inputs, mask, inverse_normalizer, memory,
memory_mask, memory_inverse_normalizer, attn_mask)
inputs = tf.reshape(inputs, [-1, self.model_dimension])
tensor = tf.reshape(tensor, [-1, self.model_dimension])
tensor = self.prx(tensor, mask_rank2, inverse_normalizer)
if (self.parameters.mode == base_layers.TRAIN and
self.activation_dropout_rate > 0.0):
tensor = tf.nn.dropout(tensor, rate=self.activation_dropout_rate)
inputs_plus_selfattn = self.q1(self.ln1(inputs + tensor))
ffn_up = self.upprx(inputs_plus_selfattn, mask_rank2, inverse_normalizer)
ffn_down = self.downprx(ffn_up, mask_rank2, inverse_normalizer)
if (self.parameters.mode == base_layers.TRAIN and
self.activation_dropout_rate > 0.0):
ffn_down = tf.nn.dropout(ffn_down, rate=self.activation_dropout_rate)
inputs_plus_ffn = self.q2(self.ln2(inputs_plus_selfattn + ffn_down))
return tf.reshape(inputs_plus_ffn, [batch_size, -1, self.model_dimension])
class FunnelTransformerEncoderStack(base_layers.BaseLayer):
"""Transformer Encoder."""
def __init__(self, num_layers, max_time_step, vocabulary_size, embedding_size,
model_dimension, num_heads, intermediate_size, **kwargs):
self.max_time_step = max_time_step
self.pool_windows = kwargs.pop('pool_windows', [])
assert len(self.pool_windows) == num_layers
self.vocabulary_size = vocabulary_size
activation_dropout_rate = kwargs.pop('activation_dropout_rate', 0.0)
attention_dropout_rate = kwargs.pop('attention_dropout_rate', 0.0)
self.layers = []
for _ in range(num_layers):
self.layers.append(
FunnelTransformerEncoder(
model_dimension=model_dimension,
num_heads=num_heads,
intermediate_size=intermediate_size,
activation_dropout_rate=activation_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
**kwargs))
super(FunnelTransformerEncoderStack, self).__init__(**kwargs)
def call(self, inputs, sequence_length):
mask_rank2 = tf.sequence_mask(
sequence_length, tf.shape(inputs)[1], dtype=tf.float32)
mask_rank3 = tf.expand_dims(mask_rank2, axis=2)
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
inputs = inputs * mask_rank3
pooled_inputs = inputs
pooled_mask = mask_rank3
pooled_inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(pooled_mask))
memory = pooled_inputs
memory_mask = pooled_mask
memory_inverse_normalizer = pooled_inverse_normalizer
for i, layer in enumerate(self.layers):
if self.pool_windows[i] > 1:
pooled_inputs = tf.nn.avg_pool(
pooled_inputs, [self.pool_windows[i]],
strides=[self.pool_windows[i]],
padding='SAME')
pooled_mask = pooled_mask[:, ::self.pool_windows[i], :]
pooled_inverse_normalizer = tf.math.reciprocal(
tf.reduce_sum(pooled_mask))
attn_mask = tf.matmul(pooled_mask, memory_mask, transpose_b=True)
pooled_outputs = layer(pooled_inputs, pooled_mask,
pooled_inverse_normalizer, memory, memory_mask,
memory_inverse_normalizer, attn_mask)
pooled_inputs = pooled_outputs
pooled_inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(pooled_mask))
memory = pooled_inputs
memory_mask = pooled_mask
memory_inverse_normalizer = pooled_inverse_normalizer
if self.parameters.mode not in [base_layers.PREDICT, base_layers.TFLITE]:
pooled_outputs = pooled_outputs * pooled_mask
return pooled_outputs, pooled_mask
class DecoderMultiheadAttention(base_layers.BaseLayer):
"""Multihead attention for decoder."""
def __init__(self,
model_dimension,
num_heads,
attention_dropout_rate=0.0,
cached_kv=False,
**kwargs):
self.model_dimension = model_dimension
self.num_heads = num_heads
self.filters = model_dimension // num_heads
self.cached_kv = cached_kv
self.q_dense_layers = dense_layers.BaseQDense(
units=model_dimension,
activation=None,
normalize=False,
bias=False,
**kwargs)
self.kv_dense_layers = dense_layers.BaseQDenseVarLen(
units=model_dimension * 2, activation=None, **kwargs)
self.qactivation = quantization_layers.ActivationQuantization(**kwargs)
self.attention_dropout_rate = attention_dropout_rate
self.qconcat = quantization_layers.ConcatQuantization(axis=1, **kwargs)
super(DecoderMultiheadAttention, self).__init__(**kwargs)
def call(self,
inputs,
input_mask,
input_inverse_normalizer,
memory=None,
memory_mask=None,
memory_inverse_normalizer=None,
attn_mask=None):
bsz = self.get_batch_dimension(inputs)
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(input_mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
inputs_rank2 = tf.reshape(inputs, [-1, self.model_dimension])
q_tensor = self.q_dense_layers(inputs_rank2)
if memory is not None:
self._assert_rank_and_type(memory, 2)
self._assert_rank_and_type(memory_mask, 2)
if self.cached_kv:
# Keys and Values are cached and reused at each layer.
assert memory.get_shape().as_list()[1] == 2 * self.model_dimension
kv_tensors = memory
else:
kv_tensors = self.kv_dense_layers(memory, memory_mask,
memory_inverse_normalizer)
else:
kv_tensors = self.kv_dense_layers(inputs_rank2)
if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
q_tensor = tf.reshape(q_tensor, [bsz, -1, self.num_heads, self.filters])
kv_tensors = tf.reshape(kv_tensors,
[bsz, -1, 2, self.num_heads, self.filters])
kv_tensors = tf.unstack(kv_tensors, axis=2)
else:
q_tensor = tf.split(q_tensor, self.num_heads, axis=1)
kv_tensors = tf.split(kv_tensors, self.num_heads * 2, axis=1)
if self.parameters.mode in [base_layers.TRAIN, base_layers.EVAL]:
assert attn_mask is not None
if (self.attention_dropout_rate > 0.0 and
self.parameters.mode == base_layers.TRAIN):
attn_mask *= self.random_drop_to_zero(attn_mask,
self.attention_dropout_rate)
attn_mask = tf.expand_dims(attn_mask, 1)
invalid_mask = (1 - attn_mask) * self.parameters.invalid_logit
queries = tf.transpose(q_tensor, [0, 2, 1, 3])
keys = tf.transpose(kv_tensors[0], [0, 2, 1, 3])
values = tf.transpose(kv_tensors[1], [0, 2, 1, 3])
attn_logits = self.qactivation(tf.matmul(queries, keys, transpose_b=True))
attn_logits_masked = attn_logits * attn_mask + invalid_mask
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
result = tf.matmul(attention, values)
result = tf.transpose(result, [0, 2, 1, 3])
result = tf.reshape(result, [bsz, -1, self.model_dimension])
return self.qconcat([result])
else:
# We need to invoke the keras layer before calling APIs that it provides
# such as quantize_using_range.
self.qconcat(None)
context = []
for head in range(self.num_heads):
queries = q_tensor[head]
if self.parameters.mode == base_layers.PREDICT:
# PREDICT mode assumes callers tile and merge beam size with batch
# size. Hence extracting the first entry in the tile to compute
# attention.
keys = tf.split(kv_tensors[head], bsz, axis=0)
keys = keys[0]
values = tf.split(kv_tensors[head + self.num_heads], bsz, axis=0)
values = values[0]
else:
keys = kv_tensors[head]
values = kv_tensors[head + self.num_heads]
attn_logits_masked = self.qactivation(
tf.matmul(queries, keys, transpose_b=True))
attention = tf.nn.softmax(attn_logits_masked)
attention = self.qrange_sigmoid(attention, tf_only=True)
context.append(
self.qconcat.quantize_using_range(tf.matmul(attention, values)))
# Concatenating heads along axis 1.
result = self.qconcat.quantize_using_range(tf.concat(context, axis=1))
return tf.reshape(result, [-1, 1, self.model_dimension])
class DecoderUniformAttention(base_layers.BaseLayer):
"""Decoder uniform attention."""
def __init__(self,
model_dimension,
max_time_step,
attention_dropout_rate=0.0,
beam_size=1,
**kwargs):
self.model_dimension = model_dimension
self.max_time_step = max_time_step
self.beam_size = beam_size
self.causal_mask = tf.expand_dims(
tf.linalg.band_part(tf.ones([max_time_step, max_time_step]), -1, 0), 0)
self.dense_layers = dense_layers.BaseQDenseVarLen(
units=model_dimension,
activation=None,
normalize=False,
bias=False,
rank=3,
**kwargs)
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
super(DecoderUniformAttention, self).__init__(**kwargs)
def get_uniform_attention(self, attn_mask=None):
"""Generates uniform attention matrix using `causal_mask`."""
mask = tf.math.divide_no_nan(
self.causal_mask,
tf.reduce_sum(self.causal_mask, axis=-1, keepdims=True))
if attn_mask is not None:
self._assert_rank_and_type(attn_mask, 3)
mask = mask * attn_mask
return mask
def call(self,
inputs,
mask,
inverse_normalizer,
step=None,
beam_indices=None,
cache=None,
attn_mask=None):
self._assert_rank_and_type(inputs, 3)
self._assert_rank_and_type(mask, 3)
assert inputs.get_shape().as_list()[-1] == self.model_dimension
layer_out = self.dense_layers(inputs, mask, inverse_normalizer)
# TFLite mode is handled with a custom op.
if self.parameters.mode == base_layers.TFLITE:
assert beam_indices is not None
assert step is not None
layer_out = tf_custom_ops_py.uniform_causal_attn(
layer_out, step, beam_indices, self.model_dimension, self.beam_size)
else:
# Cache is used for TF Predict and Eval modes.
if cache is None:
attention_matrix = self.get_uniform_attention(attn_mask)
layer_out = tf.matmul(attention_matrix, layer_out)
else:
assert self.parameters.mode in [base_layers.PREDICT, base_layers.EVAL]
assert step is not None
cache['uniform_avg'] = layer_out + cache['uniform_avg']
layer_out = cache['uniform_avg'] / tf.cast(step, dtype=tf.float32)
return self.qoutput(layer_out)
...@@ -54,3 +54,52 @@ py_library( ...@@ -54,3 +54,52 @@ py_library(
"//tf_ops:tf_custom_ops_py", "//tf_ops:tf_custom_ops_py",
], ],
) )
py_library(
name = "charformer",
srcs = ["charformer.py"],
srcs_version = "PY3",
deps = [
":transformer_encoder",
# package tensorflow
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:normalization_layers",
"//layers:quantization_layers",
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
py_library(
name = "transformer_encoder",
srcs = ["transformer_encoder.py"],
srcs_version = "PY3",
deps = [
# package absl/logging
# package tensorflow
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:transformer_layers",
# "//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
py_library(
name = "transformer_uniform_attn_decoder",
srcs = ["transformer_uniform_attn_decoder.py"],
srcs_version = "PY3",
deps = [
# package absl/logging
# package tensorflow
# tensor2tensor/utils:beam_search",
"//layers:base_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:transformer_layers",
"//tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Charformer based model for in-training tokenization."""
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import misc_layers
from layers import normalization_layers
from layers import quantization_layers
from models import transformer_encoder
class Encoder(tf.keras.layers.Layer):
"""Encoder with GBST and Transformer layers."""
def __init__(self, config, mode, **kwargs):
super(Encoder, self).__init__(**kwargs)
def _get_params(varname, default_value=None):
value = config[varname] if varname in config else default_value
default = "" if varname in config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
_get_params("labels", [])
_get_params("regularizer_scale")
_get_params("quantize")
_get_params("feature_size")
_get_params("bottleneck_size")
self.max_seq_len = config.get("max_seq_len", 128)
self.gbst_max_token_len = config.get("gbst_max_token_len", 128)
# Including 3 additional special token ids (0=padding, 1=EOS, 2=UNK).
self.vocabulary_size = config.get("vocabulary_size", 259)
self.parameters = base_layers.Parameters(
mode, quantize=self.quantize, regularizer_scale=self.regularizer_scale)
self.embedding = embedding_layers.EmbeddingLayer(
shape=[self.vocabulary_size, self.feature_size],
parameters=self.parameters)
self.gbst_downsample_rate = config.get("gbst_downsample_rate", 1)
self.positional_embedding = embedding_layers.EmbeddingLayer(
shape=[self.gbst_max_token_len, self.feature_size],
parameters=self.parameters)
self.ln = normalization_layers.LayerNormalization(
parameters=self.parameters)
self.qact = quantization_layers.ActivationQuantization(
parameters=self.parameters)
self.bottleneck_layer = None
gbst_size = self.feature_size
if self.bottleneck_size != self.feature_size:
self.bottleneck_layer = dense_layers.BaseQDenseVarLen(
self.bottleneck_size,
rank=3,
normalize=False,
activation=None,
parameters=self.parameters)
gbst_size = self.bottleneck_size
self.gbst_max_subword_block_width = config.get(
"gbst_max_subword_block_width", 5)
self.gbst_conv_kernel_size = config.get("gbst_conv_kernel_size", 5)
self.gbst_block_mixing_mode = config.get("gbst_block_mixing_mode", None)
self.gbst_layer = misc_layers.GBSTLayerV2(
feature_size=gbst_size,
max_seq_len=self.gbst_max_token_len,
downsample_rate=self.gbst_downsample_rate,
max_subword_block_width=self.gbst_max_subword_block_width,
conv_kernel_size=self.gbst_conv_kernel_size,
block_mixing_mode=self.gbst_block_mixing_mode,
parameters=self.parameters)
self.pool_windows = config.get("pool_windows", None)
if self.pool_windows:
self.transformer_encoder_layer = transformer_encoder.FunnelTransformerModel(
config, mode)
else:
self.transformer_encoder_layer = transformer_encoder.ModelWithEmbeddings(
config, mode)
self.attention_pool = misc_layers.AttentionPooling(
parameters=self.parameters)
self.num_classes = len(self.labels)
if self.num_classes:
self.final_fc = dense_layers.BaseQDense(
units=self.num_classes,
rank=2,
parameters=self.parameters,
activation=None)
def call(self, token_ids, seq_length):
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
mask_rank2 = tf.ones(tf.shape(token_ids), dtype=tf.int32)
seq_length = tf.reduce_sum(mask_rank2, axis=1)
pos_indices = tf.cumsum(mask_rank2, axis=1, exclusive=True)
pos_indices = tf.cast(pos_indices, dtype=tf.int32)
pos_indices = tf.reshape(pos_indices, [1, -1])
else:
mask_rank2 = tf.sequence_mask(
seq_length, tf.shape(token_ids)[1], dtype=tf.float32)
pos_indices = tf.cumsum(mask_rank2, axis=1, exclusive=True)
pos_indices = tf.cast(pos_indices, dtype=tf.int32)
input_values = self.embedding(token_ids)
pos_values = self.positional_embedding(pos_indices)
input_embeds = self.qact(self.ln(input_values + pos_values))
if self.bottleneck_layer is not None:
maskr3 = tf.expand_dims(mask_rank2, axis=2)
maskr3 = tf.cast(maskr3, tf.float32)
bottleneck_output = self.bottleneck_layer(input_embeds, maskr3)
else:
bottleneck_output = input_embeds
gbst_output = self.gbst_layer(bottleneck_output, seq_length)
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
mask_rank2 = tf.ones(tf.shape(gbst_output)[:-1], dtype=tf.float32)
seq_length = tf.reduce_sum(mask_rank2, axis=1)
else:
seq_length = seq_length / self.gbst_downsample_rate
if self.pool_windows:
outputs, mask = self.transformer_encoder_layer(gbst_output,
seq_length)
inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
pre_logits = self.attention_pool(outputs, mask, inverse_normalizer)
else:
outputs = self.transformer_encoder_layer(gbst_output, seq_length)
mask = tf.sequence_mask(
seq_length, tf.shape(outputs)[1], dtype=tf.float32)
inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask))
maskr3 = tf.expand_dims(mask, axis=2)
pre_logits = self.attention_pool(outputs, maskr3, inverse_normalizer)
if self.num_classes:
return self.final_fc(pre_logits)
else:
return pre_logits
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of pQRNN model."""
# pylint: disable=arguments-renamed
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import transformer_layers
class Model(tf.keras.layers.Layer):
"""Quantized transformer encoder."""
def __init__(self, config, mode):
def _get_params(varname, default_value=None):
value = config[varname] if varname in config else default_value
default = "" if varname in config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
_get_params("intermediate_size")
_get_params("max_time_step")
_get_params("embedding_size")
_get_params("vocabulary_size")
_get_params("num_layers")
_get_params("labels")
_get_params("regularizer_scale")
_get_params("num_heads")
_get_params("model_dimension")
_get_params("quantize")
_get_params("activation_dropout_rate", 0.0)
_get_params("attention_dropout_rate", 0.0)
self.parameters = base_layers.Parameters(mode, self.quantize,
self.regularizer_scale)
super(Model, self).__init__()
def build(self, input_shape):
self.transformer = transformer_layers.TransformerEncoderStack(
parameters=self.parameters,
num_layers=self.num_layers,
intermediate_size=self.intermediate_size,
embedding_size=self.embedding_size,
max_time_step=self.max_time_step,
num_heads=self.num_heads,
model_dimension=self.model_dimension,
vocabulary_size=self.vocabulary_size,
activation_dropout_rate=self.activation_dropout_rate,
attention_dropout_rate=self.attention_dropout_rate)
def call(self, indices, sequence_length):
return self.transformer(indices, sequence_length)
class ModelWithEmbeddings(Model):
"""Quantized transformer encoder which takes embeddings instead of indices."""
def build(self, input_shape):
self.transformer_with_input_embedding = transformer_layers.TransformerEncoderStackWithInputEmbedding(
parameters=self.parameters,
num_layers=self.num_layers,
intermediate_size=self.intermediate_size,
embedding_size=self.embedding_size,
max_time_step=self.max_time_step,
num_heads=self.num_heads,
model_dimension=self.model_dimension,
vocabulary_size=self.vocabulary_size,
activation_dropout_rate=self.activation_dropout_rate,
attention_dropout_rate=self.attention_dropout_rate)
def call(self, embeddings, sequence_length):
return self.transformer_with_input_embedding(embeddings, sequence_length)
class FunnelTransformerModel(Model):
"""Quantized transformer encoder which takes embeddings instead of indices."""
def __init__(self, config, mode):
self.pool_windows = config.get("pool_windows", None)
super(FunnelTransformerModel, self).__init__(config, mode)
def build(self, input_shape):
self.funnel_transformer = transformer_layers.FunnelTransformerEncoderStack(
parameters=self.parameters,
num_layers=self.num_layers,
intermediate_size=self.intermediate_size,
embedding_size=self.embedding_size,
max_time_step=self.max_time_step,
num_heads=self.num_heads,
model_dimension=self.model_dimension,
vocabulary_size=self.vocabulary_size,
activation_dropout_rate=self.activation_dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
pool_windows=self.pool_windows)
def call(self, embeddings, sequence_length):
return self.funnel_transformer(embeddings, sequence_length)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of Transformer decoder model."""
import math
from absl import logging
from tensor2tensor.utils import beam_search
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import normalization_layers
from layers import quantization_layers
from layers import transformer_layers
class TransformerUniformAttnDecoder(base_layers.BaseLayer):
"""Transformer Uniform Attention Decoder."""
def __init__(self,
model_dimension,
max_time_step,
num_heads,
intermediate_size,
activation_dropout_rate=0.0,
attention_dropout_rate=0.0,
beam_size=1,
cached_kv=False,
**kwargs):
self.model_dimension = model_dimension
self.decoder_uniform_attn = transformer_layers.DecoderUniformAttention(
model_dimension,
max_time_step,
attention_dropout_rate=attention_dropout_rate,
beam_size=beam_size,
**kwargs)
self.multihead_cross_attn = transformer_layers.DecoderMultiheadAttention(
model_dimension,
num_heads,
cached_kv=cached_kv,
attention_dropout_rate=attention_dropout_rate,
**kwargs)
self.prx = dense_layers.BaseQDense(
model_dimension, activation=None, normalize=False, bias=False, **kwargs)
self.upprx = dense_layers.BaseQDense(
intermediate_size, normalize=False, **kwargs)
self.downprx = dense_layers.BaseQDense(
model_dimension, activation=None, normalize=False, **kwargs)
self.activation_dropout_rate = activation_dropout_rate
self.ln1 = normalization_layers.LayerNormalization(**kwargs)
self.ln2 = normalization_layers.LayerNormalization(**kwargs)
self.q0 = quantization_layers.ActivationQuantization(**kwargs)
self.q1 = quantization_layers.ActivationQuantization(**kwargs)
self.q2 = quantization_layers.ActivationQuantization(**kwargs)
super(TransformerUniformAttnDecoder, self).__init__(**kwargs)
def call(self,
dec_inputs,
dec_mask,
dec_inverse_normalizer,
enc_output,
enc_mask,
enc_inverse_normalizer,
cross_attn_mask=None,
step=None,
selected_beams=None,
cache=None):
batch_size = self.get_batch_dimension(dec_inputs)
self._assert_rank_and_type(dec_inputs, 3)
self._assert_rank_and_type(dec_mask, 3)
assert dec_inputs.get_shape().as_list()[-1] == self.model_dimension
self_attn_output = self.decoder_uniform_attn(
dec_inputs,
dec_mask,
dec_inverse_normalizer,
step=step,
beam_indices=selected_beams,
cache=cache)
cross_attn_output = self.multihead_cross_attn(dec_inputs, dec_mask,
dec_inverse_normalizer,
enc_output, enc_mask,
enc_inverse_normalizer,
cross_attn_mask)
layer_out = self.q0(cross_attn_output + self_attn_output)
layer_out = tf.reshape(layer_out, [-1, self.model_dimension])
layer_out = self.prx(layer_out)
if self.parameters.mode == base_layers.TRAIN:
layer_out = tf.nn.dropout(layer_out, rate=self.activation_dropout_rate)
dec_inputs = tf.reshape(dec_inputs, [-1, self.model_dimension])
dec_inputs_updated = self.q1(self.ln1(dec_inputs + layer_out))
# Feed forward network.
layer_out = self.upprx(dec_inputs_updated)
layer_out = self.downprx(layer_out)
if self.parameters.mode == base_layers.TRAIN:
layer_out = tf.nn.dropout(layer_out, rate=self.activation_dropout_rate)
outputs = self.q2(self.ln2(dec_inputs_updated + layer_out))
return tf.reshape(outputs, [batch_size, -1, self.model_dimension])
class TransformerUniformAttnDecoderStack(base_layers.BaseLayer):
"""TransformerUniformAttnDecoderStack Decoder."""
def __init__(self,
num_layers,
max_time_step,
vocabulary_size,
embedding_size,
model_dimension,
num_heads,
intermediate_size,
beam_size=1,
activation_dropout_rate=0.1,
attention_dropout_rate=0.0,
cached_kv=False,
**kwargs):
super(TransformerUniformAttnDecoderStack, self).__init__(**kwargs)
self.max_time_step = max_time_step
self.vocabulary_size = vocabulary_size
self.embedding_size = embedding_size
self.activation_dropout_rate = activation_dropout_rate
self.layers = []
for _ in range(num_layers):
self.layers.append(
TransformerUniformAttnDecoder(
model_dimension=model_dimension,
max_time_step=max_time_step,
num_heads=num_heads,
intermediate_size=intermediate_size,
beam_size=beam_size,
cached_kv=cached_kv,
activation_dropout_rate=activation_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
**kwargs))
def call(self,
dec_inputs,
dec_mask,
enc_output,
enc_mask,
step=None,
selected_beams=None,
cache=None):
self._assert_rank_and_type(dec_mask, 2)
self._assert_rank_and_type(enc_mask, 2)
dec_mask_rank3 = tf.expand_dims(dec_mask, axis=2)
dec_inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(dec_mask_rank3))
enc_mask_rank3 = tf.expand_dims(enc_mask, 1)
enc_inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(enc_mask_rank3))
cross_attn_mask = enc_mask_rank3
layer_in = dec_inputs
if self.parameters.mode == base_layers.TRAIN:
layer_in = tf.nn.dropout(layer_in, rate=self.activation_dropout_rate)
enc_output_feature_dim = enc_output.get_shape().as_list()[2]
enc_output = tf.reshape(enc_output, [-1, enc_output_feature_dim])
for i, layer in enumerate(self.layers):
layer_cache = cache["layer_%d" % i] if cache is not None else None
layer_in = layer(
layer_in,
dec_mask_rank3,
dec_inverse_normalizer,
enc_output,
enc_mask,
enc_inverse_normalizer,
cross_attn_mask,
step=step,
selected_beams=selected_beams,
cache=layer_cache)
return layer_in
class Model(tf.keras.layers.Layer):
"""Quantized transformer decoder."""
def __init__(self, config, mode):
super(Model, self).__init__()
def _get_params(varname, default_value=None):
value = config[varname] if varname in config else default_value
default = "" if varname in config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
_get_params("intermediate_size")
_get_params("max_dec_time_step")
_get_params("max_enc_time_step")
_get_params("embedding_size")
_get_params("vocabulary_size")
_get_params("num_layers")
_get_params("labels")
_get_params("regularizer_scale")
_get_params("num_heads")
_get_params("model_dimension")
_get_params("beam_size", 1)
_get_params("quantize", True)
_get_params("cached_kv", False)
_get_params("attention_dropout_rate", 0.0)
_get_params("activation_dropout_rate", 0.0)
# If set, a separate dense layer is used to generate the logits instead of
# re-using the input embedding table.
_get_params("use_output_layer", False)
self.parameters = base_layers.Parameters(mode, self.quantize,
self.regularizer_scale)
# Activation/Normalization enabled on input bottleneck as there is no
# temporal information.
self.input_bottleneck = dense_layers.BaseQDenseVarLen(
self.model_dimension, rank=3, parameters=self.parameters)
self.output_bottleneck = dense_layers.BaseQDense(
self.embedding_size,
normalize=False,
activation=None,
bias=False,
parameters=self.parameters)
self.embedding = embedding_layers.EmbeddingFullyConnected(
shape=[self.vocabulary_size, self.embedding_size],
initializer=tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3)),
parameters=self.parameters)
if self.use_output_layer:
self.output_layer = dense_layers.BaseQDense(
self.vocabulary_size,
activation=None,
normalize=False,
bias=False,
parameters=self.parameters)
self.positional_embedding = embedding_layers.EmbeddingLayer(
shape=[self.max_dec_time_step, self.model_dimension],
initializer=tf.random_uniform_initializer(-math.sqrt(3), math.sqrt(3)),
parameters=self.parameters)
self.ln = normalization_layers.LayerNormalization(
parameters=self.parameters)
self.qact = quantization_layers.ActivationQuantization(
parameters=self.parameters)
# Scales the weights for computing logits.
self.logits_fc_weights_scale_factor = None
self.logits_fc_bias = self.add_weight(
"logits_fc_bias",
shape=[self.vocabulary_size],
initializer=tf.constant_initializer(0),
dtype="float32")
# Optional bias which can be used to mask logits output.
self.output_bias = None
self.transformer_uniform_attn_decoder = TransformerUniformAttnDecoderStack(
parameters=self.parameters,
num_layers=self.num_layers,
intermediate_size=self.intermediate_size,
embedding_size=self.embedding_size,
max_time_step=self.max_dec_time_step,
num_heads=self.num_heads,
model_dimension=self.model_dimension,
vocabulary_size=self.vocabulary_size,
beam_size=self.beam_size,
cached_kv=self.cached_kv,
attention_dropout_rate=self.attention_dropout_rate,
activation_dropout_rate=self.activation_dropout_rate)
# Beam search output.
self.finished_seq = None
self.finished_scores = None
def call(self,
decode_ids,
decode_ids_mask,
enc_output,
enc_mask,
start_ids=None,
eos_id=None,
pad_id=None,
input_id=None,
time_step=None,
selected_beams=None):
if self.parameters.mode == base_layers.TRAIN:
inputs = self.training_inputs(decode_ids, decode_ids_mask)
layer_out = self.transformer_uniform_attn_decoder(inputs, decode_ids_mask,
enc_output, enc_mask)
logits, predicted_ids = self.model_outputs(layer_out)
elif self.parameters.mode in [base_layers.EVAL, base_layers.PREDICT]:
logits, predicted_ids = self.decode_beam_search(start_ids, eos_id, pad_id,
enc_output, enc_mask)
elif self.parameters.mode == base_layers.TFLITE:
input_values = self.embedding(input_id)
# time_step starts from 1.
pos_values = self.positional_embedding(time_step - 1)
pos_values = tf.reshape(pos_values, [-1, 1, self.embedding_size])
input_mask = tf.ones(tf.shape(input_values)[:-1], dtype=tf.float32)
inputs = self.qact(self.ln(input_values + pos_values))
layer_out = self.transformer_uniform_attn_decoder(
inputs,
input_mask,
enc_output,
enc_mask,
step=time_step,
selected_beams=selected_beams)
logits, predicted_ids = self.model_outputs(layer_out)
else:
assert "Invalid mode."
return logits, predicted_ids
def training_inputs(self, input_ids, input_mask):
input_values = self.embedding(input_ids)
if self.embedding_size != self.model_dimension:
input_values = self.input_bottleneck(input_values, input_mask)
pos_indices = tf.cumsum(input_mask, axis=1, exclusive=True)
pos_indices = tf.cast(pos_indices, dtype=tf.int32)
pos_values = self.positional_embedding(pos_indices)
inputs = self.qact(self.ln(input_values + pos_values))
return inputs
def model_outputs(self, layer_in):
bsz = layer_in.get_shape().as_list()[0] or tf.shape(layer_in)[0]
layer_out = tf.reshape(layer_in, [-1, self.model_dimension])
if self.use_output_layer:
logits = self.output_layer(layer_out)
else:
if self.model_dimension != self.embedding_size:
layer_out = self.output_bottleneck(layer_out)
logits = self.embedding.fully_connected(
layer_out,
bias=self.logits_fc_bias,
weights_scale_factor=self.logits_fc_weights_scale_factor)
logits = tf.reshape(logits, [bsz, -1, self.vocabulary_size])
# Optional bias to mask out logits before applying argmax.
if self.output_bias is not None:
logits += self.output_bias
predicted_ids = tf.argmax(logits, axis=2, output_type=tf.int64)
return logits, predicted_ids
def decode_beam_search(self,
start_ids,
eos_id,
pad_id,
enc_output,
enc_mask,
scope="model"):
batch_size = tf.shape(start_ids)[0]
cache = { # pylint: disable=g-complex-comprehension
"layer_%d" % layer: {
"uniform_avg": tf.zeros([batch_size, 1, self.model_dimension]),
} for layer in range(self.num_layers)
}
cache["logits"] = tf.zeros([batch_size, 0, self.vocabulary_size])
pos_indices = tf.range(self.max_dec_time_step, dtype=tf.int32)
pos_indices = tf.reshape(pos_indices, [1, -1])
pos_values = self.positional_embedding(pos_indices)
def beam_search_tile(output, tile_pattern, final_shape):
x = tf.tile(output, tile_pattern)
x = tf.reshape(x, final_shape)
return x
enc_output_feature_dim = enc_output.get_shape().as_list()[2]
enc_output = beam_search_tile(
enc_output, [1, self.beam_size, 1],
[batch_size * self.beam_size, -1, enc_output_feature_dim])
enc_mask = beam_search_tile(enc_mask, [1, self.beam_size],
[batch_size * self.beam_size, -1])
def symbols_to_logits_fn(ids, step, cache):
"""Looks up ids to logits."""
logging.info("Running symbols to logits. ids=%s, step=%s, cache=%s", ids,
step, cache)
curr_id = ids[:, -1:]
with tf.name_scope(scope):
curr_embed = self.embedding(curr_id)
input_mask = tf.ones(tf.shape(curr_embed)[:-1], dtype=tf.float32)
if self.embedding_size != self.model_dimension:
curr_embed = self.input_bottleneck(curr_embed, input_mask)
inputs = self.qact(
self.ln(curr_embed + pos_values[:, step:step + 1, :]))
layer_out = self.transformer_uniform_attn_decoder(
inputs,
input_mask,
enc_output,
enc_mask,
step=step + 1,
cache=cache)
next_logits, _ = self.model_outputs(layer_out)
cache["logits"] = tf.concat([cache["logits"], next_logits], axis=1)
return next_logits, cache
self.finished_seq, self.finished_scores, states = beam_search.beam_search(
symbols_to_logits_fn,
initial_ids=start_ids,
beam_size=self.beam_size,
decode_length=self.max_dec_time_step,
vocab_size=self.vocabulary_size,
alpha=0.6,
eos_id=eos_id,
states=cache)
beam_ids = self.finished_seq[:, 0, 1:]
beam_ids = tf.pad(
beam_ids, [[0, 0], [0, self.max_dec_time_step - tf.shape(beam_ids)[1]]],
constant_values=pad_id)
logits = states["logits"][:, 0, :, :]
logits = tf.pad(
logits,
[[0, 0], [0, self.max_dec_time_step - tf.shape(logits)[1]], [0, 0]],
constant_values=self.parameters.invalid_logit)
return logits, beam_ids
class ModelEvalWithGTLogitsAndPredictions(Model):
"""Model with EVAL mode logits and predictions based on ground truth inputs at each step."""
def call(self,
decode_ids,
decode_ids_mask,
enc_output,
enc_mask,
start_ids=None,
eos_id=None,
pad_id=None,
input_id=None,
time_step=None,
selected_beams=None):
if self.parameters.mode in [base_layers.TRAIN, base_layers.EVAL]:
inputs = self.training_inputs(decode_ids, decode_ids_mask)
layer_out = self.transformer_uniform_attn_decoder(inputs, decode_ids_mask,
enc_output, enc_mask)
logits, predicted_ids = self.model_outputs(layer_out)
elif self.parameters.mode == base_layers.PREDICT:
logits, predicted_ids = self.decode_beam_search(
start_ids,
eos_id,
pad_id,
enc_output,
enc_mask,
scope="model_eval_with_gt_logits_and_predictions")
elif self.parameters.mode == base_layers.TFLITE:
input_values = self.embedding(input_id)
# time_step starts from 1.
pos_values = self.positional_embedding(time_step - 1)
pos_values = tf.reshape(pos_values, [-1, 1, self.embedding_size])
input_mask = tf.ones(tf.shape(input_values)[:-1], dtype=tf.float32)
inputs = self.qact(self.ln(input_values + pos_values))
layer_out = self.transformer_uniform_attn_decoder(
inputs,
input_mask,
enc_output,
enc_mask,
step=time_step,
selected_beams=selected_beams)
logits, predicted_ids = self.model_outputs(layer_out)
else:
assert "Invalid mode."
return logits, predicted_ids
class ModelEvalWithGTLogits(Model):
"""Model with EVAL mode logits computed based on ground truth input at each step."""
def call(self,
decode_ids,
decode_ids_mask,
enc_output,
enc_mask,
start_ids=None,
eos_id=None,
pad_id=None,
input_id=None,
time_step=None,
selected_beams=None):
logits = None
if self.parameters.mode in [base_layers.TRAIN, base_layers.EVAL]:
inputs = self.training_inputs(decode_ids, decode_ids_mask)
layer_out = self.transformer_uniform_attn_decoder(inputs, decode_ids_mask,
enc_output, enc_mask)
logits, predicted_ids = self.model_outputs(layer_out)
if self.parameters.mode in [base_layers.EVAL, base_layers.PREDICT]:
# EVAL mode predictions are based on beam search path.
_, predicted_ids = self.decode_beam_search(
start_ids,
eos_id,
pad_id,
enc_output,
enc_mask,
scope="model_eval_with_gt_logits")
if self.parameters.mode == base_layers.TFLITE:
input_values = self.embedding(input_id)
# time_step starts from 1.
pos_values = self.positional_embedding(time_step - 1)
pos_values = tf.reshape(pos_values, [-1, 1, self.embedding_size])
input_mask = tf.ones(tf.shape(input_values)[:-1], dtype=tf.float32)
inputs = self.qact(self.ln(input_values + pos_values))
layer_out = self.transformer_uniform_attn_decoder(
inputs,
input_mask,
enc_output,
enc_mask,
step=time_step,
selected_beams=selected_beams)
logits, predicted_ids = self.model_outputs(layer_out)
return logits, predicted_ids
...@@ -93,3 +93,33 @@ REGISTER_OP("PoolingOp") ...@@ -93,3 +93,33 @@ REGISTER_OP("PoolingOp")
.Doc(R"doc( .Doc(R"doc(
Dummy pooling op. Dummy pooling op.
)doc"); )doc");
class UniformCausalAttnOp : public tensorflow::OpKernel {
public:
explicit UniformCausalAttnOp(tensorflow::OpKernelConstruction* context)
: tensorflow::OpKernel(context) {}
void Compute(tensorflow::OpKernelContext* ctx) override {}
};
REGISTER_KERNEL_BUILDER(
Name("UniformCausalAttn").Device(::tensorflow::DEVICE_CPU),
UniformCausalAttnOp);
REGISTER_OP("UniformCausalAttn")
.Input("input: float32")
.Input("time_step: int32")
.Input("selected_beams: int32")
.Attr("feature_size: int")
.Attr("beam_size: int")
.Output("output: float32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
auto batch_size = c->Dim(c->input(0), 0);
int32 feature_size;
TF_RETURN_IF_ERROR(c->GetAttr("feature_size", &feature_size));
c->set_output(0, c->MakeShape({batch_size, 1, feature_size}));
return tensorflow::Status::OK();
})
.Doc(R"doc(
Dummy uniform causal attn op.
)doc";
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment