"vscode:/vscode.git/clone" did not exist on "3c5330d8130ec7e03e5df28b199ab0357b559301"
Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
...@@ -29,8 +29,10 @@ _NEG_INF_FP32 = -1e9 ...@@ -29,8 +29,10 @@ _NEG_INF_FP32 = -1e9
_NEG_INF_FP16 = np.finfo(np.float16).min _NEG_INF_FP16 = np.finfo(np.float16).min
def get_position_encoding( def get_position_encoding(length,
length, hidden_size, min_timescale=1.0, max_timescale=1.0e4): hidden_size,
min_timescale=1.0,
max_timescale=1.0e4):
"""Return positional encoding. """Return positional encoding.
Calculates the position encoding as a mix of sine and cosine functions with Calculates the position encoding as a mix of sine and cosine functions with
...@@ -77,8 +79,8 @@ def get_decoder_self_attention_bias(length, dtype=tf.float32): ...@@ -77,8 +79,8 @@ def get_decoder_self_attention_bias(length, dtype=tf.float32):
""" """
neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32 neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
with tf.name_scope("decoder_self_attention_bias"): with tf.name_scope("decoder_self_attention_bias"):
valid_locs = tf.linalg.band_part(tf.ones([length, length], dtype=dtype), valid_locs = tf.linalg.band_part(
-1, 0) tf.ones([length, length], dtype=dtype), -1, 0)
valid_locs = tf.reshape(valid_locs, [1, 1, length, length]) valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
decoder_bias = neg_inf * (1.0 - valid_locs) decoder_bias = neg_inf * (1.0 - valid_locs)
return decoder_bias return decoder_bias
......
...@@ -40,22 +40,19 @@ class ModelUtilsTest(tf.test.TestCase): ...@@ -40,22 +40,19 @@ class ModelUtilsTest(tf.test.TestCase):
bias_shape = tf.shape(bias) bias_shape = tf.shape(bias)
flattened_bias = tf.reshape(bias, [3, 5]) flattened_bias = tf.reshape(bias, [3, 5])
self.assertAllEqual([[0, NEG_INF, NEG_INF, NEG_INF, 0], self.assertAllEqual(
[0, 0, NEG_INF, NEG_INF, NEG_INF], [[0, NEG_INF, NEG_INF, NEG_INF, 0], [0, 0, NEG_INF, NEG_INF, NEG_INF],
[NEG_INF, 0, 0, NEG_INF, 0]], [NEG_INF, 0, 0, NEG_INF, 0]], flattened_bias)
flattened_bias)
self.assertAllEqual([3, 1, 1, 5], bias_shape) self.assertAllEqual([3, 1, 1, 5], bias_shape)
def test_get_decoder_self_attention_bias(self): def test_get_decoder_self_attention_bias(self):
length = 5 length = 5
bias = model_utils.get_decoder_self_attention_bias(length) bias = model_utils.get_decoder_self_attention_bias(length)
self.assertAllEqual([[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF], self.assertAllEqual(
[0, 0, NEG_INF, NEG_INF, NEG_INF], [[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
[0, 0, 0, NEG_INF, NEG_INF], [0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, 0, NEG_INF, NEG_INF],
[0, 0, 0, 0, NEG_INF], [0, 0, 0, 0, NEG_INF], [0, 0, 0, 0, 0]]]], bias)
[0, 0, 0, 0, 0]]]],
bias)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -31,7 +31,6 @@ from official.nlp.transformer import metrics ...@@ -31,7 +31,6 @@ from official.nlp.transformer import metrics
from official.nlp.transformer import model_utils from official.nlp.transformer import model_utils
from official.nlp.transformer.utils.tokenizer import EOS_ID from official.nlp.transformer.utils.tokenizer import EOS_ID
# Disable the not-callable lint error, since it claims many objects are not # Disable the not-callable lint error, since it claims many objects are not
# callable when they actually are. # callable when they actually are.
# pylint: disable=not-callable # pylint: disable=not-callable
...@@ -49,11 +48,12 @@ def create_model(params, is_train): ...@@ -49,11 +48,12 @@ def create_model(params, is_train):
label_smoothing = params["label_smoothing"] label_smoothing = params["label_smoothing"]
if params["enable_metrics_in_training"]: if params["enable_metrics_in_training"]:
logits = metrics.MetricLayer(vocab_size)([logits, targets]) logits = metrics.MetricLayer(vocab_size)([logits, targets])
logits = tf.keras.layers.Lambda(lambda x: x, name="logits", logits = tf.keras.layers.Lambda(
dtype=tf.float32)(logits) lambda x: x, name="logits", dtype=tf.float32)(
logits)
model = tf.keras.Model([inputs, targets], logits) model = tf.keras.Model([inputs, targets], logits)
loss = metrics.transformer_loss( loss = metrics.transformer_loss(logits, targets, label_smoothing,
logits, targets, label_smoothing, vocab_size) vocab_size)
model.add_loss(loss) model.add_loss(loss)
return model return model
...@@ -130,9 +130,7 @@ class Transformer(tf.keras.Model): ...@@ -130,9 +130,7 @@ class Transformer(tf.keras.Model):
"Padded decoding on CPU/GPUs is not supported.") "Padded decoding on CPU/GPUs is not supported.")
decode_batch_size = int(self.params["decode_batch_size"] / decode_batch_size = int(self.params["decode_batch_size"] /
self.params["num_replicas"]) self.params["num_replicas"])
inputs.set_shape([ inputs.set_shape([decode_batch_size, self.params["decode_max_length"]])
decode_batch_size, self.params["decode_max_length"]
])
# Variance scaling is used here because it seems to work in many problems. # Variance scaling is used here because it seems to work in many problems.
# Other reasonable initializers may also work just as well. # Other reasonable initializers may also work just as well.
...@@ -314,15 +312,13 @@ class Transformer(tf.keras.Model): ...@@ -314,15 +312,13 @@ class Transformer(tf.keras.Model):
cache = { cache = {
"layer_%d" % layer: { "layer_%d" % layer: {
"k": "k":
tf.zeros([ tf.zeros(
batch_size, init_decode_length, num_heads, dim_per_head [batch_size, init_decode_length, num_heads, dim_per_head],
], dtype=self.params["dtype"]),
dtype=self.params["dtype"]),
"v": "v":
tf.zeros([ tf.zeros(
batch_size, init_decode_length, num_heads, dim_per_head [batch_size, init_decode_length, num_heads, dim_per_head],
], dtype=self.params["dtype"])
dtype=self.params["dtype"])
} for layer in range(self.params["num_hidden_layers"]) } for layer in range(self.params["num_hidden_layers"])
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
...@@ -512,15 +508,14 @@ class DecoderStack(tf.keras.layers.Layer): ...@@ -512,15 +508,14 @@ class DecoderStack(tf.keras.layers.Layer):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
decoder_inputs: A tensor with shape decoder_inputs: A tensor with shape [batch_size, target_length,
[batch_size, target_length, hidden_size]. hidden_size].
encoder_outputs: A tensor with shape encoder_outputs: A tensor with shape [batch_size, input_length,
[batch_size, input_length, hidden_size] hidden_size]
decoder_self_attention_bias: A tensor with shape decoder_self_attention_bias: A tensor with shape [1, 1, target_len,
[1, 1, target_len, target_length], the bias for decoder self-attention target_length], the bias for decoder self-attention layer.
layer. attention_bias: A tensor with shape [batch_size, 1, 1, input_length], the
attention_bias: A tensor with shape [batch_size, 1, 1, input_length], bias for encoder-decoder attention layer.
the bias for encoder-decoder attention layer.
training: A bool, whether in training mode or not. training: A bool, whether in training mode or not.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
......
...@@ -34,11 +34,12 @@ class TransformerLayersTest(tf.test.TestCase): ...@@ -34,11 +34,12 @@ class TransformerLayersTest(tf.test.TestCase):
dropout = 0.5 dropout = 0.5
dim_per_head = hidden_size // num_heads dim_per_head = hidden_size // num_heads
layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout) layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout)
self.assertDictEqual(layer.get_config(), { self.assertDictEqual(
"hidden_size": hidden_size, layer.get_config(), {
"num_heads": num_heads, "hidden_size": hidden_size,
"attention_dropout": dropout, "num_heads": num_heads,
}) "attention_dropout": dropout,
})
length = 2 length = 2
x = tf.ones([1, length, hidden_size]) x = tf.ones([1, length, hidden_size])
bias = tf.ones([1]) bias = tf.ones([1])
...@@ -47,9 +48,23 @@ class TransformerLayersTest(tf.test.TestCase): ...@@ -47,9 +48,23 @@ class TransformerLayersTest(tf.test.TestCase):
"v": tf.zeros([1, 0, num_heads, dim_per_head]), "v": tf.zeros([1, 0, num_heads, dim_per_head]),
} }
y = layer(x, bias, training=True, cache=cache) y = layer(x, bias, training=True, cache=cache)
self.assertEqual(y.shape, (1, length, 64,)) self.assertEqual(y.shape, (
self.assertEqual(cache["k"].shape, (1, length, num_heads, dim_per_head,)) 1,
self.assertEqual(cache["v"].shape, (1, length, num_heads, dim_per_head,)) length,
64,
))
self.assertEqual(cache["k"].shape, (
1,
length,
num_heads,
dim_per_head,
))
self.assertEqual(cache["v"].shape, (
1,
length,
num_heads,
dim_per_head,
))
def test_embedding_shared_weights(self): def test_embedding_shared_weights(self):
vocab_size = 50 vocab_size = 50
...@@ -63,25 +78,38 @@ class TransformerLayersTest(tf.test.TestCase): ...@@ -63,25 +78,38 @@ class TransformerLayersTest(tf.test.TestCase):
idx = tf.ones([1, length], dtype="int32") idx = tf.ones([1, length], dtype="int32")
y = layer(idx) y = layer(idx)
self.assertEqual(y.shape, (1, length, hidden_size,)) self.assertEqual(y.shape, (
1,
length,
hidden_size,
))
x = tf.ones([1, length, hidden_size]) x = tf.ones([1, length, hidden_size])
output = layer(x, "linear") output = layer(x, "linear")
self.assertEqual(output.shape, (1, length, vocab_size,)) self.assertEqual(output.shape, (
1,
length,
vocab_size,
))
def test_feed_forward_network(self): def test_feed_forward_network(self):
hidden_size = 64 hidden_size = 64
filter_size = 32 filter_size = 32
relu_dropout = 0.5 relu_dropout = 0.5
layer = ffn_layer.FeedForwardNetwork(hidden_size, filter_size, relu_dropout) layer = ffn_layer.FeedForwardNetwork(hidden_size, filter_size, relu_dropout)
self.assertDictEqual(layer.get_config(), { self.assertDictEqual(
"hidden_size": hidden_size, layer.get_config(), {
"filter_size": filter_size, "hidden_size": hidden_size,
"relu_dropout": relu_dropout, "filter_size": filter_size,
}) "relu_dropout": relu_dropout,
})
length = 2 length = 2
x = tf.ones([1, length, hidden_size]) x = tf.ones([1, length, hidden_size])
y = layer(x, training=True) y = layer(x, training=True)
self.assertEqual(y.shape, (1, length, hidden_size,)) self.assertEqual(y.shape, (
1,
length,
hidden_size,
))
def test_metric_layer(self): def test_metric_layer(self):
vocab_size = 50 vocab_size = 50
...@@ -90,7 +118,11 @@ class TransformerLayersTest(tf.test.TestCase): ...@@ -90,7 +118,11 @@ class TransformerLayersTest(tf.test.TestCase):
name="logits") name="logits")
targets = tf.keras.layers.Input((None,), dtype="int64", name="targets") targets = tf.keras.layers.Input((None,), dtype="int64", name="targets")
output_logits = metrics.MetricLayer(vocab_size)([logits, targets]) output_logits = metrics.MetricLayer(vocab_size)([logits, targets])
self.assertEqual(output_logits.shape.as_list(), [None, None, vocab_size,]) self.assertEqual(output_logits.shape.as_list(), [
None,
None,
vocab_size,
])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -24,6 +24,7 @@ from __future__ import print_function ...@@ -24,6 +24,7 @@ from __future__ import print_function
import os import os
import tempfile import tempfile
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Import libraries
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
......
...@@ -22,6 +22,7 @@ import collections ...@@ -22,6 +22,7 @@ import collections
import re import re
import sys import sys
import unicodedata import unicodedata
from absl import logging from absl import logging
import numpy as np import numpy as np
...@@ -29,7 +30,6 @@ import six ...@@ -29,7 +30,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
PAD = "<pad>" PAD = "<pad>"
PAD_ID = 0 PAD_ID = 0
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utilities for pre-processing classification data.""" """Utilities for pre-processing classification data."""
from absl import logging from absl import logging
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
......
...@@ -22,12 +22,12 @@ from __future__ import print_function ...@@ -22,12 +22,12 @@ from __future__ import print_function
import collections import collections
import json import json
import os import os
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
special_symbols = { special_symbols = {
"<unk>": 0, "<unk>": 0,
"<s>": 1, "<s>": 1,
...@@ -51,10 +51,10 @@ SEG_ID_Q = 1 ...@@ -51,10 +51,10 @@ SEG_ID_Q = 1
SEG_ID_CLS = 2 SEG_ID_CLS = 2
SEG_ID_PAD = 3 SEG_ID_PAD = 3
OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [ OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
"sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words", "sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
"min_num_words"]) "min_num_words"
])
def file_based_input_fn_builder(input_file, name_to_features, batch_size, def file_based_input_fn_builder(input_file, name_to_features, batch_size,
...@@ -253,20 +253,14 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training, ...@@ -253,20 +253,14 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict): def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
"""Turn beg and end indices into actual mask.""" """Turn beg and end indices into actual mask."""
non_func_mask = tf.logical_and( non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
tf.not_equal(inputs, CLS_ID)) all_indices = tf.where(non_func_mask, tf.range(tgt_len, dtype=tf.int64),
all_indices = tf.where( tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
non_func_mask,
tf.range(tgt_len, dtype=tf.int64),
tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
candidate_matrix = tf.cast( candidate_matrix = tf.cast(
tf.logical_and( tf.logical_and(all_indices[None, :] >= beg_indices[:, None],
all_indices[None, :] >= beg_indices[:, None], all_indices[None, :] < end_indices[:, None]), tf.float32)
all_indices[None, :] < end_indices[:, None]),
tf.float32)
cumsum_matrix = tf.reshape( cumsum_matrix = tf.reshape(
tf.cumsum(tf.reshape(candidate_matrix, [-1])), tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, tgt_len])
[-1, tgt_len])
masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32) masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0) target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
is_masked = tf.cast(target_mask, tf.bool) is_masked = tf.cast(target_mask, tf.bool)
...@@ -274,8 +268,8 @@ def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict): ...@@ -274,8 +268,8 @@ def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
return is_masked, target_mask return is_masked, target_mask
def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, max_num_words,
max_num_words, boundary): boundary):
"""Sample whole word spans as prediction targets.""" """Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio # Note: 1.2 is the token-to-word ratio
mask_alpha = tgt_len / num_predict / 1.2 mask_alpha = tgt_len / num_predict / 1.2
...@@ -283,7 +277,7 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, ...@@ -283,7 +277,7 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
# Sample span lengths from a zipf distribution # Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_words, max_num_words + 1) span_len_seq = np.arange(min_num_words, max_num_words + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs) probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32) logits = tf.constant(np.log(probs), dtype=tf.float32)
...@@ -302,8 +296,8 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, ...@@ -302,8 +296,8 @@ def _word_span_mask(inputs, tgt_len, num_predict, min_num_words,
left_ctx_len = round_to_int(left_ctx_len) left_ctx_len = round_to_int(left_ctx_len)
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
beg_indices = (tf.cumsum(left_ctx_len) + beg_indices = (
tf.cumsum(right_offset, exclusive=True)) tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens end_indices = beg_indices + span_lens
# Remove out of range indices # Remove out of range indices
...@@ -333,7 +327,7 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens, ...@@ -333,7 +327,7 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
# Sample span lengths from a zipf distribution # Sample span lengths from a zipf distribution
span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1) span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
probs = np.array([1.0 / (i + 1) for i in span_len_seq]) probs = np.array([1.0 / (i + 1) for i in span_len_seq])
probs /= np.sum(probs) probs /= np.sum(probs)
logits = tf.constant(np.log(probs), dtype=tf.float32) logits = tf.constant(np.log(probs), dtype=tf.float32)
...@@ -353,8 +347,8 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens, ...@@ -353,8 +347,8 @@ def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
# Get the actual begin and end indices # Get the actual begin and end indices
beg_indices = (tf.cumsum(left_ctx_len) + beg_indices = (
tf.cumsum(right_offset, exclusive=True)) tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
end_indices = beg_indices + span_lens end_indices = beg_indices + span_lens
# Remove out of range indices # Remove out of range indices
...@@ -387,8 +381,7 @@ def _single_token_mask(inputs, tgt_len, num_predict): ...@@ -387,8 +381,7 @@ def _single_token_mask(inputs, tgt_len, num_predict):
"""Sample individual tokens as prediction targets.""" """Sample individual tokens as prediction targets."""
all_indices = tf.range(tgt_len, dtype=tf.int64) all_indices = tf.range(tgt_len, dtype=tf.int64)
non_func_mask = tf.logical_and( non_func_mask = tf.logical_and(
tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
tf.not_equal(inputs, CLS_ID))
non_func_indices = tf.boolean_mask(all_indices, non_func_mask) non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
masked_pos = tf.random.shuffle(non_func_indices) masked_pos = tf.random.shuffle(non_func_indices)
...@@ -404,7 +397,10 @@ def _single_token_mask(inputs, tgt_len, num_predict): ...@@ -404,7 +397,10 @@ def _single_token_mask(inputs, tgt_len, num_predict):
return is_masked, target_mask return is_masked, target_mask
def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config, def _online_sample_masks(inputs,
tgt_len,
num_predict,
online_masking_config,
boundary=None): boundary=None):
"""Sample target positions to predict.""" """Sample target positions to predict."""
logging.info("Online sample with strategy: `%s`.", logging.info("Online sample with strategy: `%s`.",
...@@ -422,8 +418,7 @@ def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config, ...@@ -422,8 +418,7 @@ def _online_sample_masks(inputs, tgt_len, num_predict, online_masking_config,
assert boundary is not None, "word span sampling requires `boundary`" assert boundary is not None, "word span sampling requires `boundary`"
return _word_span_mask(inputs, tgt_len, num_predict, return _word_span_mask(inputs, tgt_len, num_predict,
online_masking_config.min_num_words, online_masking_config.min_num_words,
online_masking_config.max_num_words, online_masking_config.max_num_words, boundary)
boundary)
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -529,10 +524,11 @@ def create_pretrain_dataset(file_names, ...@@ -529,10 +524,11 @@ def create_pretrain_dataset(file_names,
example["target"] = tf.reshape(target, [num_predict]) example["target"] = tf.reshape(target, [num_predict])
##### target mask ##### target mask
target_mask = tf.concat( target_mask = tf.concat([
[tf.ones([actual_num_predict], dtype=tf.float32), tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)], tf.zeros([pad_len], dtype=tf.float32)
axis=0) ],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict]) example["target_mask"] = tf.reshape(target_mask, [num_predict])
else: else:
example["target"] = tf.reshape(target, [seq_len]) example["target"] = tf.reshape(target, [seq_len])
...@@ -562,7 +558,11 @@ def create_pretrain_dataset(file_names, ...@@ -562,7 +558,11 @@ def create_pretrain_dataset(file_names,
return dataset return dataset
def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None, def format_filename(prefix,
suffix,
bsz_per_host,
seq_len,
reuse_len=None,
uncased=False): uncased=False):
"""Generates input file name pattern.""" """Generates input file name pattern."""
if reuse_len is not None and reuse_len > 0: if reuse_len is not None and reuse_len > 0:
...@@ -577,8 +577,8 @@ def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None, ...@@ -577,8 +577,8 @@ def format_filename(prefix, suffix, bsz_per_host, seq_len, reuse_len=None,
else: else:
case_str = "uncased." case_str = "uncased."
file_name = "{}.seq-{}.{}{}{}{}".format( file_name = "{}.seq-{}.{}{}{}{}".format(prefix, seq_len, reuse_str, bsz_str,
prefix, seq_len, reuse_str, bsz_str, case_str, suffix) case_str, suffix)
return file_name return file_name
...@@ -722,9 +722,7 @@ def parse_files_to_dataset(parser, ...@@ -722,9 +722,7 @@ def parse_files_to_dataset(parser,
# even more randomness to the training pipeline. # even more randomness to the training pipeline.
dataset = dataset.apply( dataset = dataset.apply(
tf.data.experimental.parallel_interleave( tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset, tf.data.TFRecordDataset, sloppy=True, cycle_length=cycle_length))
sloppy=True,
cycle_length=cycle_length))
buffer_size = 2048 buffer_size = 2048
logging.info("Perform sample-level shuffle with size %d", buffer_size) logging.info("Perform sample-level shuffle with size %d", buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.shuffle(buffer_size=buffer_size)
...@@ -778,9 +776,8 @@ def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio): ...@@ -778,9 +776,8 @@ def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
index = tf.reshape(tf.transpose(index), [-1]) index = tf.reshape(tf.transpose(index), [-1])
# non-functional tokens # non-functional tokens
non_func_tokens = tf.logical_not(tf.logical_or( non_func_tokens = tf.logical_not(
tf.equal(inputs, SEP_ID), tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
tf.equal(inputs, CLS_ID)))
masked_tokens = tf.logical_and(is_masked, non_func_tokens) masked_tokens = tf.logical_and(is_masked, non_func_tokens)
non_masked_or_func_tokens = tf.logical_not(masked_tokens) non_masked_or_func_tokens = tf.logical_not(masked_tokens)
......
...@@ -21,6 +21,7 @@ import collections ...@@ -21,6 +21,7 @@ import collections
import csv import csv
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -23,6 +23,7 @@ import json ...@@ -23,6 +23,7 @@ import json
import os import os
import random import random
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
import absl.logging as _logging # pylint: disable=unused-import import absl.logging as _logging # pylint: disable=unused-import
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
import random import random
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -21,7 +21,6 @@ import unicodedata ...@@ -21,7 +21,6 @@ import unicodedata
import six import six
SPIECE_UNDERLINE = '▁' SPIECE_UNDERLINE = '▁'
...@@ -95,8 +94,8 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False): ...@@ -95,8 +94,8 @@ def encode_pieces(sp_model, text, return_unicode=True, sample=False):
new_pieces = [] new_pieces = []
for piece in pieces: for piece in pieces:
if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
cur_pieces = sp_model.EncodeAsPieces( cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
piece[:-1].replace(SPIECE_UNDERLINE, '')) SPIECE_UNDERLINE, ''))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1: if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:] cur_pieces = cur_pieces[1:]
......
...@@ -20,6 +20,7 @@ from __future__ import division ...@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -22,6 +22,7 @@ from __future__ import print_function ...@@ -22,6 +22,7 @@ from __future__ import print_function
import functools import functools
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -24,6 +24,7 @@ import json ...@@ -24,6 +24,7 @@ import json
import os import os
import pickle import pickle
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -38,12 +38,13 @@ def create_run_config(is_training, is_finetune, flags): ...@@ -38,12 +38,13 @@ def create_run_config(is_training, is_finetune, flags):
clamp_len=flags.clamp_len) clamp_len=flags.clamp_len)
if not is_finetune: if not is_finetune:
kwargs.update(dict( kwargs.update(
mem_len=flags.mem_len, dict(
reuse_len=flags.reuse_len, mem_len=flags.mem_len,
bi_data=flags.bi_data, reuse_len=flags.reuse_len,
clamp_len=flags.clamp_len, bi_data=flags.bi_data,
same_length=flags.same_length)) clamp_len=flags.clamp_len,
same_length=flags.same_length))
return RunConfig(**kwargs) return RunConfig(**kwargs)
...@@ -80,8 +81,10 @@ class XLNetConfig(object): ...@@ -80,8 +81,10 @@ class XLNetConfig(object):
assert FLAGS is not None or json_path is not None or args_dict is not None assert FLAGS is not None or json_path is not None or args_dict is not None
self.keys = ['n_layer', 'd_model', 'n_head', 'd_head', 'd_inner', self.keys = [
'ff_activation', 'untie_r', 'n_token'] 'n_layer', 'd_model', 'n_head', 'd_head', 'd_inner', 'ff_activation',
'untie_r', 'n_token'
]
if FLAGS is not None: if FLAGS is not None:
self.init_from_flags(FLAGS) self.init_from_flags(FLAGS)
...@@ -152,17 +155,17 @@ class RunConfig(object): ...@@ -152,17 +155,17 @@ class RunConfig(object):
init_method: str, the initialization scheme, either "normal" or "uniform". init_method: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform". in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution init_std: float, initialize the parameters with a normal distribution with
with mean 0 and stddev init_std. Only effective when init="normal". mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache. mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached reuse_len: int, the number of tokens in the currect batch to be cached and
and reused in the future. reused in the future.
bi_data: bool, whether to use bidirectional input pipeline. bi_data: bool, whether to use bidirectional input pipeline. Usually set to
Usually set to True during pretraining and False during finetuning. True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len. clamp_len: int, clamp all relative distances larger than clamp_len. -1
-1 means no clamping. means no clamping.
same_length: bool, whether to use the same attention length same_length: bool, whether to use the same attention length for each
for each token. token.
use_cls_mask: bool, whether to introduce cls mask. use_cls_mask: bool, whether to introduce cls mask.
""" """
......
...@@ -48,5 +48,6 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase): ...@@ -48,5 +48,6 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
logging.info(pos_emb) logging.info(pos_emb)
self.assertAllClose(pos_emb, target) self.assertAllClose(pos_emb, target)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -65,7 +65,7 @@ CACHE_INVALIDATION_SEC = 3600 * 24 ...@@ -65,7 +65,7 @@ CACHE_INVALIDATION_SEC = 3600 * 24
# == Data Generation =========================================================== # == Data Generation ===========================================================
# ============================================================================== # ==============================================================================
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead" CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
# of the main training loop. # of the main training loop.
# Number of batches to run per epoch when using synthetic data. At high batch # Number of batches to run per epoch when using synthetic data. At high batch
# sizes, we run for more batches than with real data, which is good since # sizes, we run for more batches than with real data, which is good since
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import json import json
# pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow.compat.v2 as tf import tensorflow.compat.v2 as tf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment