Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
78c43ef1
Commit
78c43ef1
authored
Jul 26, 2021
by
Gunho Park
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
67cfc95b
e3c7e300
Changes
227
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
319 deletions
+157
-319
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+36
-36
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+1
-1
official/nlp/modeling/layers/text_layers.py
official/nlp/modeling/layers/text_layers.py
+2
-117
official/nlp/modeling/layers/text_layers_test.py
official/nlp/modeling/layers/text_layers_test.py
+0
-96
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+6
-2
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+3
-1
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+4
-1
official/nlp/modeling/networks/bert_encoder.py
official/nlp/modeling/networks/bert_encoder.py
+6
-1
official/nlp/modeling/networks/bert_encoder_test.py
official/nlp/modeling/networks/bert_encoder_test.py
+2
-1
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+20
-25
official/nlp/train.py
official/nlp/train.py
+1
-0
official/recommendation/data_pipeline.py
official/recommendation/data_pipeline.py
+3
-3
official/recommendation/data_preprocessing.py
official/recommendation/data_preprocessing.py
+4
-6
official/recommendation/ranking/README.md
official/recommendation/ranking/README.md
+8
-5
official/recommendation/ranking/configs/__init__.py
official/recommendation/ranking/configs/__init__.py
+0
-3
official/recommendation/ranking/configs/config.py
official/recommendation/ranking/configs/config.py
+11
-4
official/recommendation/ranking/data/__init__.py
official/recommendation/ranking/data/__init__.py
+14
-0
official/recommendation/ranking/data/data_pipeline.py
official/recommendation/ranking/data/data_pipeline.py
+2
-1
official/recommendation/ranking/data/data_pipeline_test.py
official/recommendation/ranking/data/data_pipeline_test.py
+1
-1
official/recommendation/ranking/task.py
official/recommendation/ranking/task.py
+33
-15
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
78c43ef1
...
...
@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
is_query
,
f
,
h
,
data_normalizer_fn
=
None
):
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Transformed feature.
"""
# No asymmetric operations.
del
is_query
if
data_normalizer_fn
is
not
None
:
x
=
data_normalizer_fn
(
x
)
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
...
...
@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),),
"expmod"
:
functools
.
partial
(
_generalized_kernel
,
# Avoid exp explosion by shifting.
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"l2"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
)),
data_normalizer_fn
=
lambda
x
:
x
),
"identity"
:
lambda
x
,
projection_matrix
,
is_query
:
x
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
),
"identity"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
}
# pylint: enable=g-long-lambda
...
...
@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
, l2
- exp (Lemma 1, positive), relu
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...
...
@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw
=
False
,
is_short_seq
=
False
,
begin_kernel
=
0
,
scale
=
None
,
**
kwargs
):
r
"""Constructor of KernelAttention.
Args:
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
feature_transform as "l2".
"identity".
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
...
...
@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
...
@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference
# 2. no redraw
self
.
_seed
=
seed
super
().
__init__
(
**
kwargs
)
if
scale
is
None
:
self
.
_scale
=
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
))
else
:
self
.
_scale
=
scale
self
.
_projection_matrix
=
None
if
num_random_features
>
0
:
self
.
_projection_matrix
=
create_projection_matrix
(
...
...
@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
attention_output: Multi-headed outputs of attention computation.
"""
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_redraw
and
training
:
...
...
@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
projection_matrix
=
self
.
_projection_matrix
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
,
False
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
,
True
)
if
is_short_seq
:
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
query
*
self
.
_scale
else
:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_scale
)
query
*=
math
.
sqrt
(
self
.
_scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
...
...
@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
return
attention_output
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
denominator
=
1.0
/
(
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
_NUMERIC_STABLER
)
return
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
attention_output
=
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
return
attention_output
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
...
...
@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw"
:
self
.
_redraw
,
"is_short_seq"
:
self
.
_is_short_seq
,
"begin_kernel"
:
self
.
_begin_kernel
,
"scale"
:
self
.
_scale
,
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/modeling/layers/kernel_attention_test.py
View file @
78c43ef1
...
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
,
'l2'
]
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
...
...
official/nlp/modeling/layers/text_layers.py
View file @
78c43ef1
...
...
@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'."
)
def
_iterative_vectorized_fair_share
(
capacity
:
tf
.
Tensor
,
limit
:
Union
[
int
,
tf
.
Tensor
]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit
=
limit
//
capacity
.
shape
[
0
]
limit_mask
=
tf
.
ones
(
capacity
.
shape
,
dtype
=
tf
.
int64
)
*
per_seg_limit
lower_bound
=
tf
.
minimum
(
capacity
,
limit_mask
)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum
=
limit
-
tf
.
math
.
reduce_sum
(
lower_bound
,
axis
=
0
)
remaining_cap_mat
=
capacity
-
lower_bound
new_cap
=
lower_bound
+
remaining_cap_mat
*
tf
.
cast
(
tf
.
math
.
reduce_sum
(
remaining_cap_mat
,
axis
=
0
)
<=
remaining_cap_sum
,
tf
.
int64
)
# Process iteratively. This step is O(#segments), see analysis below.
while
True
:
remaining_limit
=
limit
-
tf
.
math
.
reduce_sum
(
new_cap
,
axis
=
0
)
remaining_cap
=
capacity
-
new_cap
masked_remaining_slots
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
remaining_cap_col_slots
=
tf
.
reduce_sum
(
masked_remaining_slots
,
axis
=
0
)
masked_remaining_limit
=
tf
.
cast
(
remaining_cap_col_slots
>
0
,
tf
.
int64
)
*
remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit
=
masked_remaining_limit
//
(
tf
.
cast
(
remaining_cap_col_slots
<=
0
,
tf
.
int64
)
+
remaining_cap_col_slots
)
# +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if
tf
.
math
.
reduce_sum
(
per_seg_limit
)
>
0
:
remaining_slots_mat
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
new_cap
=
new_cap
+
remaining_slots_mat
*
per_seg_limit
else
:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask
=
tf
.
cast
(
(
tf
.
cumsum
(
masked_remaining_slots
,
axis
=
0
)
<=
masked_remaining_limit
)
&
(
masked_remaining_slots
>
0
),
tf
.
int64
)
new_cap
=
new_cap
+
new_remained_assignment_mask
break
return
new_cap
def
round_robin_truncate_inputs
(
inputs
:
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]],
limit
:
Union
[
int
,
tf
.
Tensor
],
)
->
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
return
round_robin_truncate_inputs
([
inputs
],
limit
)[
0
]
limit
=
tf
.
cast
(
limit
,
tf
.
int64
)
if
not
all
(
rt
.
shape
.
rank
==
2
for
rt
in
inputs
):
raise
ValueError
(
"All inputs must have shape [batch_size, (items)]"
)
if
len
(
inputs
)
==
1
:
return
[
_truncate_row_lengths
(
inputs
[
0
],
limit
)]
elif
len
(
inputs
)
==
2
:
size_a
,
size_b
=
[
rt
.
row_lengths
()
for
rt
in
inputs
]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half
=
limit
//
2
ceil_half
=
limit
-
floor_half
quota_a
=
tf
.
minimum
(
size_a
,
ceil_half
+
tf
.
nn
.
relu
(
floor_half
-
size_b
))
quota_b
=
tf
.
minimum
(
size_b
,
floor_half
+
tf
.
nn
.
relu
(
ceil_half
-
size_a
))
return
[
_truncate_row_lengths
(
inputs
[
0
],
quota_a
),
_truncate_row_lengths
(
inputs
[
1
],
quota_b
)]
else
:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity
=
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
inputs
])
# #Segments x B
new_capacity
=
_iterative_vectorized_fair_share
(
capacity
,
limit
)
return
[
_truncate_row_lengths
(
inputs
[
i
],
new_capacity
[
i
])
for
i
in
range
(
capacity
.
shape
[
0
])
]
def
_truncate_row_lengths
(
ragged_tensor
:
tf
.
RaggedTensor
,
new_lengths
:
tf
.
Tensor
)
->
tf
.
RaggedTensor
:
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
...
...
@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
# fall back to some ad-hoc truncation.
num_special_tokens
=
len
(
inputs
)
+
1
if
truncator
==
"round_robin"
:
trimmed_segments
=
r
ound
_r
obin
_truncate_inputs
(
inputs
,
seq_length
-
num_special_tokens
)
trimmed_segments
=
text
.
R
ound
R
obin
Trimmer
(
seq_length
-
num_special_tokens
)
.
trim
(
inputs
)
elif
truncator
==
"waterfall"
:
trimmed_segments
=
text
.
WaterfallTrimmer
(
seq_length
-
num_special_tokens
).
trim
(
inputs
)
...
...
official/nlp/modeling/layers/text_layers_test.py
View file @
78c43ef1
...
...
@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
from
official.nlp.modeling.layers
import
text_layers
class
RoundRobinTruncatorTest
(
tf
.
test
.
TestCase
):
def
_test_input
(
self
,
start
,
lengths
):
return
tf
.
ragged
.
constant
([[
start
+
10
*
j
+
i
for
i
in
range
(
length
)]
for
j
,
length
in
enumerate
(
lengths
)],
dtype
=
tf
.
int32
)
def
test_single_segment
(
self
):
# Single segment.
single_input
=
self
.
_test_input
(
11
,
[
4
,
5
,
6
])
expected_single_output
=
tf
.
ragged
.
constant
(
[[
11
,
12
,
13
,
14
],
[
21
,
22
,
23
,
24
,
25
],
[
31
,
32
,
33
,
34
,
35
],
# Truncated.
])
self
.
assertAllEqual
(
expected_single_output
,
text_layers
.
round_robin_truncate_inputs
(
single_input
,
limit
=
5
))
# Test wrapping in a singleton list.
actual_single_list_output
=
text_layers
.
round_robin_truncate_inputs
(
[
single_input
],
limit
=
5
)
self
.
assertIsInstance
(
actual_single_list_output
,
list
)
self
.
assertAllEqual
(
expected_single_output
,
actual_single_list_output
[
0
])
def
test_two_segments
(
self
):
input_a
=
self
.
_test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
])
input_b
=
self
.
_test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
])
expected_a
=
tf
.
ragged
.
constant
(
[[
111
],
[
121
,
122
],
[
131
,
132
],
[
141
,
142
,
143
],
[
151
,
152
,
153
],
# Truncated.
[
161
,
162
,
163
],
# Truncated.
])
expected_b
=
tf
.
ragged
.
constant
(
[[
211
],
[
221
,
222
,
223
],
[
231
,
232
,
233
],
# Truncated.
[
241
,
242
],
[
251
,
252
],
[
261
,
262
],
# Truncated.
])
actual_a
,
actual_b
=
text_layers
.
round_robin_truncate_inputs
(
[
input_a
,
input_b
],
limit
=
5
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
def
test_three_segments
(
self
):
input_a
=
self
.
_test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
,
1
])
input_b
=
self
.
_test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
,
8
])
input_c
=
self
.
_test_input
(
311
,
[
1
,
3
,
4
,
2
,
2
,
5
,
10
])
seg_limit
=
8
expected_a
=
tf
.
ragged
.
constant
([
[
111
],
[
121
,
122
],
[
131
,
132
],
[
141
,
142
,
143
],
[
151
,
152
,
153
,
154
],
[
161
,
162
,
163
],
# Truncated
[
171
]
])
expected_b
=
tf
.
ragged
.
constant
([
[
211
],
[
221
,
222
,
223
],
[
231
,
232
,
233
],
# Truncated
[
241
,
242
],
[
251
,
252
],
[
261
,
262
,
263
],
# Truncated
[
271
,
272
,
273
,
274
]
# Truncated
])
expected_c
=
tf
.
ragged
.
constant
([
[
311
],
[
321
,
322
,
323
],
[
331
,
332
,
333
],
# Truncated
[
341
,
342
],
[
351
,
352
],
[
361
,
362
],
# Truncated
[
371
,
372
,
373
]
# Truncated
])
actual_a
,
actual_b
,
actual_c
=
text_layers
.
round_robin_truncate_inputs
(
[
input_a
,
input_b
,
input_c
],
limit
=
seg_limit
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
self
.
assertAllEqual
(
expected_c
,
actual_c
)
input_cap
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
input_a
,
input_b
,
input_c
]]),
axis
=
0
)
per_example_usage
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
actual_a
,
actual_b
,
actual_c
]]),
axis
=
0
)
self
.
assertTrue
(
all
(
per_example_usage
<=
tf
.
minimum
(
seg_limit
,
input_cap
)))
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py.
...
...
official/nlp/modeling/models/bert_classifier.py
View file @
78c43ef1
...
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler'
, 'head_name'
) will be ignored.
"""
def
__init__
(
self
,
...
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
head_name
=
'sentence_prediction'
,
cls_head
=
None
,
**
kwargs
):
self
.
num_classes
=
num_classes
self
.
head_name
=
head_name
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
...
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes
=
num_classes
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
name
=
head_name
)
predictions
=
classifier
(
cls_inputs
)
...
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return
{
'network'
:
self
.
_network
,
'num_classes'
:
self
.
num_classes
,
'head_name'
:
self
.
head_name
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
_cls_head
,
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
78c43ef1
...
...
@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
def
_embedding_linear
(
self
,
embedding_matrix
,
x
):
"""Uses embeddings as linear transformation weights."""
embedding_matrix
=
tf
.
cast
(
embedding_matrix
,
dtype
=
self
.
compute_dtype
)
x
=
tf
.
cast
(
x
,
dtype
=
self
.
compute_dtype
)
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
1
]
hidden_size
=
tf
.
shape
(
x
)[
2
]
vocab_size
=
tf
.
shape
(
embedding_matrix
)[
0
]
x
=
tf
.
reshape
(
x
,
[
-
1
,
hidden_size
])
logits
=
tf
.
matmul
(
x
,
tf
.
cast
(
embedding_matrix
,
x
.
dtype
),
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
embedding_matrix
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
...
...
official/nlp/modeling/models/xlnet.py
View file @
78c43ef1
...
...
@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
"""
def
__init__
(
...
...
@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'random_normal'
,
summary_type
:
str
=
'last'
,
dropout_rate
:
float
=
0.1
,
head_name
:
str
=
'sentence_prediction'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_network
=
network
...
...
@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes'
:
num_classes
,
'summary_type'
:
summary_type
,
'dropout_rate'
:
dropout_rate
,
'head_name'
:
head_name
,
}
if
summary_type
==
'last'
:
...
...
@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
cls_token_idx
=
cls_token_idx
,
name
=
'sentence_prediction'
)
name
=
head_name
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_word_ids'
]
...
...
official/nlp/modeling/networks/bert_encoder.py
View file @
78c43ef1
...
...
@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
def
__init__
(
self
,
...
...
@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width
=
None
,
embedding_layer
=
None
,
dict_outputs
=
False
,
norm_first
=
False
,
**
kwargs
):
# b/164516224
...
...
@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer
=
initializer
,
output_range
=
output_range
,
embedding_width
=
embedding_width
,
embedding_layer
=
embedding_layer
)
embedding_layer
=
embedding_layer
,
norm_first
=
norm_first
)
self
.
_embedding_layer_instance
=
embedding_layer
...
...
official/nlp/modeling/networks/bert_encoder_test.py
View file @
78c43ef1
...
...
@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range
=-
1
,
embedding_width
=
16
,
dict_outputs
=
True
,
embedding_layer
=
None
)
embedding_layer
=
None
,
norm_first
=
False
)
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"activation"
]
=
tf
.
keras
.
activations
.
serialize
(
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
78c43ef1
...
...
@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding."""
import
abc
from
typing
import
Any
,
Callable
,
Dict
from
typing
import
Any
,
Callable
,
Dict
,
Optional
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
],
-
1
)
# Scatter sorted indices to original indexes.
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
top_p_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices
)
top_p_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
np
.
NINF
)
return
top_p_logits
...
...
@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape
=
decoding_module
.
shape_list
(
batch_indices
)
broad_casted_batch_dims
=
tf
.
reshape
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
tf
.
range
(
tensor_shape
[
0
]),
axis
=-
1
),
tensor_shape
),
[
1
,
-
1
])
tf
.
expand_dims
(
tf
.
range
(
tensor_shape
[
0
]),
axis
=-
1
),
tensor_shape
),
[
1
,
-
1
])
pair_indices
=
tf
.
transpose
(
tf
.
concat
([
broad_casted_batch_dims
,
tf
.
reshape
(
batch_indices
,
[
1
,
-
1
])],
0
))
return
tf
.
scatter_nd
(
pair_indices
,
tf
.
reshape
(
values
,
[
-
1
]),
tensor_shape
)
return
tf
.
scatter_nd
(
pair_indices
,
tf
.
reshape
(
values
,
[
-
1
]),
tensor_shape
)
def
set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
...
...
@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
Returns:
output_tensor: same shape as input_tensor.
"""
...
...
@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def
__init__
(
self
,
symbols_to_logits_fn
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
vocab_size
:
int
,
max_decode_length
:
int
,
eos_id
:
int
,
padded_decode
:
bool
,
length_normalization_fn
:
Optional
[
Callable
[[
int
,
tf
.
DType
],
float
]]
=
None
,
top_k
=
0
,
top_p
=
1.0
,
sample_temperature
=
0.0
,
...
...
@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self
.
max_decode_length
=
max_decode_length
self
.
top_k
=
tf
.
convert_to_tensor
(
top_k
,
dtype
=
tf
.
int32
)
self
.
top_p
=
tf
.
convert_to_tensor
(
top_p
,
dtype
=
tf
.
float32
)
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
...
...
@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return
state
,
state_shape_invariants
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
"""Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished.
...
...
@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
new_cache
}
def
_get_new_finished_state
(
self
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
def
_get_new_finished_state
(
self
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
batch_size
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
...
...
@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
self
.
dtype
)
alive_log_probs
=
alive_log_probs
/
length_norm
seq_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_seq
)
score_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_scores
)
seq_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_seq
)
score_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
...
...
official/nlp/train.py
View file @
78c43ef1
...
...
@@ -66,4 +66,5 @@ def main(_):
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
official/recommendation/data_pipeline.py
View file @
78c43ef1
...
...
@@ -29,17 +29,16 @@ import timeit
import
traceback
import
typing
from
absl
import
logging
import
numpy
as
np
import
six
from
six.moves
import
queue
import
tensorflow
as
tf
from
absl
import
logging
from
tensorflow.python.tpu.datasets
import
StreamingFilesDataset
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
popen_helper
from
official.recommendation
import
stat_utils
from
tensorflow.python.tpu.datasets
import
StreamingFilesDataset
SUMMARY_TEMPLATE
=
"""General:
{spacer}Num users: {num_users}
...
...
@@ -119,6 +118,7 @@ class DatasetManager(object):
"""Convert NumPy arrays into a TFRecords entry."""
def
create_int_feature
(
values
):
values
=
np
.
squeeze
(
values
)
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
feature_dict
=
{
...
...
official/recommendation/data_preprocessing.py
View file @
78c43ef1
...
...
@@ -23,21 +23,19 @@ import os
import
pickle
import
time
import
timeit
# pylint: disable=wro
ng
-
import
-order
import
typing
from
typi
ng
import
Dict
,
Text
,
Tuple
from
absl
import
logging
import
numpy
as
np
import
pandas
as
pd
import
tensorflow
as
tf
import
typing
from
typing
import
Dict
,
Text
,
Tuple
# pylint: enable=wrong-import-order
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
movielens
_EXPECTED_CACHE_KEYS
=
(
rconst
.
TRAIN_USER_KEY
,
rconst
.
TRAIN_ITEM_KEY
,
rconst
.
EVAL_USER_KEY
,
rconst
.
EVAL_ITEM_KEY
,
rconst
.
USER_MAP
,
rconst
.
ITEM_MAP
)
...
...
@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
logging
.
info
(
"Writing raw data cache."
)
with
tf
.
io
.
gfile
.
GFile
(
cache_path
,
"wb"
)
as
f
:
pickle
.
dump
(
data
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
data
,
f
,
protocol
=
4
)
# TODO(robieta): MLPerf cache clear.
return
data
,
valid_cache
...
...
official/recommendation/ranking/README.md
View file @
78c43ef1
...
...
@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
export
EXPERIMENT_NAME
=
my_experiment_name
export
BUCKET_NAME
=
"gs://my_dlrm_bucket"
export
DATA_DIR
=
"
${
BUCKET_NAME
}
/data"
export
EMBEDDING_DIM
=
32
python3 models/official/recommendation/ranking/train.py
--mode
=
train_and_eval
\
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
...
...
@@ -126,8 +127,8 @@ task:
global_batch_size: 16384
model:
num_dense_features: 13
bottom_mlp: [512,256,
128
]
embedding_dim:
128
bottom_mlp: [512,256,
${
EMBEDDING_DIM
}
]
embedding_dim:
${
EMBEDDING_DIM
}
top_mlp: [1024,1024,512,256,1]
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
...
...
@@ -135,8 +136,8 @@ task:
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer:
use_orbit: true
validation_interval:
90000
checkpoint_interval:
100000
validation_interval:
85352
checkpoint_interval:
85352
validation_steps: 5440
train_steps: 256054
steps_per_loop: 1000
...
...
@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs):
```
shell
python3 official/recommendation/ranking/main.py
--mode
=
train_and_eval
\
export
EMBEDDING_DIM
=
8
python3 official/recommendation/ranking/train.py
--mode
=
train_and_eval
\
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
runtime:
distribution_strategy: 'mirrored'
...
...
official/
utils/misc/distribution_utils
.py
→
official/
recommendation/ranking/configs/__init__
.py
View file @
78c43ef1
...
...
@@ -12,6 +12,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from
official.common.distribute_utils
import
*
official/recommendation/ranking/configs/config.py
View file @
78c43ef1
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Ranking Model configuration definition."""
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Union
import
dataclasses
from
official.core
import
exp_factory
...
...
@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
top_mlp: The sizes of hidden layers for top MLP.
...
...
@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
"""
num_dense_features
:
int
=
13
vocab_sizes
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
embedding_dim
:
int
=
8
embedding_dim
:
Union
[
int
,
List
[
int
]]
=
8
size_threshold
:
int
=
50_000
bottom_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
top_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
interaction
:
str
=
'dot'
...
...
@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime
=
cfg
.
RuntimeConfig
(),
task
=
Task
(
model
=
ModelConfig
(
embedding_dim
=
4
,
embedding_dim
=
8
,
vocab_sizes
=
vocab_sizes
,
bottom_mlp
=
[
64
,
32
,
4
],
top_mlp
=
[
64
,
32
,
1
]),
...
...
official/recommendation/ranking/data/__init__.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/recommendation/ranking/data_pipeline.py
→
official/recommendation/ranking/data
/data
_pipeline.py
View file @
78c43ef1
...
...
@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas
=
ctx
.
num_replicas_in_sync
if
ctx
else
1
if
params
.
is_training
:
dataset_size
=
1000
0
*
batch_size
*
num_replicas
dataset_size
=
1000
*
batch_size
*
num_replicas
else
:
dataset_size
=
1000
*
batch_size
*
num_replicas
dense_tensor
=
tf
.
random
.
uniform
(
...
...
@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features'
:
sparse_tensor_elements
},
label_tensor
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
input_elem
)
dataset
=
dataset
.
cache
()
if
params
.
is_training
:
dataset
=
dataset
.
repeat
()
...
...
official/recommendation/ranking/data_pipeline_test.py
→
official/recommendation/ranking/data
/data
_pipeline_test.py
View file @
78c43ef1
...
...
@@ -17,8 +17,8 @@
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
class
DataPipelineTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/recommendation/ranking/task.py
View file @
78c43ef1
...
...
@@ -15,7 +15,7 @@
"""Task for the Ranking model."""
import
math
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
import
tensorflow
as
tf
import
tensorflow_recommenders
as
tfrs
...
...
@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.recommendation.ranking
import
common
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
RuntimeConfig
=
config_definitions
.
RuntimeConfig
def
_get_tpu_embedding_feature_config
(
vocab_sizes
:
List
[
int
],
embedding_dim
:
int
,
embedding_dim
:
Union
[
int
,
List
[
int
]]
,
table_name_prefix
:
str
=
'embedding_table'
)
->
Dict
[
str
,
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
]:
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args:
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim:
E
mbedding dimension.
embedding_dim:
An integer or a list of e
mbedding
table
dimension
s
.
table_name_prefix: a prefix for embedding tables.
Returns:
A dictionary of feature_name, FeatureConfig pairs.
"""
if
isinstance
(
embedding_dim
,
List
):
if
len
(
vocab_sizes
)
!=
len
(
embedding_dim
):
raise
ValueError
(
f
'length of vocab_sizes:
{
len
(
vocab_sizes
)
}
is not equal to the '
f
'length of embedding_dim:
{
len
(
embedding_dim
)
}
'
)
elif
isinstance
(
embedding_dim
,
int
):
embedding_dim
=
[
embedding_dim
]
*
len
(
vocab_sizes
)
else
:
raise
ValueError
(
'embedding_dim is not either a list or an int, got '
f
'
{
type
(
embedding_dim
)
}
'
)
feature_config
=
{}
for
i
,
vocab_size
in
enumerate
(
vocab_sizes
):
table_config
=
tf
.
tpu
.
experimental
.
embedding
.
TableConfig
(
vocabulary_size
=
vocab_size
,
dim
=
embedding_dim
,
dim
=
embedding_dim
[
i
]
,
combiner
=
'mean'
,
initializer
=
tf
.
initializers
.
TruncatedNormal
(
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
)),
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
[
i
]
)),
name
=
table_name_prefix
+
'_%s'
%
i
)
feature_config
[
str
(
i
)]
=
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
(
table
=
table_config
)
...
...
@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization.
Args:
params: the Ran
n
kingModel task configuration instance.
params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
...
...
@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self
.
optimizer_config
.
embedding_optimizer
)
embedding_optimizer
.
learning_rate
=
lr_callable
emb_
feature_config
=
_get_tpu_embedding_feature_config
(
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
,
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
)
feature_config
=
_get_tpu_embedding_feature_config
(
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
,
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
)
tpu_embedding
=
tfrs
.
layers
.
embedding
.
TPUEmbedding
(
emb_feature_config
,
embedding_optimizer
)
embedding_layer
=
tfrs
.
experimental
.
layers
.
embedding
.
PartialTPUEmbedding
(
feature_config
=
feature_config
,
optimizer
=
embedding_optimizer
,
size_threshold
=
self
.
task_config
.
model
.
size_threshold
)
if
self
.
task_config
.
model
.
interaction
==
'dot'
:
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
()
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
(
skip_gather
=
True
)
elif
self
.
task_config
.
model
.
interaction
==
'cross'
:
feature_interaction
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Concatenate
(),
...
...
@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f
'is not supported it must be either
\'
dot
\'
or
\'
cross
\'
.'
)
model
=
tfrs
.
experimental
.
models
.
Ranking
(
embedding_layer
=
tpu_
embedding
,
embedding_layer
=
embedding
_layer
,
bottom_stack
=
tfrs
.
layers
.
blocks
.
MLP
(
units
=
self
.
task_config
.
model
.
bottom_mlp
,
final_activation
=
'relu'
),
feature_interaction
=
feature_interaction
,
...
...
@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@
property
def
optimizer_config
(
self
)
->
config
.
OptimizationConfig
:
return
self
.
_optimizer_config
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment