Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
......@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, is_query, f, h,
data_normalizer_fn=None):
def _generalized_kernel(x, projection_matrix, f, h):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Transformed feature.
"""
# No asymmetric operations.
del is_query
if data_normalizer_fn is not None:
x = data_normalizer_fn(x)
if projection_matrix is None:
return h(x) * f(x)
......@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
tf.math.square(x), axis=-1, keepdims=True)),),
"expmod":
functools.partial(
_generalized_kernel,
# Avoid exp explosion by shifting.
f=lambda x: tf.math.exp(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"l2":
functools.partial(
_generalized_kernel,
f=lambda x: x,
h=lambda x: tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32)),
data_normalizer_fn=lambda x: x),
"identity": lambda x, projection_matrix, is_query: x
f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
tf.cast(tf.shape(x)[-1], tf.float32))),
),
"identity":
functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
}
# pylint: enable=g-long-lambda
......@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2
- exp (Lemma 1, positive), relu
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
......@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw=False,
is_short_seq=False,
begin_kernel=0,
scale=None,
**kwargs):
r"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
feature_transform as "l2".
"identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
......@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if feature_transform not in _TRANSFORM_MAP:
......@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference
# 2. no redraw
self._seed = seed
super().__init__(**kwargs)
if scale is None:
self._scale = 1.0 / math.sqrt(float(self._key_dim))
else:
self._scale = scale
self._projection_matrix = None
if num_random_features > 0:
self._projection_matrix = create_projection_matrix(
......@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix = None
if self._num_random_features > 0:
if self._redraw and training:
......@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else:
projection_matrix = self._projection_matrix
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix, False)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix, True)
if is_short_seq:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
......@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator)
attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output
def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key)
......@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw": self._redraw,
"is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel,
"scale": self._scale,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'l2']
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_REDRAW = [True, False]
_TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False]
......
......@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'.")
def _iterative_vectorized_fair_share(capacity: tf.Tensor,
limit: Union[int, tf.Tensor]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit = limit // capacity.shape[0]
limit_mask = tf.ones(capacity.shape, dtype=tf.int64) * per_seg_limit
lower_bound = tf.minimum(capacity, limit_mask)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum = limit - tf.math.reduce_sum(lower_bound, axis=0)
remaining_cap_mat = capacity - lower_bound
new_cap = lower_bound + remaining_cap_mat * tf.cast(
tf.math.reduce_sum(remaining_cap_mat, axis=0) <= remaining_cap_sum,
tf.int64)
# Process iteratively. This step is O(#segments), see analysis below.
while True:
remaining_limit = limit - tf.math.reduce_sum(new_cap, axis=0)
remaining_cap = capacity - new_cap
masked_remaining_slots = tf.cast(remaining_cap > 0, tf.int64)
remaining_cap_col_slots = tf.reduce_sum(masked_remaining_slots, axis=0)
masked_remaining_limit = tf.cast(remaining_cap_col_slots > 0,
tf.int64) * remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit = masked_remaining_limit // (
tf.cast(remaining_cap_col_slots <= 0, tf.int64) +
remaining_cap_col_slots) # +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if tf.math.reduce_sum(per_seg_limit) > 0:
remaining_slots_mat = tf.cast(remaining_cap > 0, tf.int64)
new_cap = new_cap + remaining_slots_mat * per_seg_limit
else:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask = tf.cast(
(tf.cumsum(masked_remaining_slots, axis=0) <= masked_remaining_limit)
& (masked_remaining_slots > 0), tf.int64)
new_cap = new_cap + new_remained_assignment_mask
break
return new_cap
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
) -> Union[tf.RaggedTensor, List[tf.RaggedTensor]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if not isinstance(inputs, (list, tuple)):
return round_robin_truncate_inputs([inputs], limit)[0]
limit = tf.cast(limit, tf.int64)
if not all(rt.shape.rank == 2 for rt in inputs):
raise ValueError("All inputs must have shape [batch_size, (items)]")
if len(inputs) == 1:
return [_truncate_row_lengths(inputs[0], limit)]
elif len(inputs) == 2:
size_a, size_b = [rt.row_lengths() for rt in inputs]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half = limit // 2
ceil_half = limit - floor_half
quota_a = tf.minimum(size_a, ceil_half + tf.nn.relu(floor_half - size_b))
quota_b = tf.minimum(size_b, floor_half + tf.nn.relu(ceil_half - size_a))
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity = tf.stack([rt.row_lengths() for rt in inputs]) # #Segments x B
new_capacity = _iterative_vectorized_fair_share(capacity, limit)
return [
_truncate_row_lengths(inputs[i], new_capacity[i])
for i in range(capacity.shape[0])
]
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
new_lengths: tf.Tensor) -> tf.RaggedTensor:
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
......@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
# fall back to some ad-hoc truncation.
num_special_tokens = len(inputs) + 1
if truncator == "round_robin":
trimmed_segments = round_robin_truncate_inputs(
inputs, seq_length - num_special_tokens)
trimmed_segments = text.RoundRobinTrimmer(seq_length -
num_special_tokens).trim(inputs)
elif truncator == "waterfall":
trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(inputs)
......
......@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def _test_input(self, start, lengths):
return tf.ragged.constant([[start + 10 * j + i
for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
def test_single_segment(self):
# Single segment.
single_input = self._test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35], # Truncated.
])
self.assertAllEqual(
expected_single_output,
text_layers.round_robin_truncate_inputs(single_input, limit=5))
# Test wrapping in a singleton list.
actual_single_list_output = text_layers.round_robin_truncate_inputs(
[single_input], limit=5)
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
def test_two_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153], # Truncated.
[161, 162, 163], # Truncated.
])
expected_b = tf.ragged.constant(
[[211],
[221, 222, 223],
[231, 232, 233], # Truncated.
[241, 242],
[251, 252],
[261, 262], # Truncated.
])
actual_a, actual_b = text_layers.round_robin_truncate_inputs(
[input_a, input_b], limit=5)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
def test_three_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5, 1])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5, 8])
input_c = self._test_input(311, [1, 3, 4, 2, 2, 5, 10])
seg_limit = 8
expected_a = tf.ragged.constant([
[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153, 154],
[161, 162, 163], # Truncated
[171]
])
expected_b = tf.ragged.constant([
[211],
[221, 222, 223],
[231, 232, 233], # Truncated
[241, 242],
[251, 252],
[261, 262, 263], # Truncated
[271, 272, 273, 274] # Truncated
])
expected_c = tf.ragged.constant([
[311],
[321, 322, 323],
[331, 332, 333], # Truncated
[341, 342],
[351, 352],
[361, 362], # Truncated
[371, 372, 373] # Truncated
])
actual_a, actual_b, actual_c = text_layers.round_robin_truncate_inputs(
[input_a, input_b, input_c], limit=seg_limit)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
self.assertAllEqual(expected_c, actual_c)
input_cap = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [input_a, input_b, input_c]]),
axis=0)
per_example_usage = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [actual_a, actual_b, actual_c]]),
axis=0)
self.assertTrue(all(per_example_usage <= tf.minimum(seg_limit, input_cap)))
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py.
......
......@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler', 'head_name') will be ignored.
"""
def __init__(self,
......@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform',
dropout_rate=0.1,
use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None,
**kwargs):
self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler
......@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
name=head_name)
predictions = classifier(cls_inputs)
......@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return {
'network': self._network,
'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head,
......
......@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights."""
embedding_matrix = tf.cast(embedding_matrix, dtype=self.compute_dtype)
x = tf.cast(x, dtype=self.compute_dtype)
batch_size = tf.shape(x)[0]
length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True)
logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size])
......
......@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
"""
def __init__(
......@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last',
dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction',
**kwargs):
super().__init__(**kwargs)
self._network = network
......@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes': num_classes,
'summary_type': summary_type,
'dropout_rate': dropout_rate,
'head_name': head_name,
}
if summary_type == 'last':
......@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer=initializer,
dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx,
name='sentence_prediction')
name=head_name)
def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids']
......
......@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
def __init__(self,
......@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width=None,
embedding_layer=None,
dict_outputs=False,
norm_first=False,
**kwargs):
# b/164516224
......@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer=initializer,
output_range=output_range,
embedding_width=embedding_width,
embedding_layer=embedding_layer)
embedding_layer=embedding_layer,
norm_first=norm_first)
self._embedding_layer_instance = embedding_layer
......
......@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1,
embedding_width=16,
dict_outputs=True,
embedding_layer=None)
embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize(
......
......@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding."""
import abc
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
import numpy as np
import tensorflow as tf
......@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
], -1)
# Scatter sorted indices to original indexes.
indices_to_remove = scatter_values_on_batch_indices(
sorted_indices_to_remove, sorted_indices)
top_p_logits = set_tensor_by_indices_to_value(
logits, indices_to_remove, np.NINF)
indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove,
sorted_indices)
top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
np.NINF)
return top_p_logits
......@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape = decoding_module.shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape(
tf.broadcast_to(
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1),
tensor_shape), [1, -1])
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape),
[1, -1])
pair_indices = tf.transpose(
tf.concat([broad_casted_batch_dims,
tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices,
tf.reshape(values, [-1]), tensor_shape)
return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
def set_tensor_by_indices_to_value(input_tensor, indices, value):
......@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
......@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self,
symbols_to_logits_fn,
length_normalization_fn: Callable[[int, tf.DType], float],
vocab_size: int,
max_decode_length: int,
eos_id: int,
padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType],
float]] = None,
top_k=0,
top_p=1.0,
sample_temperature=0.0,
......@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.max_decode_length = max_decode_length
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature,
dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(
sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy
super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype)
......@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return state, state_shape_invariants
def _get_new_alive_state(
self,
new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished.
......@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.ALIVE_CACHE: new_cache
}
def _get_new_finished_state(self,
state: Dict[str, Any],
new_seq: tf.Tensor,
def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor,
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]:
......@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype)
alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank(
finished_cond, finished_scores)
seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq)
score_cond = decoding_module.expand_to_same_rank(finished_cond,
finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores
......
......@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__':
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main)
......@@ -29,17 +29,16 @@ import timeit
import traceback
import typing
from absl import logging
import numpy as np
import six
from six.moves import queue
import tensorflow as tf
from absl import logging
from tensorflow.python.tpu.datasets import StreamingFilesDataset
from official.recommendation import constants as rconst
from official.recommendation import movielens
from official.recommendation import popen_helper
from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General:
{spacer}Num users: {num_users}
......@@ -119,6 +118,7 @@ class DatasetManager(object):
"""Convert NumPy arrays into a TFRecords entry."""
def create_int_feature(values):
values = np.squeeze(values)
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
feature_dict = {
......
......@@ -23,21 +23,19 @@ import os
import pickle
import time
import timeit
# pylint: disable=wrong-import-order
import typing
from typing import Dict, Text, Tuple
from absl import logging
import numpy as np
import pandas as pd
import tensorflow as tf
import typing
from typing import Dict, Text, Tuple
# pylint: enable=wrong-import-order
from official.recommendation import constants as rconst
from official.recommendation import data_pipeline
from official.recommendation import movielens
_EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY,
rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY,
rconst.USER_MAP, rconst.ITEM_MAP)
......@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
logging.info("Writing raw data cache.")
with tf.io.gfile.GFile(cache_path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(data, f, protocol=4)
# TODO(robieta): MLPerf cache clear.
return data, valid_cache
......
......@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
export EXPERIMENT_NAME=my_experiment_name
export BUCKET_NAME="gs://my_dlrm_bucket"
export DATA_DIR="${BUCKET_NAME}/data"
export EMBEDDING_DIM=32
python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
......@@ -126,8 +127,8 @@ task:
global_batch_size: 16384
model:
num_dense_features: 13
bottom_mlp: [512,256,128]
embedding_dim: 128
bottom_mlp: [512,256,${EMBEDDING_DIM}]
embedding_dim: ${EMBEDDING_DIM}
top_mlp: [1024,1024,512,256,1]
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
......@@ -135,8 +136,8 @@ task:
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer:
use_orbit: true
validation_interval: 90000
checkpoint_interval: 100000
validation_interval: 85352
checkpoint_interval: 85352
validation_steps: 5440
train_steps: 256054
steps_per_loop: 1000
......@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs):
```shell
python3 official/recommendation/ranking/main.py --mode=train_and_eval \
export EMBEDDING_DIM=8
python3 official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime:
distribution_strategy: 'mirrored'
......
......@@ -12,6 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from official.common.distribute_utils import *
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Ranking Model configuration definition."""
from typing import Optional, List
from typing import Optional, List, Union
import dataclasses
from official.core import exp_factory
......@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
......@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
"""
num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
embedding_dim: int = 8
embedding_dim: Union[int, List[int]] = 8
size_threshold: int = 50_000
bottom_mlp: List[int] = dataclasses.field(default_factory=list)
top_mlp: List[int] = dataclasses.field(default_factory=list)
interaction: str = 'dot'
......@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
embedding_dim=4,
embedding_dim=8,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas = ctx.num_replicas_in_sync if ctx else 1
if params.is_training:
dataset_size = 10000 * batch_size * num_replicas
dataset_size = 1000 * batch_size * num_replicas
else:
dataset_size = 1000 * batch_size * num_replicas
dense_tensor = tf.random.uniform(
......@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features': sparse_tensor_elements}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem)
dataset = dataset.cache()
if params.is_training:
dataset = dataset.repeat()
......
......@@ -17,8 +17,8 @@
from absl.testing import parameterized
import tensorflow as tf
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -15,7 +15,7 @@
"""Task for the Ranking model."""
import math
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import tensorflow as tf
import tensorflow_recommenders as tfrs
......@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from official.core import base_task
from official.core import config_definitions
from official.recommendation.ranking import common
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
RuntimeConfig = config_definitions.RuntimeConfig
def _get_tpu_embedding_feature_config(
vocab_sizes: List[int],
embedding_dim: int,
embedding_dim: Union[int, List[int]],
table_name_prefix: str = 'embedding_table'
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args:
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
table_name_prefix: a prefix for embedding tables.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
if isinstance(embedding_dim, List):
if len(vocab_sizes) != len(embedding_dim):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
f'length of embedding_dim: {len(embedding_dim)}')
elif isinstance(embedding_dim, int):
embedding_dim = [embedding_dim] * len(vocab_sizes)
else:
raise ValueError('embedding_dim is not either a list or an int, got '
f'{type(embedding_dim)}')
feature_config = {}
for i, vocab_size in enumerate(vocab_sizes):
table_config = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=vocab_size,
dim=embedding_dim,
dim=embedding_dim[i],
combiner='mean',
initializer=tf.initializers.TruncatedNormal(
mean=0.0, stddev=1 / math.sqrt(embedding_dim)),
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
name=table_name_prefix + '_%s' % i)
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
table=table_config)
......@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization.
Args:
params: the RannkingModel task configuration instance.
params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
......@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self.optimizer_config.embedding_optimizer)
embedding_optimizer.learning_rate = lr_callable
emb_feature_config = _get_tpu_embedding_feature_config(
vocab_sizes=self.task_config.model.vocab_sizes,
embedding_dim=self.task_config.model.embedding_dim)
feature_config = _get_tpu_embedding_feature_config(
embedding_dim=self.task_config.model.embedding_dim,
vocab_sizes=self.task_config.model.vocab_sizes)
tpu_embedding = tfrs.layers.embedding.TPUEmbedding(
emb_feature_config, embedding_optimizer)
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
feature_config=feature_config,
optimizer=embedding_optimizer,
size_threshold=self.task_config.model.size_threshold)
if self.task_config.model.interaction == 'dot':
feature_interaction = tfrs.layers.feature_interaction.DotInteraction()
feature_interaction = tfrs.layers.feature_interaction.DotInteraction(
skip_gather=True)
elif self.task_config.model.interaction == 'cross':
feature_interaction = tf.keras.Sequential([
tf.keras.layers.Concatenate(),
......@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f'is not supported it must be either \'dot\' or \'cross\'.')
model = tfrs.experimental.models.Ranking(
embedding_layer=tpu_embedding,
embedding_layer=embedding_layer,
bottom_stack=tfrs.layers.blocks.MLP(
units=self.task_config.model.bottom_mlp, final_activation='relu'),
feature_interaction=feature_interaction,
......@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@property
def optimizer_config(self) -> config.OptimizationConfig:
return self._optimizer_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment