"vscode:/vscode.git/clone" did not exist on "abf75ac039420f7a4ab64a419416dd493b906742"
Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None): ...@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix) return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def _generalized_kernel(x, projection_matrix, is_query, f, h, def _generalized_kernel(x, projection_matrix, f, h):
data_normalizer_fn=None):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS. """Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args: Args:
x: The feature being transformed with shape [B, T, N ,H]. x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections. M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x. f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None. transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns: Returns:
Transformed feature. Transformed feature.
""" """
# No asymmetric operations.
del is_query
if data_normalizer_fn is not None:
x = data_normalizer_fn(x)
if projection_matrix is None: if projection_matrix is None:
return h(x) * f(x) return h(x) * f(x)
...@@ -139,26 +129,18 @@ _TRANSFORM_MAP = { ...@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(
-0.5 * tf.math.reduce_sum( -0.5 * tf.math.reduce_sum(
tf.math.square(x), axis=-1, keepdims=True)), tf.math.square(x), axis=-1, keepdims=True)),),
data_normalizer_fn=lambda x: x /
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))),
"expmod": "expmod":
functools.partial( functools.partial(
_generalized_kernel, _generalized_kernel,
# Avoid exp explosion by shifting. # Avoid exp explosion by shifting.
f=lambda x: tf.math.exp( f=lambda x: tf.math.exp(x - tf.math.reduce_max(
x - tf.math.reduce_max(x, axis=[1, 2, 3], keepdims=True)), x, axis=[1, 2, 3], keepdims=True)),
h=lambda x: tf.math.exp( h=lambda x: tf.math.exp(-0.5 * tf.math.sqrt(
-0.5 * tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))), tf.cast(tf.shape(x)[-1], tf.float32))),
data_normalizer_fn=lambda x: x / ),
(tf.math.sqrt(tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))))), "identity":
"l2": functools.partial(_generalized_kernel, f=lambda x: x, h=lambda x: 1)
functools.partial(
_generalized_kernel,
f=lambda x: x,
h=lambda x: tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32)),
data_normalizer_fn=lambda x: x),
"identity": lambda x, projection_matrix, is_query: x
} }
# pylint: enable=g-long-lambda # pylint: enable=g-long-lambda
...@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794) (https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu, l2 - exp (Lemma 1, positive), relu
- random/deterministic projection - random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw=False, redraw=False,
is_short_seq=False, is_short_seq=False,
begin_kernel=0, begin_kernel=0,
scale=None,
**kwargs): **kwargs):
r"""Constructor of KernelAttention. r"""Constructor of KernelAttention.
Args: Args:
feature_transform: A non-linear transform of the keys and quries. feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod", Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose "identity".
feature_transform as "l2".
num_random_features: Number of random features to be used for projection. num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform. if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the seed: The seed to begin drawing random features. Once the seed is set, the
...@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option). (default option).
begin_kernel: Apply kernel_attention after this sequence id and apply begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this. softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer. **kwargs: The same arguments `MultiHeadAttention` layer.
""" """
if feature_transform not in _TRANSFORM_MAP: if feature_transform not in _TRANSFORM_MAP:
...@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference # 1. inference
# 2. no redraw # 2. no redraw
self._seed = seed self._seed = seed
super().__init__(**kwargs) super().__init__(**kwargs)
if scale is None:
self._scale = 1.0 / math.sqrt(float(self._key_dim))
else:
self._scale = scale
self._projection_matrix = None self._projection_matrix = None
if num_random_features > 0: if num_random_features > 0:
self._projection_matrix = create_projection_matrix( self._projection_matrix = create_projection_matrix(
...@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns: Returns:
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
""" """
projection_matrix = None projection_matrix = None
if self._num_random_features > 0: if self._num_random_features > 0:
if self._redraw and training: if self._redraw and training:
...@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else: else:
projection_matrix = self._projection_matrix projection_matrix = self._projection_matrix
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix, False) if is_short_seq:
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix, True) # Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = query * self._scale
else:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key *= math.sqrt(self._scale)
query *= math.sqrt(self._scale)
key = _TRANSFORM_MAP[feature_transform](key, projection_matrix)
query = _TRANSFORM_MAP[feature_transform](query, projection_matrix)
if attention_mask is not None: if attention_mask is not None:
key = tf.einsum("BSNH,BS->BSNH", key, attention_mask) key = tf.einsum("BSNH,BS->BSNH", key, attention_mask)
...@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key) attention_scores = tf.einsum("BTNH,BSNH->BTSN", query, key)
attention_scores = tf.nn.softmax(attention_scores, axis=2) attention_scores = tf.nn.softmax(attention_scores, axis=2)
attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value) attention_output = tf.einsum("BTSN,BSNH->BTNH", attention_scores, value)
return attention_output
else: else:
kv = tf.einsum("BSNH,BSND->BNDH", key, value) kv = tf.einsum("BSNH,BSND->BNDH", key, value)
denominator = 1.0 / ( denominator = 1.0 / (
tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) + tf.einsum("BTNH,BNH->BTN", query, tf.reduce_sum(key, axis=1)) +
_NUMERIC_STABLER) _NUMERIC_STABLER)
return tf.einsum("BTNH,BNDH,BTN->BTND", query, kv, denominator) attention_output = tf.einsum(
"BTNH,BNDH,BTN->BTND", query, kv, denominator)
return attention_output
def _build_from_signature(self, query, value, key=None): def _build_from_signature(self, query, value, key=None):
super()._build_from_signature(query=query, value=value, key=key) super()._build_from_signature(query=query, value=value, key=key)
...@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw": self._redraw, "redraw": self._redraw,
"is_short_seq": self._is_short_seq, "is_short_seq": self._is_short_seq,
"begin_kernel": self._begin_kernel, "begin_kernel": self._begin_kernel,
"scale": self._scale,
} }
base_config = super().get_config() base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -21,7 +21,7 @@ import tensorflow as tf ...@@ -21,7 +21,7 @@ import tensorflow as tf
from official.nlp.modeling.layers import kernel_attention as attention from official.nlp.modeling.layers import kernel_attention as attention
_FEATURE_TRANSFORM = ['relu', 'elu', 'exp', 'l2'] _FEATURE_TRANSFORM = ['relu', 'elu', 'exp']
_REDRAW = [True, False] _REDRAW = [True, False]
_TRAINING = [True, False] _TRAINING = [True, False]
_IS_SHORT_SEQ = [True, False] _IS_SHORT_SEQ = [True, False]
......
...@@ -33,121 +33,6 @@ def _check_if_tf_text_installed(): ...@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'.") "'tensorflow-text-nightly'.")
def _iterative_vectorized_fair_share(capacity: tf.Tensor,
limit: Union[int, tf.Tensor]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit = limit // capacity.shape[0]
limit_mask = tf.ones(capacity.shape, dtype=tf.int64) * per_seg_limit
lower_bound = tf.minimum(capacity, limit_mask)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum = limit - tf.math.reduce_sum(lower_bound, axis=0)
remaining_cap_mat = capacity - lower_bound
new_cap = lower_bound + remaining_cap_mat * tf.cast(
tf.math.reduce_sum(remaining_cap_mat, axis=0) <= remaining_cap_sum,
tf.int64)
# Process iteratively. This step is O(#segments), see analysis below.
while True:
remaining_limit = limit - tf.math.reduce_sum(new_cap, axis=0)
remaining_cap = capacity - new_cap
masked_remaining_slots = tf.cast(remaining_cap > 0, tf.int64)
remaining_cap_col_slots = tf.reduce_sum(masked_remaining_slots, axis=0)
masked_remaining_limit = tf.cast(remaining_cap_col_slots > 0,
tf.int64) * remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit = masked_remaining_limit // (
tf.cast(remaining_cap_col_slots <= 0, tf.int64) +
remaining_cap_col_slots) # +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if tf.math.reduce_sum(per_seg_limit) > 0:
remaining_slots_mat = tf.cast(remaining_cap > 0, tf.int64)
new_cap = new_cap + remaining_slots_mat * per_seg_limit
else:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask = tf.cast(
(tf.cumsum(masked_remaining_slots, axis=0) <= masked_remaining_limit)
& (masked_remaining_slots > 0), tf.int64)
new_cap = new_cap + new_remained_assignment_mask
break
return new_cap
def round_robin_truncate_inputs(
inputs: Union[tf.RaggedTensor, List[tf.RaggedTensor]],
limit: Union[int, tf.Tensor],
) -> Union[tf.RaggedTensor, List[tf.RaggedTensor]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if not isinstance(inputs, (list, tuple)):
return round_robin_truncate_inputs([inputs], limit)[0]
limit = tf.cast(limit, tf.int64)
if not all(rt.shape.rank == 2 for rt in inputs):
raise ValueError("All inputs must have shape [batch_size, (items)]")
if len(inputs) == 1:
return [_truncate_row_lengths(inputs[0], limit)]
elif len(inputs) == 2:
size_a, size_b = [rt.row_lengths() for rt in inputs]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half = limit // 2
ceil_half = limit - floor_half
quota_a = tf.minimum(size_a, ceil_half + tf.nn.relu(floor_half - size_b))
quota_b = tf.minimum(size_b, floor_half + tf.nn.relu(ceil_half - size_a))
return [_truncate_row_lengths(inputs[0], quota_a),
_truncate_row_lengths(inputs[1], quota_b)]
else:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity = tf.stack([rt.row_lengths() for rt in inputs]) # #Segments x B
new_capacity = _iterative_vectorized_fair_share(capacity, limit)
return [
_truncate_row_lengths(inputs[i], new_capacity[i])
for i in range(capacity.shape[0])
]
def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor, def _truncate_row_lengths(ragged_tensor: tf.RaggedTensor,
new_lengths: tf.Tensor) -> tf.RaggedTensor: new_lengths: tf.Tensor) -> tf.RaggedTensor:
"""Truncates the rows of `ragged_tensor` to the given row lengths.""" """Truncates the rows of `ragged_tensor` to the given row lengths."""
...@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
# fall back to some ad-hoc truncation. # fall back to some ad-hoc truncation.
num_special_tokens = len(inputs) + 1 num_special_tokens = len(inputs) + 1
if truncator == "round_robin": if truncator == "round_robin":
trimmed_segments = round_robin_truncate_inputs( trimmed_segments = text.RoundRobinTrimmer(seq_length -
inputs, seq_length - num_special_tokens) num_special_tokens).trim(inputs)
elif truncator == "waterfall": elif truncator == "waterfall":
trimmed_segments = text.WaterfallTrimmer( trimmed_segments = text.WaterfallTrimmer(
seq_length - num_special_tokens).trim(inputs) seq_length - num_special_tokens).trim(inputs)
......
...@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer ...@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
from official.nlp.modeling.layers import text_layers from official.nlp.modeling.layers import text_layers
class RoundRobinTruncatorTest(tf.test.TestCase):
def _test_input(self, start, lengths):
return tf.ragged.constant([[start + 10 * j + i
for i in range(length)]
for j, length in enumerate(lengths)],
dtype=tf.int32)
def test_single_segment(self):
# Single segment.
single_input = self._test_input(11, [4, 5, 6])
expected_single_output = tf.ragged.constant(
[[11, 12, 13, 14],
[21, 22, 23, 24, 25],
[31, 32, 33, 34, 35], # Truncated.
])
self.assertAllEqual(
expected_single_output,
text_layers.round_robin_truncate_inputs(single_input, limit=5))
# Test wrapping in a singleton list.
actual_single_list_output = text_layers.round_robin_truncate_inputs(
[single_input], limit=5)
self.assertIsInstance(actual_single_list_output, list)
self.assertAllEqual(expected_single_output, actual_single_list_output[0])
def test_two_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5])
expected_a = tf.ragged.constant(
[[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153], # Truncated.
[161, 162, 163], # Truncated.
])
expected_b = tf.ragged.constant(
[[211],
[221, 222, 223],
[231, 232, 233], # Truncated.
[241, 242],
[251, 252],
[261, 262], # Truncated.
])
actual_a, actual_b = text_layers.round_robin_truncate_inputs(
[input_a, input_b], limit=5)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
def test_three_segments(self):
input_a = self._test_input(111, [1, 2, 2, 3, 4, 5, 1])
input_b = self._test_input(211, [1, 3, 4, 2, 2, 5, 8])
input_c = self._test_input(311, [1, 3, 4, 2, 2, 5, 10])
seg_limit = 8
expected_a = tf.ragged.constant([
[111],
[121, 122],
[131, 132],
[141, 142, 143],
[151, 152, 153, 154],
[161, 162, 163], # Truncated
[171]
])
expected_b = tf.ragged.constant([
[211],
[221, 222, 223],
[231, 232, 233], # Truncated
[241, 242],
[251, 252],
[261, 262, 263], # Truncated
[271, 272, 273, 274] # Truncated
])
expected_c = tf.ragged.constant([
[311],
[321, 322, 323],
[331, 332, 333], # Truncated
[341, 342],
[351, 352],
[361, 362], # Truncated
[371, 372, 373] # Truncated
])
actual_a, actual_b, actual_c = text_layers.round_robin_truncate_inputs(
[input_a, input_b, input_c], limit=seg_limit)
self.assertAllEqual(expected_a, actual_a)
self.assertAllEqual(expected_b, actual_b)
self.assertAllEqual(expected_c, actual_c)
input_cap = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [input_a, input_b, input_c]]),
axis=0)
per_example_usage = tf.math.reduce_sum(
tf.stack([rt.row_lengths() for rt in [actual_a, actual_b, actual_c]]),
axis=0)
self.assertTrue(all(per_example_usage <= tf.minimum(seg_limit, input_cap)))
# This test covers the in-process behavior of a BertTokenizer layer. # This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference), # For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py. # see nlp/tools/export_tfhub_lib_test.py.
......
...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model): ...@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder. encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head. cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits. It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate', If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored. 'use_encoder_pooler', 'head_name') will be ignored.
""" """
def __init__(self, def __init__(self,
...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model): ...@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True, use_encoder_pooler=True,
head_name='sentence_prediction',
cls_head=None, cls_head=None,
**kwargs): **kwargs):
self.num_classes = num_classes self.num_classes = num_classes
self.head_name = head_name
self.initializer = initializer self.initializer = initializer
self.use_encoder_pooler = use_encoder_pooler self.use_encoder_pooler = use_encoder_pooler
...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model): ...@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
name='sentence_prediction') name=head_name)
predictions = classifier(cls_inputs) predictions = classifier(cls_inputs)
...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model): ...@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return { return {
'network': self._network, 'network': self._network,
'num_classes': self.num_classes, 'num_classes': self.num_classes,
'head_name': self.head_name,
'initializer': self.initializer, 'initializer': self.initializer,
'use_encoder_pooler': self.use_encoder_pooler, 'use_encoder_pooler': self.use_encoder_pooler,
'cls_head': self._cls_head, 'cls_head': self._cls_head,
......
...@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
def _embedding_linear(self, embedding_matrix, x): def _embedding_linear(self, embedding_matrix, x):
"""Uses embeddings as linear transformation weights.""" """Uses embeddings as linear transformation weights."""
embedding_matrix = tf.cast(embedding_matrix, dtype=self.compute_dtype)
x = tf.cast(x, dtype=self.compute_dtype)
batch_size = tf.shape(x)[0] batch_size = tf.shape(x)[0]
length = tf.shape(x)[1] length = tf.shape(x)[1]
hidden_size = tf.shape(x)[2] hidden_size = tf.shape(x)[2]
vocab_size = tf.shape(embedding_matrix)[0] vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size]) x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True) logits = tf.matmul(x, embedding_matrix, transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size]) return tf.reshape(logits, [batch_size, length, vocab_size])
......
...@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer. Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector. summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head. dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
""" """
def __init__( def __init__(
...@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer: tf.keras.initializers.Initializer = 'random_normal', initializer: tf.keras.initializers.Initializer = 'random_normal',
summary_type: str = 'last', summary_type: str = 'last',
dropout_rate: float = 0.1, dropout_rate: float = 0.1,
head_name: str = 'sentence_prediction',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._network = network self._network = network
...@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes': num_classes, 'num_classes': num_classes,
'summary_type': summary_type, 'summary_type': summary_type,
'dropout_rate': dropout_rate, 'dropout_rate': dropout_rate,
'head_name': head_name,
} }
if summary_type == 'last': if summary_type == 'last':
...@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model): ...@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer=initializer, initializer=initializer,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
cls_token_idx=cls_token_idx, cls_token_idx=cls_token_idx,
name='sentence_prediction') name=head_name)
def call(self, inputs: Mapping[str, Any]): def call(self, inputs: Mapping[str, Any]):
input_ids = inputs['input_word_ids'] input_ids = inputs['input_word_ids']
......
...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings. generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs. dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
""" """
def __init__(self, def __init__(self,
...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width=None, embedding_width=None,
embedding_layer=None, embedding_layer=None,
dict_outputs=False, dict_outputs=False,
norm_first=False,
**kwargs): **kwargs):
# b/164516224 # b/164516224
...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer=initializer, initializer=initializer,
output_range=output_range, output_range=output_range,
embedding_width=embedding_width, embedding_width=embedding_width,
embedding_layer=embedding_layer) embedding_layer=embedding_layer,
norm_first=norm_first)
self._embedding_layer_instance = embedding_layer self._embedding_layer_instance = embedding_layer
......
...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase): ...@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range=-1, output_range=-1,
embedding_width=16, embedding_width=16,
dict_outputs=True, dict_outputs=True,
embedding_layer=None) embedding_layer=None,
norm_first=False)
network = bert_encoder.BertEncoder(**kwargs) network = bert_encoder.BertEncoder(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
expected_config["activation"] = tf.keras.activations.serialize( expected_config["activation"] = tf.keras.activations.serialize(
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding.""" """Sampling module for top_k, top_p and greedy decoding."""
import abc import abc
from typing import Any, Callable, Dict from typing import Any, Callable, Dict, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p): ...@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
], -1) ], -1)
# Scatter sorted indices to original indexes. # Scatter sorted indices to original indexes.
indices_to_remove = scatter_values_on_batch_indices( indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove,
sorted_indices_to_remove, sorted_indices) sorted_indices)
top_p_logits = set_tensor_by_indices_to_value( top_p_logits = set_tensor_by_indices_to_value(logits, indices_to_remove,
logits, indices_to_remove, np.NINF) np.NINF)
return top_p_logits return top_p_logits
...@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices): ...@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape = decoding_module.shape_list(batch_indices) tensor_shape = decoding_module.shape_list(batch_indices)
broad_casted_batch_dims = tf.reshape( broad_casted_batch_dims = tf.reshape(
tf.broadcast_to( tf.broadcast_to(
tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tf.expand_dims(tf.range(tensor_shape[0]), axis=-1), tensor_shape),
tensor_shape), [1, -1]) [1, -1])
pair_indices = tf.transpose( pair_indices = tf.transpose(
tf.concat([broad_casted_batch_dims, tf.concat([broad_casted_batch_dims,
tf.reshape(batch_indices, [1, -1])], 0)) tf.reshape(batch_indices, [1, -1])], 0))
return tf.scatter_nd(pair_indices, return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), tensor_shape)
tf.reshape(values, [-1]), tensor_shape)
def set_tensor_by_indices_to_value(input_tensor, indices, value): def set_tensor_by_indices_to_value(input_tensor, indices, value):
...@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value): ...@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim) input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim) indices: bool (batch_size, dim)
value: float scalar value: float scalar
Returns: Returns:
output_tensor: same shape as input_tensor. output_tensor: same shape as input_tensor.
""" """
...@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
symbols_to_logits_fn, symbols_to_logits_fn,
length_normalization_fn: Callable[[int, tf.DType], float],
vocab_size: int, vocab_size: int,
max_decode_length: int, max_decode_length: int,
eos_id: int, eos_id: int,
padded_decode: bool, padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType],
float]] = None,
top_k=0, top_k=0,
top_p=1.0, top_p=1.0,
sample_temperature=0.0, sample_temperature=0.0,
...@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self.max_decode_length = max_decode_length self.max_decode_length = max_decode_length
self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32) self.top_k = tf.convert_to_tensor(top_k, dtype=tf.int32)
self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32) self.top_p = tf.convert_to_tensor(top_p, dtype=tf.float32)
self.sample_temperature = tf.convert_to_tensor(sample_temperature, self.sample_temperature = tf.convert_to_tensor(
dtype=tf.float32) sample_temperature, dtype=tf.float32)
self.enable_greedy = enable_greedy self.enable_greedy = enable_greedy
super(SamplingModule, self).__init__( super(SamplingModule, self).__init__(
length_normalization_fn=length_normalization_fn, dtype=dtype) length_normalization_fn=length_normalization_fn, dtype=dtype)
...@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return state, state_shape_invariants return state, state_shape_invariants
def _get_new_alive_state( def _get_new_alive_state(self, new_seq: tf.Tensor, new_log_probs: tf.Tensor,
self, new_finished_flags: tf.Tensor,
new_seq: tf.Tensor, new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor,
new_cache: Dict[str, tf.Tensor]) -> Dict[str, Any]:
"""Gather the sequences that are still alive. """Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished. This function resets the sequences in the alive_state that are finished.
...@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module.StateKeys.ALIVE_CACHE: new_cache decoding_module.StateKeys.ALIVE_CACHE: new_cache
} }
def _get_new_finished_state(self, def _get_new_finished_state(self, state: Dict[str, Any], new_seq: tf.Tensor,
state: Dict[str, Any],
new_seq: tf.Tensor,
new_log_probs: tf.Tensor, new_log_probs: tf.Tensor,
new_finished_flags: tf.Tensor, new_finished_flags: tf.Tensor,
batch_size: int) -> Dict[str, tf.Tensor]: batch_size: int) -> Dict[str, tf.Tensor]:
...@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta): ...@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm = self.length_normalization_fn(self.max_decode_length + 1, length_norm = self.length_normalization_fn(self.max_decode_length + 1,
self.dtype) self.dtype)
alive_log_probs = alive_log_probs / length_norm alive_log_probs = alive_log_probs / length_norm
seq_cond = decoding_module.expand_to_same_rank( seq_cond = decoding_module.expand_to_same_rank(finished_cond, finished_seq)
finished_cond, finished_seq) score_cond = decoding_module.expand_to_same_rank(finished_cond,
score_cond = decoding_module.expand_to_same_rank( finished_scores)
finished_cond, finished_scores)
finished_seq = tf.where(seq_cond, finished_seq, alive_seq) finished_seq = tf.where(seq_cond, finished_seq, alive_seq)
finished_scores = tf.where(score_cond, finished_scores, alive_log_probs) finished_scores = tf.where(score_cond, finished_scores, alive_log_probs)
return finished_seq, finished_scores return finished_seq, finished_scores
......
...@@ -66,4 +66,5 @@ def main(_): ...@@ -66,4 +66,5 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
app.run(main) app.run(main)
...@@ -29,17 +29,16 @@ import timeit ...@@ -29,17 +29,16 @@ import timeit
import traceback import traceback
import typing import typing
from absl import logging
import numpy as np import numpy as np
import six
from six.moves import queue from six.moves import queue
import tensorflow as tf import tensorflow as tf
from absl import logging
from tensorflow.python.tpu.datasets import StreamingFilesDataset
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import movielens from official.recommendation import movielens
from official.recommendation import popen_helper from official.recommendation import popen_helper
from official.recommendation import stat_utils from official.recommendation import stat_utils
from tensorflow.python.tpu.datasets import StreamingFilesDataset
SUMMARY_TEMPLATE = """General: SUMMARY_TEMPLATE = """General:
{spacer}Num users: {num_users} {spacer}Num users: {num_users}
...@@ -119,6 +118,7 @@ class DatasetManager(object): ...@@ -119,6 +118,7 @@ class DatasetManager(object):
"""Convert NumPy arrays into a TFRecords entry.""" """Convert NumPy arrays into a TFRecords entry."""
def create_int_feature(values): def create_int_feature(values):
values = np.squeeze(values)
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
feature_dict = { feature_dict = {
......
...@@ -23,21 +23,19 @@ import os ...@@ -23,21 +23,19 @@ import os
import pickle import pickle
import time import time
import timeit import timeit
import typing
# pylint: disable=wrong-import-order from typing import Dict, Text, Tuple
from absl import logging from absl import logging
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
import typing
from typing import Dict, Text, Tuple
# pylint: enable=wrong-import-order
from official.recommendation import constants as rconst from official.recommendation import constants as rconst
from official.recommendation import data_pipeline from official.recommendation import data_pipeline
from official.recommendation import movielens from official.recommendation import movielens
_EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, _EXPECTED_CACHE_KEYS = (rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY,
rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY, rconst.EVAL_USER_KEY, rconst.EVAL_ITEM_KEY,
rconst.USER_MAP, rconst.ITEM_MAP) rconst.USER_MAP, rconst.ITEM_MAP)
...@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text, ...@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
logging.info("Writing raw data cache.") logging.info("Writing raw data cache.")
with tf.io.gfile.GFile(cache_path, "wb") as f: with tf.io.gfile.GFile(cache_path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(data, f, protocol=4)
# TODO(robieta): MLPerf cache clear. # TODO(robieta): MLPerf cache clear.
return data, valid_cache return data, valid_cache
......
...@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu ...@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
export EXPERIMENT_NAME=my_experiment_name export EXPERIMENT_NAME=my_experiment_name
export BUCKET_NAME="gs://my_dlrm_bucket" export BUCKET_NAME="gs://my_dlrm_bucket"
export DATA_DIR="${BUCKET_NAME}/data" export DATA_DIR="${BUCKET_NAME}/data"
export EMBEDDING_DIM=32
python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \ python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override=" --model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
...@@ -126,8 +127,8 @@ task: ...@@ -126,8 +127,8 @@ task:
global_batch_size: 16384 global_batch_size: 16384
model: model:
num_dense_features: 13 num_dense_features: 13
bottom_mlp: [512,256,128] bottom_mlp: [512,256,${EMBEDDING_DIM}]
embedding_dim: 128 embedding_dim: ${EMBEDDING_DIM}
top_mlp: [1024,1024,512,256,1] top_mlp: [1024,1024,512,256,1]
interaction: 'dot' interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
...@@ -135,8 +136,8 @@ task: ...@@ -135,8 +136,8 @@ task:
39979771, 25641295, 39664984, 585935, 12972, 108, 36] 39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer: trainer:
use_orbit: true use_orbit: true
validation_interval: 90000 validation_interval: 85352
checkpoint_interval: 100000 checkpoint_interval: 85352
validation_steps: 5440 validation_steps: 5440
train_steps: 256054 train_steps: 256054
steps_per_loop: 1000 steps_per_loop: 1000
...@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs ...@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs): to be updated and number of GPUs provided (for 4 GPUs):
```shell ```shell
python3 official/recommendation/ranking/main.py --mode=train_and_eval \ export EMBEDDING_DIM=8
python3 official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override=" --model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime: runtime:
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
......
...@@ -12,6 +12,3 @@ ...@@ -12,6 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from official.common.distribute_utils import *
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Ranking Model configuration definition.""" """Ranking Model configuration definition."""
from typing import Optional, List from typing import Optional, List, Union
import dataclasses import dataclasses
from official.core import exp_factory from official.core import exp_factory
...@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config): ...@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features. num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data. with the order of the input data.
embedding_dim: Embedding dimension. embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features. features.
top_mlp: The sizes of hidden layers for top MLP. top_mlp: The sizes of hidden layers for top MLP.
...@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config): ...@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
""" """
num_dense_features: int = 13 num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list) vocab_sizes: List[int] = dataclasses.field(default_factory=list)
embedding_dim: int = 8 embedding_dim: Union[int, List[int]] = 8
size_threshold: int = 50_000
bottom_mlp: List[int] = dataclasses.field(default_factory=list) bottom_mlp: List[int] = dataclasses.field(default_factory=list)
top_mlp: List[int] = dataclasses.field(default_factory=list) top_mlp: List[int] = dataclasses.field(default_factory=list)
interaction: str = 'dot' interaction: str = 'dot'
...@@ -188,7 +195,7 @@ def default_config() -> Config: ...@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime=cfg.RuntimeConfig(), runtime=cfg.RuntimeConfig(),
task=Task( task=Task(
model=ModelConfig( model=ModelConfig(
embedding_dim=4, embedding_dim=8,
vocab_sizes=vocab_sizes, vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 4], bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]), top_mlp=[64, 32, 1]),
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
...@@ -136,7 +136,7 @@ class CriteoTsvReader: ...@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas = ctx.num_replicas_in_sync if ctx else 1 num_replicas = ctx.num_replicas_in_sync if ctx else 1
if params.is_training: if params.is_training:
dataset_size = 10000 * batch_size * num_replicas dataset_size = 1000 * batch_size * num_replicas
else: else:
dataset_size = 1000 * batch_size * num_replicas dataset_size = 1000 * batch_size * num_replicas
dense_tensor = tf.random.uniform( dense_tensor = tf.random.uniform(
...@@ -169,6 +169,7 @@ class CriteoTsvReader: ...@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features': sparse_tensor_elements}, label_tensor 'sparse_features': sparse_tensor_elements}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem) dataset = tf.data.Dataset.from_tensor_slices(input_elem)
dataset = dataset.cache()
if params.is_training: if params.is_training:
dataset = dataset.repeat() dataset = dataset.repeat()
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase): class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Task for the Ranking model.""" """Task for the Ranking model."""
import math import math
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import tensorflow as tf import tensorflow as tf
import tensorflow_recommenders as tfrs import tensorflow_recommenders as tfrs
...@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs ...@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from official.core import base_task from official.core import base_task
from official.core import config_definitions from official.core import config_definitions
from official.recommendation.ranking import common from official.recommendation.ranking import common
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
RuntimeConfig = config_definitions.RuntimeConfig RuntimeConfig = config_definitions.RuntimeConfig
def _get_tpu_embedding_feature_config( def _get_tpu_embedding_feature_config(
vocab_sizes: List[int], vocab_sizes: List[int],
embedding_dim: int, embedding_dim: Union[int, List[int]],
table_name_prefix: str = 'embedding_table' table_name_prefix: str = 'embedding_table'
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]: ) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
"""Returns TPU embedding feature config. """Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args: Args:
vocab_sizes: List of sizes of categories/id's in the table. vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim: Embedding dimension. embedding_dim: An integer or a list of embedding table dimensions.
table_name_prefix: a prefix for embedding tables. table_name_prefix: a prefix for embedding tables.
Returns: Returns:
A dictionary of feature_name, FeatureConfig pairs. A dictionary of feature_name, FeatureConfig pairs.
""" """
if isinstance(embedding_dim, List):
if len(vocab_sizes) != len(embedding_dim):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
f'length of embedding_dim: {len(embedding_dim)}')
elif isinstance(embedding_dim, int):
embedding_dim = [embedding_dim] * len(vocab_sizes)
else:
raise ValueError('embedding_dim is not either a list or an int, got '
f'{type(embedding_dim)}')
feature_config = {} feature_config = {}
for i, vocab_size in enumerate(vocab_sizes): for i, vocab_size in enumerate(vocab_sizes):
table_config = tf.tpu.experimental.embedding.TableConfig( table_config = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=vocab_size, vocabulary_size=vocab_size,
dim=embedding_dim, dim=embedding_dim[i],
combiner='mean', combiner='mean',
initializer=tf.initializers.TruncatedNormal( initializer=tf.initializers.TruncatedNormal(
mean=0.0, stddev=1 / math.sqrt(embedding_dim)), mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
name=table_name_prefix + '_%s' % i) name=table_name_prefix + '_%s' % i)
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig( feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
table=table_config) table=table_config)
...@@ -72,7 +85,7 @@ class RankingTask(base_task.Task): ...@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization. """Task initialization.
Args: Args:
params: the RannkingModel task configuration instance. params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance. optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be logging_dir: a string pointing to where the model, summaries etc. will be
saved. saved.
...@@ -125,15 +138,18 @@ class RankingTask(base_task.Task): ...@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self.optimizer_config.embedding_optimizer) self.optimizer_config.embedding_optimizer)
embedding_optimizer.learning_rate = lr_callable embedding_optimizer.learning_rate = lr_callable
emb_feature_config = _get_tpu_embedding_feature_config( feature_config = _get_tpu_embedding_feature_config(
vocab_sizes=self.task_config.model.vocab_sizes, embedding_dim=self.task_config.model.embedding_dim,
embedding_dim=self.task_config.model.embedding_dim) vocab_sizes=self.task_config.model.vocab_sizes)
tpu_embedding = tfrs.layers.embedding.TPUEmbedding( embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
emb_feature_config, embedding_optimizer) feature_config=feature_config,
optimizer=embedding_optimizer,
size_threshold=self.task_config.model.size_threshold)
if self.task_config.model.interaction == 'dot': if self.task_config.model.interaction == 'dot':
feature_interaction = tfrs.layers.feature_interaction.DotInteraction() feature_interaction = tfrs.layers.feature_interaction.DotInteraction(
skip_gather=True)
elif self.task_config.model.interaction == 'cross': elif self.task_config.model.interaction == 'cross':
feature_interaction = tf.keras.Sequential([ feature_interaction = tf.keras.Sequential([
tf.keras.layers.Concatenate(), tf.keras.layers.Concatenate(),
...@@ -145,7 +161,7 @@ class RankingTask(base_task.Task): ...@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f'is not supported it must be either \'dot\' or \'cross\'.') f'is not supported it must be either \'dot\' or \'cross\'.')
model = tfrs.experimental.models.Ranking( model = tfrs.experimental.models.Ranking(
embedding_layer=tpu_embedding, embedding_layer=embedding_layer,
bottom_stack=tfrs.layers.blocks.MLP( bottom_stack=tfrs.layers.blocks.MLP(
units=self.task_config.model.bottom_mlp, final_activation='relu'), units=self.task_config.model.bottom_mlp, final_activation='relu'),
feature_interaction=feature_interaction, feature_interaction=feature_interaction,
...@@ -184,3 +200,5 @@ class RankingTask(base_task.Task): ...@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@property @property
def optimizer_config(self) -> config.OptimizationConfig: def optimizer_config(self) -> config.OptimizationConfig:
return self._optimizer_config return self._optimizer_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment