Commit c351b6f6 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

DLRM and DCN v2 ranking models.

PiperOrigin-RevId: 375529985
parent 136cf614
...@@ -66,9 +66,11 @@ In the near future, we will add: ...@@ -66,9 +66,11 @@ In the near future, we will add:
### Recommendation ### Recommendation
| Model | Reference (Paper) | Model | Reference (Paper)
|-------|-------------------| -------------------------------- | -----------------
| [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) | [DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
## How to get started with the official models ## How to get started with the official models
......
# TF Model Garden Ranking Models
## Overview
This is an implementation of [DLRM](https://arxiv.org/abs/1906.00091) and
[DCN v2](https://arxiv.org/abs/2008.13535) ranking models that can be used for
tasks such as CTR prediction.
The model inputs are numerical and categorical features, and output is a scalar
(for example click probability).
The model can be trained and evaluated on GPU, TPU and CPU. The deep ranking
models are both memory intensive (for embedding tables/lookup) and compute
intensive for deep networks (MLPs). CPUs are best suited for large sparse
embedding lookup, GPUs for fast compute. TPUs are designed for both.
When training on TPUs we use
[TPUEmbedding layer](https://github.com/tensorflow/recommenders/blob/main/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py)
for categorical features. TPU embedding supports large embedding tables with
fast lookup, the size of embedding tables scales linearly with the size of TPU
pod. We can have up to 96 GB embedding tables for TPU v3-8 and 6.14 TB for
v3-512 and 24.6 TB for TPU Pod v3-2048.
The Model code is in
[TensorFlow Recommenders](https://github.com/tensorflow/recommenders/tree/main/tensorflow_recommenders/experimental/models)
library, while input pipeline, configuration and training loop is here.
## Prerequisites
To get started, download the code from TensorFlow models GitHub repository or
use the pre-installed Google Cloud VM. We also need to install [TensorFlow
Recommenders](https://www.tensorflow.org/recommenders) library.
```bash
git clone https://github.com/tensorflow/models.git
pip install -r models/official/requirements.txt
export PYTHONPATH=$PYTHONPATH:$(pwd)/models
```
Make sure to use TensorFlow 2.4+.
## Dataset
The models can be trained on various datasets, Two commonly used ones are
[Criteo Terabyte](https://labs.criteo.com/2013/12/download-terabyte-click-logs/)
and [Criteo Kaggle](https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/)
datasets.
We can train on synthetic data, by setting the flag `use_synthetic_data=True`.
### Download
The dataset is the Terabyte click logs dataset provided by Criteo. Follow the
[instructions](https://labs.criteo.com/2013/12/download-terabyte-click-logs/) at
the Criteo website to download the data.
Note that the dataset is large (~1TB).
### Preprocess the data
Data preprocessing steps are summarized below.
Integer feature processing steps, sequentially:
1. Missing values are replaced with zeros.
2. Negative values are replaced with zeros.
3. Integer features are transformed by log(x+1) and are hence tf.float32.
Categorical features:
1. Categorical data is bucketized to tf.int32.
2. Optionally, the resulting integers are hashed to a lower dimensionality.
This is necessary to reduce the sizes of the large tables. Simple hashing
function such as modulus will suffice, i.e. feature_value % MAX_INDEX.
The vocabulary sizes resulting from pre-processing are passed in to the model
trainer using the model.vocab_sizes config.
The full dataset is composed of 24 directories. Partition the data into training
and eval sets, for example days 1-23 for training and day 24 for evaluation.
Training and eval datasets are expected to be saved in many tab-separated values
(TSV) files in the following format: numberical fetures, categorical features
and label.
On each row of the TSV file first `num_dense_features` inputs are numerical
features, then `vocab_sizes` categorical features and the last one is the label
(either 0 or 1). Each i-th categorical feature is expected to be an integer in
the range of `[0, vocab_sizes[i])`.
## Train and Evaluate
To train DLRM model we use dot product feature interaction, i.e.
`interaction: 'dot'` to train DCN v2 model we use `interaction: 'cross'`.
### Training on TPU
```shell
export TPU_NAME=my-dlrm-tpu
export EXPERIMENT_NAME=my_experiment_name
export BUCKET_NAME="gs://my_dlrm_bucket"
export DATA_DIR="${BUCKET_NAME}/data"
python3 official/recommendation/ranking/main.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime:
distribution_strategy='tpu'
task:
use_synthetic_data: false
train_data:
input_path: '${DATA_DIR}/train/*'
global_batch_size: 16384
validation_data:
input_path: '${DATA_DIR}/eval/*'
global_batch_size: 16384
model:
num_dense_features: 13
bottom_mlp: [512,256,128]
embedding_dim: 128
top_mlp: [1024,1024,512,256,1]
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
38532951, 2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14,
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer:
use_orbit: true
validation_interval: 90000
checkpoint_interval: 100000
validation_steps: 5440
train_steps: 256054
steps_per_execution: 1000
"
```
The data directory should have two subdirectories:
* $DATA_DIR/train
* $DATA_DIR/eval
### Training on GPU
Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs):
```shell
python3 official/recommendation/ranking/main.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime:
distribution_strategy: 'mirrored'
num_gpus: 4
...
"
```
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags and common definitions for Ranking Models."""
from absl import flags
import tensorflow as tf
from official.common import flags as tfm_flags
FLAGS = flags.FLAGS
def define_flags() -> None:
"""Defines flags for training the Ranking model."""
tfm_flags.define_flags()
FLAGS.set_default(name='experiment', value='dlrm_criteo')
FLAGS.set_default(name='mode', value='train_and_eval')
flags.DEFINE_integer(
name='seed',
default=None,
help='This value will be used to seed both NumPy and TensorFlow.')
flags.DEFINE_string(
name='profile_steps',
default='20,40',
help='Save profiling data to model dir at given range of global steps. '
'The value must be a comma separated pair of positive integers, '
'specifying the first and last step to profile. For example, '
'"--profile_steps=2,4" triggers the profiler to process 3 steps, starting'
' from the 2nd step. Note that profiler has a non-trivial performance '
'overhead, and the output file can be gigantic if profiling many steps.')
@tf.keras.utils.register_keras_serializable(package='RANKING')
class WarmUpAndPolyDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate callable for the embeddings.
Linear warmup on [0, warmup_steps] then
Constant on [warmup_steps, decay_start_steps]
And polynomial decay on [decay_start_steps, decay_start_steps + decay_steps].
"""
def __init__(self,
batch_size: int,
decay_exp: float = 2.0,
learning_rate: float = 40.0,
warmup_steps: int = 8000,
decay_steps: int = 12000,
decay_start_steps: int = 10000):
super(WarmUpAndPolyDecay, self).__init__()
self.batch_size = batch_size
self.decay_exp = decay_exp
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.decay_steps = decay_steps
self.decay_start_steps = decay_start_steps
def __call__(self, step):
decay_exp = self.decay_exp
learning_rate = self.learning_rate
warmup_steps = self.warmup_steps
decay_steps = self.decay_steps
decay_start_steps = self.decay_start_steps
scal = self.batch_size / 2048
adj_lr = learning_rate * scal
if warmup_steps == 0:
return adj_lr
warmup_lr = step / warmup_steps * adj_lr
global_step = tf.cast(step, tf.float32)
decay_steps = tf.cast(decay_steps, tf.float32)
decay_start_step = tf.cast(decay_start_steps, tf.float32)
warmup_lr = tf.cast(warmup_lr, tf.float32)
steps_since_decay_start = global_step - decay_start_step
already_decayed_steps = tf.minimum(steps_since_decay_start, decay_steps)
decay_lr = adj_lr * (
(decay_steps - already_decayed_steps) / decay_steps)**decay_exp
decay_lr = tf.maximum(0.0001, decay_lr)
lr = tf.where(
global_step < warmup_steps, warmup_lr,
tf.where(
tf.logical_and(decay_steps > 0, global_step > decay_start_step),
decay_lr, adj_lr))
lr = tf.maximum(0.01, lr)
return lr
def get_config(self):
return {
'batch_size': self.batch_size,
'decay_exp': self.decay_exp,
'learning_rate': self.learning_rate,
'warmup_steps': self.warmup_steps,
'decay_steps': self.decay_steps,
'decay_start_steps': self.decay_start_steps
}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ranking Model configuration definition."""
from typing import Optional, List
import dataclasses
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.hyperparams import config_definitions as cfg
@dataclasses.dataclass
class LearningRateConfig(hyperparams.Config):
"""Learning rate scheduler config."""
learning_rate: float = 1.25
warmup_steps: int = 8000
decay_steps: int = 30000
decay_start_steps: int = 70000
decay_exp: float = 2
@dataclasses.dataclass
class OptimizationConfig(hyperparams.Config):
"""Embedding Optimizer config."""
lr_config: LearningRateConfig = LearningRateConfig()
embedding_optimizer: str = 'SGD'
@dataclasses.dataclass
class DataConfig(hyperparams.Config):
"""Dataset config for training and evaluation."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
sharding: bool = True
num_shards_per_host: int = 8
@dataclasses.dataclass
class ModelConfig(hyperparams.Config):
"""Configuration for training.
Attributes:
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: Embedding dimension.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
interaction: Interaction can be on of the following:
'dot', 'cross'.
"""
num_dense_features: int = 13
vocab_sizes: List[int] = dataclasses.field(default_factory=list)
embedding_dim: int = 8
bottom_mlp: List[int] = dataclasses.field(default_factory=list)
top_mlp: List[int] = dataclasses.field(default_factory=list)
interaction: str = 'dot'
@dataclasses.dataclass
class Loss(hyperparams.Config):
"""Configuration for Loss.
Attributes:
label_smoothing: Whether or not to apply label smoothing to the
Binary Crossentropy loss.
"""
label_smoothing: float = 0.0
@dataclasses.dataclass
class Task(hyperparams.Config):
"""The model config."""
init_checkpoint: str = ''
model: ModelConfig = ModelConfig()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
loss: Loss = Loss()
use_synthetic_data: bool = False
@dataclasses.dataclass
class TimeHistoryConfig(hyperparams.Config):
"""Configuration for the TimeHistory callback.
Attributes:
log_steps: Interval of steps between logging of batch level stats.
"""
log_steps: Optional[int] = None
@dataclasses.dataclass
class TrainerConfig(cfg.TrainerConfig):
"""Configuration for training.
Attributes:
train_steps: The number of steps used to train.
validation_steps: The number of steps used to eval.
validation_interval: The Number of training steps to run between
evaluations.
callbacks: An instance of CallbacksConfig.
use_orbit: Whether to use orbit library with custom training loop or
compile/fit API.
enable_metrics_in_training: Whether to enable metrics during training.
tensorboard: An instance of TensorboardConfig.
time_history: Config of TimeHistory callback.
optimizer_config: An `OptimizerConfig` instance for embedding optimizer.
Defaults to None.
"""
train_steps: int = 0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1
validation_interval: int = 70000
callbacks: cfg.CallbacksConfig = cfg.CallbacksConfig()
use_orbit: bool = False
enable_metrics_in_training: bool = True
tensorboard: cfg.TensorboardConfig = cfg.TensorboardConfig()
time_history: TimeHistoryConfig = TimeHistoryConfig(log_steps=5000)
optimizer_config: OptimizationConfig = OptimizationConfig()
NUM_TRAIN_EXAMPLES = 4195197692
NUM_EVAL_EXAMPLES = 89137318
train_batch_size = 16384
eval_batch_size = 16384
steps_per_epoch = NUM_TRAIN_EXAMPLES // train_batch_size
vocab_sizes = [
39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951,
2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295,
39664984, 585935, 12972, 108, 36
]
@dataclasses.dataclass
class Config(hyperparams.Config):
"""Configuration to train the RankingModel.
By default it configures DLRM model on criteo dataset.
Attributes:
runtime: A `RuntimeConfig` instance.
task: `Task` instance.
trainer: A `TrainerConfig` instance.
"""
runtime: cfg.RuntimeConfig = cfg.RuntimeConfig()
task: Task = Task(
model=ModelConfig(
embedding_dim=8,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 8],
top_mlp=[64, 32, 1]),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
is_training=False,
global_batch_size=eval_batch_size))
trainer: TrainerConfig = TrainerConfig(
train_steps=2 * steps_per_epoch,
validation_interval=steps_per_epoch,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig())
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
def default_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
embedding_dim=4,
vocab_sizes=vocab_sizes,
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=2 * steps_per_epoch,
validation_interval=steps_per_epoch,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
@exp_factory.register_config_factory('dlrm_criteo')
def dlrm_criteo_tb_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
num_dense_features=13,
vocab_sizes=vocab_sizes,
bottom_mlp=[512, 256, 64],
embedding_dim=64,
top_mlp=[1024, 1024, 512, 256, 1],
interaction='dot'),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=steps_per_epoch,
validation_interval=steps_per_epoch // 2,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
@exp_factory.register_config_factory('dcn_criteo')
def dcn_criteo_tb_config() -> Config:
return Config(
runtime=cfg.RuntimeConfig(),
task=Task(
model=ModelConfig(
num_dense_features=13,
vocab_sizes=vocab_sizes,
bottom_mlp=[512, 256, 64],
embedding_dim=64,
top_mlp=[1024, 1024, 512, 256, 1],
interaction='cross'),
loss=Loss(label_smoothing=0.0),
train_data=DataConfig(
global_batch_size=train_batch_size,
is_training=True,
sharding=True),
validation_data=DataConfig(
global_batch_size=eval_batch_size,
is_training=False,
sharding=False)),
trainer=TrainerConfig(
train_steps=steps_per_epoch,
validation_interval=steps_per_epoch // 2,
validation_steps=NUM_EVAL_EXAMPLES // eval_batch_size,
enable_metrics_in_training=True,
optimizer_config=OptimizationConfig()),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for DLRM config."""
from absl.testing import parameterized
import tensorflow as tf
from official.recommendation.ranking.configs import config
class ConfigTest(tf.test.TestCase, parameterized.TestCase):
def test_configs(self):
criteo_config = config.default_config()
self.assertIsInstance(criteo_config, config.Config)
self.assertIsInstance(criteo_config.task, config.Task)
self.assertIsInstance(criteo_config.task.model, config.ModelConfig)
self.assertIsInstance(criteo_config.task.train_data,
config.DataConfig)
self.assertIsInstance(criteo_config.task.validation_data,
config.DataConfig)
criteo_config.task.train_data.is_training = None
with self.assertRaises(KeyError):
criteo_config.validate()
if __name__ == '__main__':
tf.test.main()
runtime:
distribution_strategy: 'tpu'
task:
model:
bottom_mlp: [512, 256, 64]
embedding_dim: 64
num_dense_features: 13
top_mlp: [1024, 1024, 512, 256, 1]
interaction: 'cross'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, 2953546, 403346,
10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295, 39664984, 585935, 12972,
108, 36]
train_data:
global_batch_size: 16384
input_path: path_to_training_data_dir/*
is_training: true
num_shards_per_host: 4
sharding: true
validation_data:
global_batch_size: 16384
input_path: path_to_eval_data_dir/*
is_training: false
sharding: false
trainer:
checkpoint_interval: 85352
eval_tf_function: true
eval_tf_while_loop: false
max_to_keep: 5
train_steps: 256054
train_tf_function: true
train_tf_while_loop: true
use_orbit: true
validation_interval: 85352
validation_steps: 5440
validation_summary_subdir: 'validation'
runtime:
distribution_strategy: 'tpu'
task:
model:
bottom_mlp: [512, 256, 64]
embedding_dim: 64
num_dense_features: 13
top_mlp: [1024, 1024, 512, 256, 1]
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, 2953546, 403346,
10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295, 39664984, 585935, 12972,
108, 36]
train_data:
global_batch_size: 16384
input_path: path_to_training_data_dir/*
is_training: true
num_shards_per_host: 4
sharding: true
validation_data:
global_batch_size: 16384
input_path: path_to_eval_data_dir/*
is_training: false
sharding: false
trainer:
checkpoint_interval: 85352
eval_tf_function: true
eval_tf_while_loop: false
max_to_keep: 5
train_steps: 256054
train_tf_function: true
train_tf_while_loop: true
use_orbit: true
validation_interval: 85352
validation_steps: 5440
validation_summary_subdir: 'validation'
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline for the Ranking model.
This module defines various input datasets for the Ranking model.
"""
from typing import List
import tensorflow as tf
from official.recommendation.ranking.configs import config
class CriteoTsvReader:
"""Input reader callable for pre-processed Criteo data.
Raw Criteo data is assumed to be preprocessed in the following way:
1. Missing values are replaced with zeros.
2. Negative values are replaced with zeros.
3. Integer features are transformed by log(x+1) and are hence tf.float32.
4. Categorical data is bucketized and are hence tf.int32.
"""
def __init__(self,
file_pattern: str,
params: config.DataConfig,
num_dense_features: int,
vocab_sizes: List[int],
use_synthetic_data: bool = False):
self._file_pattern = file_pattern
self._params = params
self._num_dense_features = num_dense_features
self._vocab_sizes = vocab_sizes
self._use_synthetic_data = use_synthetic_data
def __call__(self, ctx: tf.distribute.InputContext) -> tf.data.Dataset:
params = self._params
# Per replica batch size.
batch_size = ctx.get_per_replica_batch_size(
params.global_batch_size) if ctx else params.global_batch_size
if self._use_synthetic_data:
return self._generate_synthetic_data(ctx, batch_size)
@tf.function
def _parse_fn(example: tf.Tensor):
"""Parser function for pre-processed Criteo TSV records."""
label_defaults = [[0.0]]
dense_defaults = [
[0.0] for _ in range(self._num_dense_features)
]
num_sparse_features = len(self._vocab_sizes)
categorical_defaults = [
[0] for _ in range(num_sparse_features)
]
record_defaults = label_defaults + dense_defaults + categorical_defaults
fields = tf.io.decode_csv(
example, record_defaults, field_delim='\t', na_value='-1')
num_labels = 1
label = tf.reshape(fields[0], [batch_size, 1])
features = {}
num_dense = len(dense_defaults)
dense_features = []
offset = num_labels
for idx in range(num_dense):
dense_features.append(fields[idx + offset])
features['dense_features'] = tf.stack(dense_features, axis=1)
offset += num_dense
features['sparse_features'] = {}
for idx in range(num_sparse_features):
features['sparse_features'][str(idx)] = fields[idx + offset]
return features, label
filenames = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
# Shard the full dataset according to host number.
# Each host will get 1 / num_of_hosts portion of the data.
if params.sharding and ctx and ctx.num_input_pipelines > 1:
filenames = filenames.shard(ctx.num_input_pipelines,
ctx.input_pipeline_id)
num_shards_per_host = 1
if params.sharding:
num_shards_per_host = params.num_shards_per_host
def make_dataset(shard_index):
filenames_for_shard = filenames.shard(num_shards_per_host, shard_index)
dataset = tf.data.TextLineDataset(filenames_for_shard)
if params.is_training:
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(_parse_fn,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
indices = tf.data.Dataset.range(num_shards_per_host)
dataset = indices.interleave(
map_func=make_dataset,
cycle_length=params.cycle_length,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def _generate_synthetic_data(self, ctx: tf.distribute.InputContext,
batch_size: int) -> tf.data.Dataset:
"""Creates synthetic data based on the parameter batch size.
Args:
ctx: Input Context
batch_size: per replica batch size.
Returns:
The synthetic dataset.
"""
params = self._params
num_dense = self._num_dense_features
num_replicas = ctx.num_replicas_in_sync if ctx else 1
if params.is_training:
dataset_size = 10000 * batch_size * num_replicas
else:
dataset_size = 1000 * batch_size * num_replicas
dense_tensor = tf.random.uniform(
shape=(dataset_size, num_dense), maxval=1.0, dtype=tf.float32)
sparse_tensors = []
for size in self._vocab_sizes:
sparse_tensors.append(
tf.random.uniform(
shape=(dataset_size,), maxval=int(size), dtype=tf.int32))
sparse_tensor_elements = {
str(i): sparse_tensors[i] for i in range(len(sparse_tensors))
}
# the mean is in [0, 1] interval.
dense_tensor_mean = tf.math.reduce_mean(dense_tensor, axis=1)
sparse_tensors = tf.stack(sparse_tensors, axis=-1)
sparse_tensors_mean = tf.math.reduce_sum(sparse_tensors, axis=1)
# the mean is in [0, 1] interval.
sparse_tensors_mean = tf.cast(sparse_tensors_mean, dtype=tf.float32)
sparse_tensors_mean /= sum(self._vocab_sizes)
# the label is in [0, 1] interval.
label_tensor = (dense_tensor_mean + sparse_tensors_mean) / 2.0
# Using the threshold 0.5 to convert to 0/1 labels.
label_tensor = tf.cast(label_tensor + 0.5, tf.int32)
input_elem = {'dense_features': dense_tensor,
'sparse_features': sparse_tensor_elements}, label_tensor
dataset = tf.data.Dataset.from_tensor_slices(input_elem)
if params.is_training:
dataset = dataset.repeat()
return dataset.batch(batch_size, drop_remainder=True)
def train_input_fn(params: config.Task) -> CriteoTsvReader:
"""Returns callable object of batched training examples.
Args:
params: hyperparams to create input pipelines.
Returns:
CriteoTsvReader callable for training dataset.
"""
return CriteoTsvReader(
file_pattern=params.train_data.input_path,
params=params.train_data,
vocab_sizes=params.model.vocab_sizes,
num_dense_features=params.model.num_dense_features,
use_synthetic_data=params.use_synthetic_data)
def eval_input_fn(params: config.Task) -> CriteoTsvReader:
"""Returns callable object of batched eval examples.
Args:
params: hyperparams to create input pipelines.
Returns:
CriteoTsvReader callable for eval dataset.
"""
return CriteoTsvReader(
file_pattern=params.validation_data.input_path,
params=params.validation_data,
vocab_sizes=params.model.vocab_sizes,
num_dense_features=params.model.num_dense_features,
use_synthetic_data=params.use_synthetic_data)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for data_pipeline."""
from absl.testing import parameterized
import tensorflow as tf
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
class DataPipelineTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(('Train', True),
('Eval', False))
def testSyntheticDataPipeline(self, is_training):
task = config.Task(
model=config.ModelConfig(
embedding_dim=4,
num_dense_features=8,
vocab_sizes=[40, 12, 11, 13, 2, 5],
bottom_mlp=[64, 32, 4],
top_mlp=[64, 32, 1]),
train_data=config.DataConfig(global_batch_size=16),
validation_data=config.DataConfig(global_batch_size=16),
use_synthetic_data=True)
num_dense_features = task.model.num_dense_features
num_sparse_features = len(task.model.vocab_sizes)
batch_size = task.train_data.global_batch_size
if is_training:
dataset = data_pipeline.train_input_fn(task)
else:
dataset = data_pipeline.eval_input_fn(task)
dataset_iter = iter(dataset(ctx=None))
# Consume full batches and validate shapes.
for _ in range(10):
features, label = next(dataset_iter)
dense_features = features['dense_features']
sparse_features = features['sparse_features']
self.assertEqual(dense_features.shape, [batch_size, num_dense_features])
self.assertLen(sparse_features, num_sparse_features)
for _, val in sparse_features.items():
self.assertEqual(val.shape, [batch_size])
self.assertEqual(label.shape, [batch_size])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Task for the Ranking model."""
import math
from typing import Dict, List, Optional
import tensorflow as tf
import tensorflow_recommenders as tfrs
from official.core import base_task
from official.core import config_definitions
from official.recommendation.ranking import common
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking.configs import config
RuntimeConfig = config_definitions.RuntimeConfig
def _get_tpu_embedding_feature_config(
vocab_sizes: List[int],
embedding_dim: int,
table_name_prefix: str = 'embedding_table'
) -> Dict[str, tf.tpu.experimental.embedding.FeatureConfig]:
"""Returns TPU embedding feature config.
Args:
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim: Embedding dimension.
table_name_prefix: a prefix for embedding tables.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
feature_config = {}
for i, vocab_size in enumerate(vocab_sizes):
table_config = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=vocab_size,
dim=embedding_dim,
combiner='mean',
initializer=tf.initializers.TruncatedNormal(
mean=0.0, stddev=1 / math.sqrt(embedding_dim)),
name=table_name_prefix + '_%s' % i)
feature_config[str(i)] = tf.tpu.experimental.embedding.FeatureConfig(
table=table_config)
return feature_config
class RankingTask(base_task.Task):
"""A task for Ranking Model."""
def __init__(self,
params: config.Task,
optimizer_config: config.OptimizationConfig,
logging_dir: Optional[str] = None,
steps_per_execution: int = 1,
name: Optional[str] = None):
"""Task initialization.
Args:
params: the RannkingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
steps_per_execution: Int. Defaults to 1. The number of batches to run
during each `tf.function` call. It's used for compile/fit API.
name: the task name.
"""
super().__init__(params, logging_dir, name=name)
self._optimizer_config = optimizer_config
self._steps_per_execution = steps_per_execution
def build_inputs(self, params, input_context=None):
"""Builds classification input."""
dataset = data_pipeline.CriteoTsvReader(
file_pattern=params.input_path,
params=params,
vocab_sizes=self.task_config.model.vocab_sizes,
num_dense_features=self.task_config.model.num_dense_features,
use_synthetic_data=self.task_config.use_synthetic_data)
return dataset(input_context)
@classmethod
def create_optimizer(cls, optimizer_config: config.OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None) -> None:
"""See base class. Return None, optimizer is set in `build_model`."""
return None
def build_model(self) -> tf.keras.Model:
"""Creates Ranking model architecture and Optimizers.
The RankingModel uses different optimizers/learning rates for embedding
variables and dense variables.
Returns:
A Ranking model instance.
"""
lr_config = self.optimizer_config.lr_config
lr_callable = common.WarmUpAndPolyDecay(
batch_size=self.task_config.train_data.global_batch_size,
decay_exp=lr_config.decay_exp,
learning_rate=lr_config.learning_rate,
warmup_steps=lr_config.warmup_steps,
decay_steps=lr_config.decay_steps,
decay_start_steps=lr_config.decay_start_steps)
dense_optimizer = tf.keras.optimizers.Adam()
embedding_optimizer = tf.keras.optimizers.get(
self.optimizer_config.embedding_optimizer)
embedding_optimizer.learning_rate = lr_callable
emb_feature_config = _get_tpu_embedding_feature_config(
vocab_sizes=self.task_config.model.vocab_sizes,
embedding_dim=self.task_config.model.embedding_dim)
tpu_embedding = tfrs.layers.embedding.TPUEmbedding(
emb_feature_config, embedding_optimizer)
if self.task_config.model.interaction == 'dot':
feature_interaction = tfrs.layers.feature_interaction.DotInteraction()
elif self.task_config.model.interaction == 'cross':
feature_interaction = tf.keras.Sequential([
tf.keras.layers.Concatenate(),
tfrs.layers.feature_interaction.Cross()
])
else:
raise ValueError(
f'params.task.model.interaction {self.task_config.model.interaction} '
f'is not supported it must be either \'dot\' or \'cross\'.')
model = tfrs.experimental.models.Ranking(
embedding_layer=tpu_embedding,
bottom_stack=tfrs.layers.blocks.MLP(
units=self.task_config.model.bottom_mlp, final_activation='relu'),
feature_interaction=feature_interaction,
top_stack=tfrs.layers.blocks.MLP(
units=self.task_config.model.top_mlp, final_activation='sigmoid'),
)
optimizer = tfrs.experimental.optimizers.CompositeOptimizer([
(embedding_optimizer, lambda: model.embedding_trainable_variables),
(dense_optimizer, lambda: model.dense_trainable_variables),
])
model.compile(optimizer, steps_per_execution=self._steps_per_execution)
return model
def train_step(
self,
inputs: Dict[str, tf.Tensor],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[tf.keras.metrics.Metric]] = None) -> tf.Tensor:
"""See base class."""
# All metrics need to be passed through the RankingModel.
assert metrics == model.metrics
return model.train_step(inputs)
def validation_step(
self,
inputs: Dict[str, tf.Tensor],
model: tf.keras.Model,
metrics: Optional[List[tf.keras.metrics.Metric]] = None) -> tf.Tensor:
"""See base class."""
# All metrics need to be passed through the RankingModel.
assert metrics == model.metrics
return model.test_step(inputs)
@property
def optimizer_config(self) -> config.OptimizationConfig:
return self._optimizer_config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for task."""
from absl.testing import parameterized
import tensorflow as tf
from official.core import exp_factory
from official.recommendation.ranking import data_pipeline
from official.recommendation.ranking import task
class TaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(('dlrm_criteo', True),
('dlrm_criteo', False),
('dcn_criteo', True),
('dcn_criteo', False))
def test_task(self, config_name, is_training):
params = exp_factory.get_exp_config(config_name)
params.task.train_data.global_batch_size = 16
params.task.validation_data.global_batch_size = 16
params.task.model.vocab_sizes = [40, 12, 11, 13, 2, 5]
params.task.use_synthetic_data = True
params.task.model.num_dense_features = 5
ranking_task = task.RankingTask(params.task,
params.trainer.optimizer_config)
if is_training:
dataset = data_pipeline.train_input_fn(params.task)
else:
dataset = data_pipeline.eval_input_fn(params.task)
iterator = iter(dataset(ctx=None))
model = ranking_task.build_model()
if is_training:
ranking_task.train_step(next(iterator), model, model.optimizer,
metrics=model.metrics)
else:
ranking_task.validation_step(next(iterator), model, metrics=model.metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Train and evaluate the Ranking model."""
from typing import Dict
from absl import app
from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_trainer
from official.core import train_lib
from official.core import train_utils
from official.recommendation.ranking import common
from official.recommendation.ranking.task import RankingTask
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
class RankingTrainer(base_trainer.Trainer):
"""A trainer for Ranking Model.
The RankingModel has two optimizers for embedding and non embedding weights.
Overriding `train_loop_end` method to log learning rates for each optimizer.
"""
def train_loop_end(self) -> Dict[str, float]:
"""See base class."""
self.join()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if self._recovery:
self._recovery.maybe_recover(self.train_loss.result().numpy(),
self.global_step.numpy())
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
for i, optimizer in enumerate(self.optimizer.optimizers):
lr_key = f'{type(optimizer).__name__}_{i}_learning_rate'
if callable(optimizer.learning_rate):
logs[lr_key] = optimizer.learning_rate(self.global_step)
else:
logs[lr_key] = optimizer.learning_rate
return logs
def main(_) -> None:
"""Train and evaluate the Ranking model."""
params = train_utils.parse_configuration(FLAGS)
mode = FLAGS.mode
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
if FLAGS.seed is not None:
logging.info('Setting tf seed.')
tf.random.set_seed(FLAGS.seed)
task = RankingTask(
params=params.task,
optimizer_config=params.trainer.optimizer_config,
logging_dir=model_dir,
steps_per_execution=params.trainer.steps_per_loop,
name='RankingTask')
enable_tensorboard = params.trainer.callbacks.enable_tensorboard
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with strategy.scope():
model = task.build_model()
if params.trainer.use_orbit:
with strategy.scope():
checkpoint_exporter = train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)
trainer = RankingTrainer(
config=params,
task=task,
model=model,
optimizer=model.optimizer,
train='train' in mode,
evaluate='eval' in mode,
checkpoint_exporter=checkpoint_exporter)
train_lib.run_experiment(
distribution_strategy=strategy,
task=task,
mode=mode,
params=params,
model_dir=model_dir,
trainer=trainer)
else: # Compile/fit
train_dataset = None
if 'train' in mode:
train_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.train_data)
eval_dataset = None
if 'eval' in mode:
eval_dataset = orbit.utils.make_distributed_dataset(
strategy, task.build_inputs, params.task.validation_data)
checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
if latest_checkpoint:
checkpoint.restore(latest_checkpoint)
logging.info('Loaded checkpoint %s', latest_checkpoint)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=model.optimizer.iterations,
checkpoint_interval=params.trainer.checkpoint_interval)
checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
time_callback = keras_utils.TimeHistory(
params.task.train_data.global_batch_size,
params.trainer.time_history.log_steps,
logdir=model_dir if enable_tensorboard else None)
callbacks = [checkpoint_callback, time_callback]
if enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=model_dir,
update_freq=min(1000, params.trainer.validation_interval),
profile_batch=FLAGS.profile_steps)
callbacks.append(tensorboard_callback)
num_epochs = (params.trainer.train_steps //
params.trainer.validation_interval)
current_step = model.optimizer.iterations.numpy()
initial_epoch = current_step // params.trainer.validation_interval
eval_steps = params.trainer.validation_steps if 'eval' in mode else None
if mode in ['train', 'train_and_eval']:
logging.info('Training started')
history = model.fit(
train_dataset,
initial_epoch=initial_epoch,
epochs=num_epochs,
steps_per_epoch=params.trainer.validation_interval,
validation_data=eval_dataset,
validation_steps=eval_steps,
callbacks=callbacks,
)
model.summary()
logging.info('Train history: %s', history.history)
elif mode == 'eval':
logging.info('Evaluation started')
validation_output = model.evaluate(eval_dataset, steps=eval_steps)
logging.info('Evaluation output: %s', validation_output)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
common.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ranking model and associated functionality."""
import json
import os
from absl import flags
from absl.testing import parameterized
import tensorflow as tf
from official.recommendation.ranking import common
from official.recommendation.ranking import train
FLAGS = flags.FLAGS
def _get_params_override(vocab_sizes,
interaction='dot',
use_orbit=True,
strategy='mirrored'):
# Update `data_dir` if `synthetic_data=False`.
data_dir = ''
return json.dumps({
'runtime': {
'distribution_strategy': strategy,
},
'task': {
'model': {
'vocab_sizes': vocab_sizes,
'interaction': interaction,
},
'train_data': {
'input_path': os.path.join(data_dir, 'train/*'),
'global_batch_size': 16,
},
'validation_data': {
'input_path': os.path.join(data_dir, 'eval/*'),
'global_batch_size': 16,
},
'use_synthetic_data': True,
},
'trainer': {
'use_orbit': use_orbit,
'validation_interval': 20,
'validation_steps': 20,
'train_steps': 40,
},
})
class TrainTest(parameterized.TestCase, tf.test.TestCase):
def setUp(self):
super().setUp()
self._temp_dir = self.get_temp_dir()
self._model_dir = os.path.join(self._temp_dir, 'model_dir')
tf.io.gfile.makedirs(self._model_dir)
FLAGS.model_dir = self._model_dir
FLAGS.tpu = ''
def tearDown(self):
tf.io.gfile.rmtree(self._model_dir)
super().tearDown()
@parameterized.named_parameters(
('DlrmOneDeviceCTL', 'one_device', 'dot', True),
('DlrmOneDevice', 'one_device', 'dot', False),
('DcnOneDeviceCTL', 'one_device', 'cross', True),
('DcnOneDevice', 'one_device', 'cross', False),
('DlrmTPUCTL', 'tpu', 'dot', True),
('DlrmTPU', 'tpu', 'dot', False),
('DcnTPUCTL', 'tpu', 'cross', True),
('DcnTPU', 'tpu', 'cross', False),
('DlrmMirroredCTL', 'Mirrored', 'dot', True),
('DlrmMirrored', 'Mirrored', 'dot', False),
('DcnMirroredCTL', 'Mirrored', 'cross', True),
('DcnMirrored', 'Mirrored', 'cross', False),
)
def testTrainEval(self, strategy, interaction, use_orbit=True):
# Set up simple trainer with synthetic data.
# By default the mode must be `train_and_eval`.
self.assertEqual(FLAGS.mode, 'train_and_eval')
vocab_sizes = [40, 12, 11, 13]
FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
interaction=interaction,
use_orbit=use_orbit,
strategy=strategy)
train.main('unused_args')
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
@parameterized.named_parameters(
('DlrmTPUCTL', 'tpu', 'dot', True),
('DlrmTPU', 'tpu', 'dot', False),
('DcnTPUCTL', 'tpu', 'cross', True),
('DcnTPU', 'tpu', 'cross', False),
('DlrmMirroredCTL', 'Mirrored', 'dot', True),
('DlrmMirrored', 'Mirrored', 'dot', False),
('DcnMirroredCTL', 'Mirrored', 'cross', True),
('DcnMirrored', 'Mirrored', 'cross', False),
)
def testTrainThenEval(self, strategy, interaction, use_orbit=True):
# Set up simple trainer with synthetic data.
vocab_sizes = [40, 12, 11, 13]
FLAGS.params_override = _get_params_override(vocab_sizes=vocab_sizes,
interaction=interaction,
use_orbit=use_orbit,
strategy=strategy)
# Training.
FLAGS.mode = 'train'
train.main('unused_args')
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(self._model_dir, 'params.yaml')))
# Evaluation.
FLAGS.mode = 'eval'
train.main('unused_args')
if __name__ == '__main__':
common.define_flags()
tf.test.main()
...@@ -12,6 +12,7 @@ tensorflow-hub>=0.6.0 ...@@ -12,6 +12,7 @@ tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.4.1 tensorflow-model-optimization>=0.4.1
tensorflow-datasets tensorflow-datasets
tensorflow-addons tensorflow-addons
tensorflow-recommenders>=0.5.0
dataclasses;python_version<"3.7" dataclasses;python_version<"3.7"
gin-config gin-config
tf_slim>=1.1.0 tf_slim>=1.1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment