Unverified Commit 50e86708 authored by karun's avatar karun Committed by GitHub
Browse files

Adding Bytestream model (#10731)


Co-authored-by: default avatarArun Kandoor <akandoor@google.com>
parent 2659c4e9
......@@ -22,6 +22,7 @@ from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tftext
from layers import projection_layers # import seq_flow_lite module
from utils import misc_utils # import seq_flow_lite module
......@@ -61,7 +62,24 @@ def create_input_fn(runner_config, mode, drop_remainder):
label = tf.reshape(label, [batch_size, num_classes])
prxlayer = projection_layers.ProjectionLayer(model_config, mode)
projection, seq_length = prxlayer(text)
return {"projection": projection, "seq_length": seq_length, "label": label}
gbst_max_token_len = max_seq_len
if "gbst_max_token_len" in model_config:
gbst_max_token_len = model_config["gbst_max_token_len"]
byte_int = tftext.ByteSplitter().split(text).to_tensor(
default_value=0, shape=[batch_size, gbst_max_token_len])
token_ids = tf.cast(byte_int, tf.int32)
token_len = tf.strings.length(text)
mask = tf.cast(
tf.sequence_mask(token_len, maxlen=gbst_max_token_len), tf.int32)
mask *= 3
token_ids += mask
return {
"projection": projection,
"seq_length": seq_length,
"token_ids": token_ids,
"token_len": token_len,
"label": label
}
def _input_fn(params):
"""Method to be used for reading the data."""
......
......@@ -83,8 +83,11 @@ py_strict_library(
srcs_version = "PY3",
deps = [
# package tensorflow
":embedding_layers",
"//layers:base_layers", # sequence projection
"//layers:conv_layers",
"//layers:dense_layers", # sequence projection
"//layers:normalization_layers",
"//layers:quantization_layers", # sequence projection
],
)
......@@ -102,3 +105,14 @@ py_strict_library(
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
py_strict_library(
name = "embedding_layers",
srcs = ["embedding_layers.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:quantization_layers",
],
)
# Copyright 2020 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers for embedding."""
import tensorflow as tf
from layers import base_layers
from layers import quantization_layers
class EmbeddingLayer(base_layers.BaseLayer):
"""Embedding layer."""
def __init__(self,
shape,
num_bits=8,
initializer=None,
trainable=True,
**kwargs):
self.shape = shape
self.quantizer = quantization_layers.ActivationQuantization(
num_bits=num_bits, **kwargs)
super(EmbeddingLayer, self).__init__(**kwargs)
if initializer is None:
initializer = tf.keras.initializers.GlorotUniform()
self.initializer = initializer
self.trainable = trainable
def build(self, input_shapes):
self.embedding_table = self.add_weight(
name="embedding_table",
shape=self.shape,
initializer=self.initializer,
trainable=self.trainable,
dtype=tf.float32)
if self.trainable:
self.add_reg_loss(self.embedding_table)
def call(self, indices):
assert indices.dtype in [tf.int64, tf.int32]
outputs = tf.nn.embedding_lookup(self.embedding_table, indices)
return self.quantizer(outputs)
class EmbeddingFullyConnected(EmbeddingLayer):
"""Uses embedding table as weights in a fully connected op."""
def __init__(self, **kwargs):
shape = kwargs.pop("shape", None)
initializer = kwargs.pop("initializer", None)
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
super(EmbeddingFullyConnected, self).__init__(
shape=shape, initializer=initializer, **kwargs)
def fully_connected(self, inputs, bias=None, weights_scale_factor=None):
# This method can only be called after a call to "call" method in this class
self._assert_rank_and_type(inputs, 2)
weights = self.embedding_table
if weights_scale_factor is not None:
weights = weights * weights_scale_factor
outputs = tf.matmul(inputs, weights, transpose_b=True)
if bias is not None:
outputs = tf.nn.bias_add(outputs, bias)
return self.qoutput(outputs)
......@@ -14,10 +14,13 @@
# ==============================================================================
# Lint as: python3
"""Layers for embedding."""
import math
import tensorflow as tf
from layers import base_layers # import seq_flow_lite module
from layers import conv_layers
from layers import dense_layers # import seq_flow_lite module
from layers import embedding_layers
from layers import quantization_layers # import seq_flow_lite module
......@@ -92,3 +95,147 @@ class TreeInductionLayer(base_layers.BaseLayer):
# seq_dim = tf.shape(result)[1]
# result = tf.reshape(result, [1, seq_dim, seq_dim])
return result
class GBSTLayerV2(base_layers.BaseLayer):
"""Tokenization layer."""
def __init__(self,
feature_size,
max_seq_len,
downsample_rate=2,
max_subword_block_width=4,
conv_kernel_size=5,
block_mixing_mode=None,
add_block_pos_embed=False,
**kwargs):
super(GBSTLayerV2, self).__init__(**kwargs)
self.feature_size = feature_size
self.max_seq_len = max_seq_len
self.downsample_rate = downsample_rate
self.subword_blocks_width = [1, 2, 3, 4]
self.max_subword_block_width = len(self.subword_blocks_width)
self.block_mixing_mode = block_mixing_mode
self.add_block_pos_embed = add_block_pos_embed
if self.add_block_pos_embed:
self.block_pos_embedding = embedding_layers.EmbeddingLayer(
shape=[self.max_subword_block_width, self.feature_size], **kwargs)
self.conv_kernel_size = conv_kernel_size
self.conv_layer = conv_layers.EncoderQConvolution(
filters=feature_size,
ksize=conv_kernel_size,
rank=3,
padding="VALID",
activation=None,
**kwargs)
padding = [conv_kernel_size - 1, 0]
self.zero_pad = tf.keras.layers.ZeroPadding1D(padding=padding)
self.block_attn = dense_layers.BaseQDense(
units=1,
rank=3,
activation=None,
normalize=False,
quantize_output=False,
**kwargs)
self.scores_concat = quantization_layers.ConcatQuantization(
axis=3, **kwargs)
self.attn_concat = quantization_layers.ConcatQuantization(axis=0, **kwargs)
self.qact = quantization_layers.ActivationQuantization(**kwargs)
self.qact_dot = quantization_layers.ActivationQuantization(**kwargs)
self.qoutput = quantization_layers.ActivationQuantization(**kwargs)
def call(self, inputs, seq_length):
"""Performs downsampling on the character-scale input representation.
Based in principle on https://arxiv.org/pdf/2106.12672.pdf.
Args:
inputs: float Tensor of shape [batch_size, seq_length, embedding_size].
seq_length: sequence length of shape [batch_size].
Returns:
<float>[batch_size, seq_length / downsample_rate, embedding_size].
Downsampled sequences.
"""
self._assert_rank_and_type(inputs, 3)
bsz = self.get_batch_dimension(inputs)
max_seq_len = self.max_seq_len
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
num_steps = tf.shape(inputs)[1]
inputs = self.zero_pad(inputs)
inputs = self.conv_layer(inputs)
all_block_scores = []
all_sequences = []
for subword_len in self.subword_blocks_width:
if self.add_block_pos_embed:
block_pos_indices = tf.range(subword_len, dtype=tf.int32)
block_pos_indices = tf.reshape(block_pos_indices, [1, -1])
block_pos_embeds = self.block_pos_embedding(block_pos_indices)
tile_len = math.ceil(max_seq_len / float(subword_len))
retiled_block_pos_embeds = tf.repeat(block_pos_embeds, tile_len, axis=1)
inputs += retiled_block_pos_embeds
# For this block size, form candidate block embeddings and scores.
# candidates shape: [batch, seq_len/subword_len, dim]
# block_scores shape: [batch, seq_len/subword_len, 1]
candidates = tf.nn.avg_pool(
inputs, [subword_len], strides=[subword_len], padding="SAME")
candidates = self.conv_layer.quantize_using_output_range(candidates)
block_scores = self.block_attn(candidates)
# Upsample it back to the original sequence length.
retiled_seq = tf.repeat(candidates, subword_len, axis=1)
retiled_block_scores = tf.repeat(block_scores, subword_len, axis=1)
# Make sure everything is the right length and add new dimension to concat
# candidate blocks on.
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
retiled_block_scores = retiled_block_scores[:, :num_steps, :]
retiled_seq = retiled_seq[:, :num_steps, :]
else:
retiled_block_scores = retiled_block_scores[:, :max_seq_len, :]
retiled_seq = retiled_seq[:, :max_seq_len, :]
retiled_seq = tf.expand_dims(retiled_seq, axis=-1)
retiled_block_scores = tf.expand_dims(retiled_block_scores, axis=-1)
all_sequences.append(retiled_seq)
all_block_scores.append(retiled_block_scores)
block_net = self.scores_concat(all_block_scores)
if self.block_mixing_mode == "score_attention":
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
block_attn_steps = []
self.attn_concat(None)
for i in range(num_steps):
block_i = tf.reshape(block_net[:, i:i + 1, :, :], [1, -1])
block_attn_steps.append(tf.matmul(block_i, block_i, transpose_b=True))
block_attn = self.attn_concat(block_attn_steps)
block_attn = tf.reshape(block_attn, [bsz, -1, 1, 1])
else:
block_attn = self.attn_concat(
[tf.matmul(block_net, block_net, transpose_b=True)])
block_attn = tf.nn.softmax(block_attn, axis=1)
block_attn = self.qrange_sigmoid(block_attn, tf_only=True)
block_net_scaled = self.qact(block_attn * block_net)
else:
block_net_scaled = block_net
candidate_embeds = self.conv_layer.quantize_using_output_range(
tf.concat(all_sequences, axis=3))
dot_product = self.qact_dot(block_net_scaled * candidate_embeds)
output = self.qoutput(tf.reduce_mean(dot_product, axis=-1, keepdims=True))
output = tf.reshape(output, [bsz, -1, self.feature_size])
# Removing pad entries for inference mode.
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
output = output[:, :num_steps, :]
# Downsample by mean pooling.
if self.downsample_rate > 1:
output = tf.nn.avg_pool(
output, (self.downsample_rate,),
strides=(self.downsample_rate,),
padding="VALID")
return output
......@@ -38,3 +38,19 @@ py_library(
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
py_library(
name = "byteqrnn",
srcs = ["byteqrnn.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers",
"//layers:dense_layers",
"//layers:embedding_layers",
"//layers:misc_layers",
"//layers:qrnn_layers",
# //tf_ops:tf_custom_ops",
"//tf_ops:tf_custom_ops_py",
],
)
# Copyright 2022 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ByteQRNN based model for in-training tokenization.
Sample model params:
"feature_size": 128, # Embedding size for each byte
"gbst_max_token_len": 1024, # Max sequence length of bytes in GBST
"gbst_downsample_rate": 1, # Downsample factor for GBST output
"bottleneck_size": 128, # Bottleneck size before feeding to QRNN
"qrnn_state_size": 128, # QRNN layer param
"qrnn_kernel_width": 3, # QRNN layer param
"qrnn_zoneout_probability": 1e-2, # QRNN layer param
"distortion_probability": 0.25, # QRNN layer param
"number_qrnn_layers": 3, # QRNN layer param
"labels": [], # List of labels for getting num classes
"regularizer_scale": 1e-5, # L2 Regularization scale
"quantize": true, # Enable quantization
"multilabel": true, # If the output is Multilabel
"""
from absl import logging
import tensorflow as tf
from layers import base_layers
from layers import dense_layers
from layers import embedding_layers
from layers import misc_layers
from layers import qrnn_layers
class Encoder(tf.keras.layers.Layer):
"""Encoder with GBST and QRNN layers."""
def __init__(self, config, mode, **kwargs):
super(Encoder, self).__init__(**kwargs)
def _get_params(varname, default_value=None):
value = config.get(varname, default_value)
default = "" if varname in config else " (default)"
logging.info("%s = %s%s", varname, value, default)
setattr(self, varname, value)
_get_params("feature_size")
_get_params("bottleneck_size", self.feature_size)
_get_params("qrnn_state_size")
_get_params("qrnn_kernel_width", 3)
_get_params("qrnn_zoneout_probability")
_get_params("number_qrnn_layers")
_get_params("labels", [])
_get_params("regularizer_scale")
_get_params("quantize")
_get_params("gbst_max_token_len", 128)
_get_params("gbst_downsample_rate", 1)
_get_params("gbst_max_subword_block_width", 4)
_get_params("gbst_conv_kernel_size", 5)
_get_params("gbst_block_mixing_mode")
_get_params("gbst_add_block_pos_embed", False)
_get_params("attn_pool_output", True)
self.num_classes = len(config.get("labels", []))
self.parameters = base_layers.Parameters(
mode, quantize=self.quantize, regularizer_scale=self.regularizer_scale)
# Including 3 additional special token ids (0=padding, 1=EOS, 2=UNK).
self.vocabulary_size = 259
self.embedding = embedding_layers.EmbeddingLayer(
shape=[self.vocabulary_size, self.feature_size],
parameters=self.parameters)
self.bottleneck_layer = dense_layers.BaseQDenseVarLen(
units=self.bottleneck_size,
rank=3,
parameters=self.parameters)
self.gbst_layer = misc_layers.GBSTLayerV2(
feature_size=self.bottleneck_size,
max_seq_len=self.gbst_max_token_len,
downsample_rate=self.gbst_downsample_rate,
max_subword_block_width=self.gbst_max_subword_block_width,
conv_kernel_size=self.gbst_conv_kernel_size,
block_mixing_mode=self.gbst_block_mixing_mode,
add_block_pos_embed=self.gbst_add_block_pos_embed,
parameters=self.parameters)
self.qrnn_stack = qrnn_layers.QRNNBidirectionalStack(
parameters=self.parameters,
zoneout_probability=self.qrnn_zoneout_probability,
kwidth=self.qrnn_kernel_width,
state_size=self.qrnn_state_size,
num_layers=self.number_qrnn_layers)
self.attention_pool = misc_layers.AttentionPooling(
parameters=self.parameters)
if self.num_classes:
self.final_fc = dense_layers.BaseQDense(
units=self.num_classes,
rank=2,
parameters=self.parameters,
activation=None)
def call(self, token_ids, seq_length):
input_embeds = self.embedding(token_ids)
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
mask_rank2 = tf.ones(tf.shape(input_embeds)[:-1], dtype=tf.float32)
seq_length = tf.reduce_sum(mask_rank2, axis=1)
else:
mask_rank2 = tf.sequence_mask(
seq_length, tf.shape(input_embeds)[1], dtype=tf.float32)
maskr3 = tf.expand_dims(mask_rank2, axis=2)
gbst_input = self.bottleneck_layer(input_embeds, maskr3)
gbst_output = self.gbst_layer(gbst_input, seq_length)
if self.parameters.mode in [base_layers.PREDICT, base_layers.TFLITE]:
mask_rank2 = tf.ones(tf.shape(gbst_output)[:-1], dtype=tf.float32)
seq_length = tf.reduce_sum(mask_rank2, axis=1)
else:
seq_length = seq_length / self.gbst_downsample_rate
mask_rank2 = tf.sequence_mask(
seq_length, tf.shape(gbst_output)[1], dtype=tf.float32)
inverse_normalizer = tf.math.reciprocal(tf.reduce_sum(mask_rank2))
maskr3 = tf.expand_dims(mask_rank2, axis=2)
outputs = self.qrnn_stack(gbst_output, maskr3, inverse_normalizer)
if self.attn_pool_output:
pre_logits = self.attention_pool(outputs, maskr3, inverse_normalizer)
if self.num_classes:
return self.final_fc(pre_logits)
else:
return pre_logits
else:
return outputs
......@@ -51,7 +51,10 @@ def load_runner_config():
def create_model(model, model_config, features, mode):
"""Creates a sequence labeling model."""
keras_model = model.Encoder(model_config, mode)
logits = keras_model(features["projection"], features["seq_length"])
if "pqrnn" in model_name:
logits = keras_model(features["projection"], features["seq_length"])
else:
logits = keras_model(features["token_ids"], features["token_len"])
if mode != tf.estimator.ModeKeys.PREDICT:
if not model_config["multilabel"]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment