Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
09a38af1
Commit
09a38af1
authored
Jul 16, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 385256430
parent
078eaaf3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
80 additions
and
22 deletions
+80
-22
official/recommendation/ranking/configs/__init__.py
official/recommendation/ranking/configs/__init__.py
+14
-0
official/recommendation/ranking/configs/config.py
official/recommendation/ranking/configs/config.py
+11
-4
official/recommendation/ranking/data/__init__.py
official/recommendation/ranking/data/__init__.py
+14
-0
official/recommendation/ranking/data/data_pipeline.py
official/recommendation/ranking/data/data_pipeline.py
+2
-1
official/recommendation/ranking/data/data_pipeline_test.py
official/recommendation/ranking/data/data_pipeline_test.py
+1
-1
official/recommendation/ranking/task.py
official/recommendation/ranking/task.py
+33
-15
official/recommendation/ranking/task_test.py
official/recommendation/ranking/task_test.py
+3
-1
official/recommendation/ranking/train_test.py
official/recommendation/ranking/train_test.py
+2
-0
No files found.
official/recommendation/ranking/configs/__init__.py
0 → 100644
View file @
09a38af1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/recommendation/ranking/configs/config.py
View file @
09a38af1
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Ranking Model configuration definition."""
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Union
import
dataclasses
from
official.core
import
exp_factory
...
...
@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
...
...
@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
"""
num_dense_features
:
int
=
13
vocab_sizes
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
embedding_dim
:
int
=
8
embedding_dim
:
Union
[
int
,
List
[
int
]]
=
8
size_threshold
:
int
=
50_000
bottom_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
top_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
interaction
:
str
=
'dot'
...
...
@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime
=
cfg
.
RuntimeConfig
(),
task
=
Task
(
model
=
ModelConfig
(
embedding_dim
=
4
,
embedding_dim
=
8
,
vocab_sizes
=
vocab_sizes
,
bottom_mlp
=
[
64
,
32
,
4
],
top_mlp
=
[
64
,
32
,
1
]),
...
...
official/recommendation/ranking/data/__init__.py
0 → 100644
View file @
09a38af1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/recommendation/ranking/data_pipeline.py
→
official/recommendation/ranking/data
/data
_pipeline.py
View file @
09a38af1
...
...
@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas
=
ctx
.
num_replicas_in_sync
if
ctx
else
1
if
params
.
is_training
:
dataset_size
=
1000
0
*
batch_size
*
num_replicas
dataset_size
=
1000
*
batch_size
*
num_replicas
else
:
dataset_size
=
1000
*
batch_size
*
num_replicas
dense_tensor
=
tf
.
random
.
uniform
(
...
...
@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features'
:
sparse_tensor_elements
},
label_tensor
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
input_elem
)
dataset
=
dataset
.
cache
()
if
params
.
is_training
:
dataset
=
dataset
.
repeat
()
...
...
official/recommendation/ranking/data_pipeline_test.py
→
official/recommendation/ranking/data
/data
_pipeline_test.py
View file @
09a38af1
...
...
@@ -17,8 +17,8 @@
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
class
DataPipelineTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/recommendation/ranking/task.py
View file @
09a38af1
...
...
@@ -15,7 +15,7 @@
"""Task for the Ranking model."""
import
math
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
import
tensorflow
as
tf
import
tensorflow_recommenders
as
tfrs
...
...
@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.recommendation.ranking
import
common
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
RuntimeConfig
=
config_definitions
.
RuntimeConfig
def
_get_tpu_embedding_feature_config
(
vocab_sizes
:
List
[
int
],
embedding_dim
:
int
,
embedding_dim
:
Union
[
int
,
List
[
int
]]
,
table_name_prefix
:
str
=
'embedding_table'
)
->
Dict
[
str
,
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
]:
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args:
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim:
E
mbedding dimension.
embedding_dim:
An integer or a list of e
mbedding
table
dimension
s
.
table_name_prefix: a prefix for embedding tables.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
if
isinstance
(
embedding_dim
,
List
):
if
len
(
vocab_sizes
)
!=
len
(
embedding_dim
):
raise
ValueError
(
f
'length of vocab_sizes:
{
len
(
vocab_sizes
)
}
is not equal to the '
f
'length of embedding_dim:
{
len
(
embedding_dim
)
}
'
)
elif
isinstance
(
embedding_dim
,
int
):
embedding_dim
=
[
embedding_dim
]
*
len
(
vocab_sizes
)
else
:
raise
ValueError
(
'embedding_dim is not either a list or an int, got '
f
'
{
type
(
embedding_dim
)
}
'
)
feature_config
=
{}
for
i
,
vocab_size
in
enumerate
(
vocab_sizes
):
table_config
=
tf
.
tpu
.
experimental
.
embedding
.
TableConfig
(
vocabulary_size
=
vocab_size
,
dim
=
embedding_dim
,
dim
=
embedding_dim
[
i
]
,
combiner
=
'mean'
,
initializer
=
tf
.
initializers
.
TruncatedNormal
(
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
)),
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
[
i
]
)),
name
=
table_name_prefix
+
'_%s'
%
i
)
feature_config
[
str
(
i
)]
=
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
(
table
=
table_config
)
...
...
@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization.
Args:
params: the Ran
n
kingModel task configuration instance.
params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
...
...
@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self
.
optimizer_config
.
embedding_optimizer
)
embedding_optimizer
.
learning_rate
=
lr_callable
emb_
feature_config
=
_get_tpu_embedding_feature_config
(
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
,
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
)
feature_config
=
_get_tpu_embedding_feature_config
(
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
,
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
)
tpu_embedding
=
tfrs
.
layers
.
embedding
.
TPUEmbedding
(
emb_feature_config
,
embedding_optimizer
)
embedding_layer
=
tfrs
.
experimental
.
layers
.
embedding
.
PartialTPUEmbedding
(
feature_config
=
feature_config
,
optimizer
=
embedding_optimizer
,
size_threshold
=
self
.
task_config
.
model
.
size_threshold
)
if
self
.
task_config
.
model
.
interaction
==
'dot'
:
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
()
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
(
skip_gather
=
True
)
elif
self
.
task_config
.
model
.
interaction
==
'cross'
:
feature_interaction
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Concatenate
(),
...
...
@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f
'is not supported it must be either
\'
dot
\'
or
\'
cross
\'
.'
)
model
=
tfrs
.
experimental
.
models
.
Ranking
(
embedding_layer
=
tpu_
embedding
,
embedding_layer
=
embedding
_layer
,
bottom_stack
=
tfrs
.
layers
.
blocks
.
MLP
(
units
=
self
.
task_config
.
model
.
bottom_mlp
,
final_activation
=
'relu'
),
feature_interaction
=
feature_interaction
,
...
...
@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@
property
def
optimizer_config
(
self
)
->
config
.
OptimizationConfig
:
return
self
.
_optimizer_config
official/recommendation/ranking/task_test.py
View file @
09a38af1
...
...
@@ -18,8 +18,8 @@ from absl.testing import parameterized
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking
import
task
from
official.recommendation.ranking.data
import
data_pipeline
class
TaskTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
@@ -34,6 +34,8 @@ class TaskTest(parameterized.TestCase, tf.test.TestCase):
params
.
task
.
train_data
.
global_batch_size
=
16
params
.
task
.
validation_data
.
global_batch_size
=
16
params
.
task
.
model
.
vocab_sizes
=
[
40
,
12
,
11
,
13
,
2
,
5
]
params
.
task
.
model
.
embedding_dim
=
8
params
.
task
.
model
.
bottom_mlp
=
[
64
,
32
,
8
]
params
.
task
.
use_synthetic_data
=
True
params
.
task
.
model
.
num_dense_features
=
5
...
...
official/recommendation/ranking/train_test.py
View file @
09a38af1
...
...
@@ -40,6 +40,8 @@ def _get_params_override(vocab_sizes,
'task'
:
{
'model'
:
{
'vocab_sizes'
:
vocab_sizes
,
'embedding_dim'
:
[
8
]
*
len
(
vocab_sizes
),
'bottom_mlp'
:
[
64
,
32
,
8
],
'interaction'
:
interaction
,
},
'train_data'
:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment