Commit 09a38af1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 385256430
parent 078eaaf3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Ranking Model configuration definition."""
from typing import Optional, List
from typing import Optional, List, Union
import dataclasses
from official.core import exp_factory
......@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
......@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
"""
num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
embedding_dim: int = 8
embedding_dim: Union[int, List[int]] = 8
size_threshold: int = 50_000
bottom_mlp: List[int] = dataclasses.field(default_factory=list)
top_mlp: List[int] = dataclasses.field(default_factory=list)
interaction: str = 'dot'
......@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
embedding_dim=4,
embedding_dim=8,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas = ctx.num_replicas_in_sync if ctx else 1
if params.is_training:
dataset_size = 10000 * batch_size * num_replicas
dataset_size = 1000 * batch_size * num_replicas
else:
dataset_size = 1000 * batch_size * num_replicas
dense_tensor = tf.random.uniform(
......@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features': sparse_tensor_elements}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem)
dataset = dataset.cache()
if params.is_training:
dataset = dataset.repeat()
......
......@@ -17,8 +17,8 @@
from absl.testing import parameterized
import tensorflow as tf
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
......
......@@ -15,7 +15,7 @@
"""Task for the Ranking model."""
import math
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
import tensorflow as tf
import tensorflow_recommenders as tfrs
......@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from official.core import base_task
from official.core import config_definitions
from official.recommendation.ranking import common
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
from official.recommendation.ranking.data import data_pipeline
RuntimeConfig = config_definitions.RuntimeConfig
def _get_tpu_embedding_feature_config(
vocab_sizes: List[int],
embedding_dim: int,
embedding_dim: Union[int, List[int]],
table_name_prefix: str = 'embedding_table'
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args:
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
table_name_prefix: a prefix for embedding tables.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
if isinstance(embedding_dim, List):
if len(vocab_sizes) != len(embedding_dim):
raise ValueError(
f'length of vocab_sizes: {len(vocab_sizes)} is not equal to the '
f'length of embedding_dim: {len(embedding_dim)}')
elif isinstance(embedding_dim, int):
embedding_dim = [embedding_dim] * len(vocab_sizes)
else:
raise ValueError('embedding_dim is not either a list or an int, got '
f'{type(embedding_dim)}')
feature_config = {}
for i, vocab_size in enumerate(vocab_sizes):
table_config = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=vocab_size,
dim=embedding_dim,
dim=embedding_dim[i],
combiner='mean',
initializer=tf.initializers.TruncatedNormal(
mean=0.0, stddev=1 / math.sqrt(embedding_dim)),
mean=0.0, stddev=1 / math.sqrt(embedding_dim[i])),
name=table_name_prefix + '_%s' % i)
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
table=table_config)
......@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization.
Args:
params: the RannkingModel task configuration instance.
params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
......@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self.optimizer_config.embedding_optimizer)
embedding_optimizer.learning_rate = lr_callable
emb_feature_config = _get_tpu_embedding_feature_config(
vocab_sizes=self.task_config.model.vocab_sizes,
embedding_dim=self.task_config.model.embedding_dim)
feature_config = _get_tpu_embedding_feature_config(
embedding_dim=self.task_config.model.embedding_dim,
vocab_sizes=self.task_config.model.vocab_sizes)
tpu_embedding = tfrs.layers.embedding.TPUEmbedding(
emb_feature_config, embedding_optimizer)
embedding_layer = tfrs.experimental.layers.embedding.PartialTPUEmbedding(
feature_config=feature_config,
optimizer=embedding_optimizer,
size_threshold=self.task_config.model.size_threshold)
if self.task_config.model.interaction == 'dot':
feature_interaction = tfrs.layers.feature_interaction.DotInteraction()
feature_interaction = tfrs.layers.feature_interaction.DotInteraction(
skip_gather=True)
elif self.task_config.model.interaction == 'cross':
feature_interaction = tf.keras.Sequential([
tf.keras.layers.Concatenate(),
......@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f'is not supported it must be either \'dot\' or \'cross\'.')
model = tfrs.experimental.models.Ranking(
embedding_layer=tpu_embedding,
embedding_layer=embedding_layer,
bottom_stack=tfrs.layers.blocks.MLP(
units=self.task_config.model.bottom_mlp, final_activation='relu'),
feature_interaction=feature_interaction,
......@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@property
def optimizer_config(self) -> config.OptimizationConfig:
return self._optimizer_config
......@@ -18,8 +18,8 @@ from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking import task
from official.recommendation.ranking.data import data_pipeline
class TaskTest(parameterized.TestCase, tf.test.TestCase):
......@@ -34,6 +34,8 @@ class TaskTest(parameterized.TestCase, tf.test.TestCase):
params.task.train_data.global_batch_size = 16
params.task.validation_data.global_batch_size = 16
params.task.model.vocab_sizes = [40, 12, 11, 13, 2, 5]
params.task.model.embedding_dim = 8
params.task.model.bottom_mlp = [64, 32, 8]
params.task.use_synthetic_data = True
params.task.model.num_dense_features = 5
......
......@@ -40,6 +40,8 @@ def _get_params_override(vocab_sizes,
'task': {
'model': {
'vocab_sizes': vocab_sizes,
'embedding_dim': [8] * len(vocab_sizes),
'bottom_mlp': [64, 32, 8],
'interaction': interaction,
},
'train_data': {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment