Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
78c43ef1
"vscode:/vscode.git/clone" did not exist on "abf75ac039420f7a4ab64a419416dd493b906742"
Commit
78c43ef1
authored
Jul 26, 2021
by
Gunho Park
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
67cfc95b
e3c7e300
Changes
227
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
157 additions
and
319 deletions
+157
-319
official/nlp/modeling/layers/kernel_attention.py
official/nlp/modeling/layers/kernel_attention.py
+36
-36
official/nlp/modeling/layers/kernel_attention_test.py
official/nlp/modeling/layers/kernel_attention_test.py
+1
-1
official/nlp/modeling/layers/text_layers.py
official/nlp/modeling/layers/text_layers.py
+2
-117
official/nlp/modeling/layers/text_layers_test.py
official/nlp/modeling/layers/text_layers_test.py
+0
-96
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+6
-2
official/nlp/modeling/models/seq2seq_transformer.py
official/nlp/modeling/models/seq2seq_transformer.py
+3
-1
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+4
-1
official/nlp/modeling/networks/bert_encoder.py
official/nlp/modeling/networks/bert_encoder.py
+6
-1
official/nlp/modeling/networks/bert_encoder_test.py
official/nlp/modeling/networks/bert_encoder_test.py
+2
-1
official/nlp/modeling/ops/sampling_module.py
official/nlp/modeling/ops/sampling_module.py
+20
-25
official/nlp/train.py
official/nlp/train.py
+1
-0
official/recommendation/data_pipeline.py
official/recommendation/data_pipeline.py
+3
-3
official/recommendation/data_preprocessing.py
official/recommendation/data_preprocessing.py
+4
-6
official/recommendation/ranking/README.md
official/recommendation/ranking/README.md
+8
-5
official/recommendation/ranking/configs/__init__.py
official/recommendation/ranking/configs/__init__.py
+0
-3
official/recommendation/ranking/configs/config.py
official/recommendation/ranking/configs/config.py
+11
-4
official/recommendation/ranking/data/__init__.py
official/recommendation/ranking/data/__init__.py
+14
-0
official/recommendation/ranking/data/data_pipeline.py
official/recommendation/ranking/data/data_pipeline.py
+2
-1
official/recommendation/ranking/data/data_pipeline_test.py
official/recommendation/ranking/data/data_pipeline_test.py
+1
-1
official/recommendation/ranking/task.py
official/recommendation/ranking/task.py
+33
-15
No files found.
official/nlp/modeling/layers/kernel_attention.py
View file @
78c43ef1
...
@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
...
@@ -85,30 +85,20 @@ def create_projection_matrix(m, d, seed=None):
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
return
tf
.
linalg
.
matmul
(
tf
.
linalg
.
diag
(
multiplier
),
final_matrix
)
def
_generalized_kernel
(
x
,
projection_matrix
,
is_query
,
f
,
h
,
def
_generalized_kernel
(
x
,
projection_matrix
,
f
,
h
):
data_normalizer_fn
=
None
):
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
"""Generalized kernel in RETHINKING ATTENTION WITH PERFORMERS.
Args:
Args:
x: The feature being transformed with shape [B, T, N ,H].
x: The feature being transformed with shape [B, T, N ,H].
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
projection_matrix: The matrix with shape [M, H] that we projecct x to, where
M is the number of projections.
M is the number of projections.
is_query: Whether the transform is a query or key. This transform is
symmetric is the argument is not used.
f: A non-linear function applied on x or projected x.
f: A non-linear function applied on x or projected x.
h: A muliplier which is a function of x applied after projected and
h: A muliplier which is a function of x applied after projected and
transformed. Only applied if projection_matrix is not None.
transformed. Only applied if projection_matrix is not None.
data_normalizer_fn: A function which takes x and returns a scalar that
normalize data.
Returns:
Returns:
Transformed feature.
Transformed feature.
"""
"""
# No asymmetric operations.
del
is_query
if
data_normalizer_fn
is
not
None
:
x
=
data_normalizer_fn
(
x
)
if
projection_matrix
is
None
:
if
projection_matrix
is
None
:
return
h
(
x
)
*
f
(
x
)
return
h
(
x
)
*
f
(
x
)
...
@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
...
@@ -139,26 +129,18 @@ _TRANSFORM_MAP = {
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
reduce_sum
(
-
0.5
*
tf
.
math
.
reduce_sum
(
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),
tf
.
math
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)),),
data_normalizer_fn
=
lambda
x
:
x
/
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"expmod"
:
"expmod"
:
functools
.
partial
(
functools
.
partial
(
_generalized_kernel
,
_generalized_kernel
,
# Avoid exp explosion by shifting.
# Avoid exp explosion by shifting.
f
=
lambda
x
:
tf
.
math
.
exp
(
f
=
lambda
x
:
tf
.
math
.
exp
(
x
-
tf
.
math
.
reduce_max
(
x
-
tf
.
math
.
reduce_max
(
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
x
,
axis
=
[
1
,
2
,
3
],
keepdims
=
True
)),
h
=
lambda
x
:
tf
.
math
.
exp
(
h
=
lambda
x
:
tf
.
math
.
exp
(
-
0.5
*
tf
.
math
.
sqrt
(
-
0.5
*
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))),
data_normalizer_fn
=
lambda
x
:
x
/
),
(
tf
.
math
.
sqrt
(
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
))))),
"identity"
:
"l2"
:
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
1
)
functools
.
partial
(
_generalized_kernel
,
f
=
lambda
x
:
x
,
h
=
lambda
x
:
tf
.
math
.
sqrt
(
tf
.
cast
(
tf
.
shape
(
x
)[
-
1
],
tf
.
float32
)),
data_normalizer_fn
=
lambda
x
:
x
),
"identity"
:
lambda
x
,
projection_matrix
,
is_query
:
x
}
}
# pylint: enable=g-long-lambda
# pylint: enable=g-long-lambda
...
@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -170,7 +152,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Rethinking Attention with Performers
Rethinking Attention with Performers
(https://arxiv.org/abs/2009.14794)
(https://arxiv.org/abs/2009.14794)
- exp (Lemma 1, positive), relu
, l2
- exp (Lemma 1, positive), relu
- random/deterministic projection
- random/deterministic projection
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
...
@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -195,14 +177,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
redraw
=
False
,
redraw
=
False
,
is_short_seq
=
False
,
is_short_seq
=
False
,
begin_kernel
=
0
,
begin_kernel
=
0
,
scale
=
None
,
**
kwargs
):
**
kwargs
):
r
"""Constructor of KernelAttention.
r
"""Constructor of KernelAttention.
Args:
Args:
feature_transform: A non-linear transform of the keys and quries.
feature_transform: A non-linear transform of the keys and quries.
Possible transforms are "elu", "relu", "square", "exp", "expmod",
Possible transforms are "elu", "relu", "square", "exp", "expmod",
"l2", "identity". If <is_short_seq> = True, it is recommended to choose
"identity".
feature_transform as "l2".
num_random_features: Number of random features to be used for projection.
num_random_features: Number of random features to be used for projection.
if num_random_features <= 0, no production is used before transform.
if num_random_features <= 0, no production is used before transform.
seed: The seed to begin drawing random features. Once the seed is set, the
seed: The seed to begin drawing random features. Once the seed is set, the
...
@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -216,6 +198,8 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
(default option).
(default option).
begin_kernel: Apply kernel_attention after this sequence id and apply
begin_kernel: Apply kernel_attention after this sequence id and apply
softmax attention before this.
softmax attention before this.
scale: The value to scale the dot product as described in `Attention Is
All You Need`. If None, we use 1/sqrt(dk) as described in the paper.
**kwargs: The same arguments `MultiHeadAttention` layer.
**kwargs: The same arguments `MultiHeadAttention` layer.
"""
"""
if
feature_transform
not
in
_TRANSFORM_MAP
:
if
feature_transform
not
in
_TRANSFORM_MAP
:
...
@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -234,8 +218,11 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
# 1. inference
# 1. inference
# 2. no redraw
# 2. no redraw
self
.
_seed
=
seed
self
.
_seed
=
seed
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
scale
is
None
:
self
.
_scale
=
1.0
/
math
.
sqrt
(
float
(
self
.
_key_dim
))
else
:
self
.
_scale
=
scale
self
.
_projection_matrix
=
None
self
.
_projection_matrix
=
None
if
num_random_features
>
0
:
if
num_random_features
>
0
:
self
.
_projection_matrix
=
create_projection_matrix
(
self
.
_projection_matrix
=
create_projection_matrix
(
...
@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -275,7 +262,6 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
Returns:
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_output: Multi-headed outputs of attention computation.
"""
"""
projection_matrix
=
None
projection_matrix
=
None
if
self
.
_num_random_features
>
0
:
if
self
.
_num_random_features
>
0
:
if
self
.
_redraw
and
training
:
if
self
.
_redraw
and
training
:
...
@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -284,8 +270,20 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
else
:
else
:
projection_matrix
=
self
.
_projection_matrix
projection_matrix
=
self
.
_projection_matrix
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
,
False
)
if
is_short_seq
:
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
,
True
)
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query
=
query
*
self
.
_scale
else
:
# Note: we suspect spliting the scale to key, query yields smaller
# approximation variance when random projection is used.
# For simplicity, we also split when there's no random projection.
key
*=
math
.
sqrt
(
self
.
_scale
)
query
*=
math
.
sqrt
(
self
.
_scale
)
key
=
_TRANSFORM_MAP
[
feature_transform
](
key
,
projection_matrix
)
query
=
_TRANSFORM_MAP
[
feature_transform
](
query
,
projection_matrix
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
key
=
tf
.
einsum
(
"BSNH,BS->BSNH"
,
key
,
attention_mask
)
...
@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -294,13 +292,14 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
einsum
(
"BTNH,BSNH->BTSN"
,
query
,
key
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_scores
=
tf
.
nn
.
softmax
(
attention_scores
,
axis
=
2
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
attention_output
=
tf
.
einsum
(
"BTSN,BSNH->BTNH"
,
attention_scores
,
value
)
return
attention_output
else
:
else
:
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
kv
=
tf
.
einsum
(
"BSNH,BSND->BNDH"
,
key
,
value
)
denominator
=
1.0
/
(
denominator
=
1.0
/
(
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
tf
.
einsum
(
"BTNH,BNH->BTN"
,
query
,
tf
.
reduce_sum
(
key
,
axis
=
1
))
+
_NUMERIC_STABLER
)
_NUMERIC_STABLER
)
return
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
attention_output
=
tf
.
einsum
(
"BTNH,BNDH,BTN->BTND"
,
query
,
kv
,
denominator
)
return
attention_output
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
def
_build_from_signature
(
self
,
query
,
value
,
key
=
None
):
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
super
().
_build_from_signature
(
query
=
query
,
value
=
value
,
key
=
key
)
...
@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
...
@@ -391,6 +390,7 @@ class KernelAttention(tf.keras.layers.MultiHeadAttention):
"redraw"
:
self
.
_redraw
,
"redraw"
:
self
.
_redraw
,
"is_short_seq"
:
self
.
_is_short_seq
,
"is_short_seq"
:
self
.
_is_short_seq
,
"begin_kernel"
:
self
.
_begin_kernel
,
"begin_kernel"
:
self
.
_begin_kernel
,
"scale"
:
self
.
_scale
,
}
}
base_config
=
super
().
get_config
()
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
official/nlp/modeling/layers/kernel_attention_test.py
View file @
78c43ef1
...
@@ -21,7 +21,7 @@ import tensorflow as tf
...
@@ -21,7 +21,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
from
official.nlp.modeling.layers
import
kernel_attention
as
attention
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
,
'l2'
]
_FEATURE_TRANSFORM
=
[
'relu'
,
'elu'
,
'exp'
]
_REDRAW
=
[
True
,
False
]
_REDRAW
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_TRAINING
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
_IS_SHORT_SEQ
=
[
True
,
False
]
...
...
official/nlp/modeling/layers/text_layers.py
View file @
78c43ef1
...
@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
...
@@ -33,121 +33,6 @@ def _check_if_tf_text_installed():
"'tensorflow-text-nightly'."
)
"'tensorflow-text-nightly'."
)
def
_iterative_vectorized_fair_share
(
capacity
:
tf
.
Tensor
,
limit
:
Union
[
int
,
tf
.
Tensor
]):
"""Iterative algorithm for max min fairness algorithm.
Reference: https://en.wikipedia.org/wiki/Max-min_fairness
The idea is for each example with some number of segments and a limit of
total segment length allowed, we grant each segment a fair share of the
limit. For example, if every segment has the same length, no work to do.
If one segment has below average length, its share will be spilt to others
fairly. In this way, the longest segment will be the shortest among all
potential capacity assignments.
Args:
capacity: A rank-2 Tensor of #Segments x Batch.
limit: The largest permissible number of tokens in total across one example.
Returns:
A rank-2 Tensor with new segment capacity assignment such that
the total number of tokens in each example does not exceed the `limit`.
"""
# Firstly, we calculate the lower bound of the capacity assignment.
per_seg_limit
=
limit
//
capacity
.
shape
[
0
]
limit_mask
=
tf
.
ones
(
capacity
.
shape
,
dtype
=
tf
.
int64
)
*
per_seg_limit
lower_bound
=
tf
.
minimum
(
capacity
,
limit_mask
)
# This step makes up the capacity that already statisfy the capacity limit.
remaining_cap_sum
=
limit
-
tf
.
math
.
reduce_sum
(
lower_bound
,
axis
=
0
)
remaining_cap_mat
=
capacity
-
lower_bound
new_cap
=
lower_bound
+
remaining_cap_mat
*
tf
.
cast
(
tf
.
math
.
reduce_sum
(
remaining_cap_mat
,
axis
=
0
)
<=
remaining_cap_sum
,
tf
.
int64
)
# Process iteratively. This step is O(#segments), see analysis below.
while
True
:
remaining_limit
=
limit
-
tf
.
math
.
reduce_sum
(
new_cap
,
axis
=
0
)
remaining_cap
=
capacity
-
new_cap
masked_remaining_slots
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
remaining_cap_col_slots
=
tf
.
reduce_sum
(
masked_remaining_slots
,
axis
=
0
)
masked_remaining_limit
=
tf
.
cast
(
remaining_cap_col_slots
>
0
,
tf
.
int64
)
*
remaining_limit
# Total remaining segment limit is different for each example.
per_seg_limit
=
masked_remaining_limit
//
(
tf
.
cast
(
remaining_cap_col_slots
<=
0
,
tf
.
int64
)
+
remaining_cap_col_slots
)
# +1 to make sure 0/0 = 0
# Note that for each step, there is at least one more segment being
# fulfilled or the loop is finished.
# The idea is, if remaining per example limit > smallest among segments,
# the smallest segment ask is fullfilled. Otherwise, all remaining segments
# are truncated, the assignment is finished.
if
tf
.
math
.
reduce_sum
(
per_seg_limit
)
>
0
:
remaining_slots_mat
=
tf
.
cast
(
remaining_cap
>
0
,
tf
.
int64
)
new_cap
=
new_cap
+
remaining_slots_mat
*
per_seg_limit
else
:
# Leftover assignment of limit that is smaller than #slots.
new_remained_assignment_mask
=
tf
.
cast
(
(
tf
.
cumsum
(
masked_remaining_slots
,
axis
=
0
)
<=
masked_remaining_limit
)
&
(
masked_remaining_slots
>
0
),
tf
.
int64
)
new_cap
=
new_cap
+
new_remained_assignment_mask
break
return
new_cap
def
round_robin_truncate_inputs
(
inputs
:
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]],
limit
:
Union
[
int
,
tf
.
Tensor
],
)
->
Union
[
tf
.
RaggedTensor
,
List
[
tf
.
RaggedTensor
]]:
"""Truncates a list of batched segments to fit a per-example length limit.
Available space is assigned one token at a time in a round-robin fashion
to the inputs that still need some, until the limit is reached.
(Or equivalently: the longest input is truncated by one token until the total
length of inputs fits the limit.) Examples that fit the limit as passed in
remain unchanged.
Args:
inputs: A list of rank-2 RaggedTensors. The i-th example is given by
the i-th row in each list element, that is, `inputs[:][i, :]`.
limit: The largest permissible number of tokens in total across one example.
Returns:
A list of rank-2 RaggedTensors at corresponding indices with the inputs,
in which the rows of each RaggedTensor have been truncated such that
the total number of tokens in each example does not exceed the `limit`.
"""
if
not
isinstance
(
inputs
,
(
list
,
tuple
)):
return
round_robin_truncate_inputs
([
inputs
],
limit
)[
0
]
limit
=
tf
.
cast
(
limit
,
tf
.
int64
)
if
not
all
(
rt
.
shape
.
rank
==
2
for
rt
in
inputs
):
raise
ValueError
(
"All inputs must have shape [batch_size, (items)]"
)
if
len
(
inputs
)
==
1
:
return
[
_truncate_row_lengths
(
inputs
[
0
],
limit
)]
elif
len
(
inputs
)
==
2
:
size_a
,
size_b
=
[
rt
.
row_lengths
()
for
rt
in
inputs
]
# Here's a brain-twister: This does round-robin assignment of quota
# to both inputs until the limit is reached. Hint: consider separately
# the cases of zero, one, or two inputs exceeding half the limit.
floor_half
=
limit
//
2
ceil_half
=
limit
-
floor_half
quota_a
=
tf
.
minimum
(
size_a
,
ceil_half
+
tf
.
nn
.
relu
(
floor_half
-
size_b
))
quota_b
=
tf
.
minimum
(
size_b
,
floor_half
+
tf
.
nn
.
relu
(
ceil_half
-
size_a
))
return
[
_truncate_row_lengths
(
inputs
[
0
],
quota_a
),
_truncate_row_lengths
(
inputs
[
1
],
quota_b
)]
else
:
# Note that we don't merge with the 2 input case because the full algorithm
# is more expensive.
capacity
=
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
inputs
])
# #Segments x B
new_capacity
=
_iterative_vectorized_fair_share
(
capacity
,
limit
)
return
[
_truncate_row_lengths
(
inputs
[
i
],
new_capacity
[
i
])
for
i
in
range
(
capacity
.
shape
[
0
])
]
def
_truncate_row_lengths
(
ragged_tensor
:
tf
.
RaggedTensor
,
def
_truncate_row_lengths
(
ragged_tensor
:
tf
.
RaggedTensor
,
new_lengths
:
tf
.
Tensor
)
->
tf
.
RaggedTensor
:
new_lengths
:
tf
.
Tensor
)
->
tf
.
RaggedTensor
:
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
"""Truncates the rows of `ragged_tensor` to the given row lengths."""
...
@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
...
@@ -675,8 +560,8 @@ class BertPackInputs(tf.keras.layers.Layer):
# fall back to some ad-hoc truncation.
# fall back to some ad-hoc truncation.
num_special_tokens
=
len
(
inputs
)
+
1
num_special_tokens
=
len
(
inputs
)
+
1
if
truncator
==
"round_robin"
:
if
truncator
==
"round_robin"
:
trimmed_segments
=
r
ound
_r
obin
_truncate_inputs
(
trimmed_segments
=
text
.
R
ound
R
obin
Trimmer
(
seq_length
-
inputs
,
seq_length
-
num_special_tokens
)
num_special_tokens
)
.
trim
(
inputs
)
elif
truncator
==
"waterfall"
:
elif
truncator
==
"waterfall"
:
trimmed_segments
=
text
.
WaterfallTrimmer
(
trimmed_segments
=
text
.
WaterfallTrimmer
(
seq_length
-
num_special_tokens
).
trim
(
inputs
)
seq_length
-
num_special_tokens
).
trim
(
inputs
)
...
...
official/nlp/modeling/layers/text_layers_test.py
View file @
78c43ef1
...
@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
...
@@ -24,102 +24,6 @@ from sentencepiece import SentencePieceTrainer
from
official.nlp.modeling.layers
import
text_layers
from
official.nlp.modeling.layers
import
text_layers
class
RoundRobinTruncatorTest
(
tf
.
test
.
TestCase
):
def
_test_input
(
self
,
start
,
lengths
):
return
tf
.
ragged
.
constant
([[
start
+
10
*
j
+
i
for
i
in
range
(
length
)]
for
j
,
length
in
enumerate
(
lengths
)],
dtype
=
tf
.
int32
)
def
test_single_segment
(
self
):
# Single segment.
single_input
=
self
.
_test_input
(
11
,
[
4
,
5
,
6
])
expected_single_output
=
tf
.
ragged
.
constant
(
[[
11
,
12
,
13
,
14
],
[
21
,
22
,
23
,
24
,
25
],
[
31
,
32
,
33
,
34
,
35
],
# Truncated.
])
self
.
assertAllEqual
(
expected_single_output
,
text_layers
.
round_robin_truncate_inputs
(
single_input
,
limit
=
5
))
# Test wrapping in a singleton list.
actual_single_list_output
=
text_layers
.
round_robin_truncate_inputs
(
[
single_input
],
limit
=
5
)
self
.
assertIsInstance
(
actual_single_list_output
,
list
)
self
.
assertAllEqual
(
expected_single_output
,
actual_single_list_output
[
0
])
def
test_two_segments
(
self
):
input_a
=
self
.
_test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
])
input_b
=
self
.
_test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
])
expected_a
=
tf
.
ragged
.
constant
(
[[
111
],
[
121
,
122
],
[
131
,
132
],
[
141
,
142
,
143
],
[
151
,
152
,
153
],
# Truncated.
[
161
,
162
,
163
],
# Truncated.
])
expected_b
=
tf
.
ragged
.
constant
(
[[
211
],
[
221
,
222
,
223
],
[
231
,
232
,
233
],
# Truncated.
[
241
,
242
],
[
251
,
252
],
[
261
,
262
],
# Truncated.
])
actual_a
,
actual_b
=
text_layers
.
round_robin_truncate_inputs
(
[
input_a
,
input_b
],
limit
=
5
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
def
test_three_segments
(
self
):
input_a
=
self
.
_test_input
(
111
,
[
1
,
2
,
2
,
3
,
4
,
5
,
1
])
input_b
=
self
.
_test_input
(
211
,
[
1
,
3
,
4
,
2
,
2
,
5
,
8
])
input_c
=
self
.
_test_input
(
311
,
[
1
,
3
,
4
,
2
,
2
,
5
,
10
])
seg_limit
=
8
expected_a
=
tf
.
ragged
.
constant
([
[
111
],
[
121
,
122
],
[
131
,
132
],
[
141
,
142
,
143
],
[
151
,
152
,
153
,
154
],
[
161
,
162
,
163
],
# Truncated
[
171
]
])
expected_b
=
tf
.
ragged
.
constant
([
[
211
],
[
221
,
222
,
223
],
[
231
,
232
,
233
],
# Truncated
[
241
,
242
],
[
251
,
252
],
[
261
,
262
,
263
],
# Truncated
[
271
,
272
,
273
,
274
]
# Truncated
])
expected_c
=
tf
.
ragged
.
constant
([
[
311
],
[
321
,
322
,
323
],
[
331
,
332
,
333
],
# Truncated
[
341
,
342
],
[
351
,
352
],
[
361
,
362
],
# Truncated
[
371
,
372
,
373
]
# Truncated
])
actual_a
,
actual_b
,
actual_c
=
text_layers
.
round_robin_truncate_inputs
(
[
input_a
,
input_b
,
input_c
],
limit
=
seg_limit
)
self
.
assertAllEqual
(
expected_a
,
actual_a
)
self
.
assertAllEqual
(
expected_b
,
actual_b
)
self
.
assertAllEqual
(
expected_c
,
actual_c
)
input_cap
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
input_a
,
input_b
,
input_c
]]),
axis
=
0
)
per_example_usage
=
tf
.
math
.
reduce_sum
(
tf
.
stack
([
rt
.
row_lengths
()
for
rt
in
[
actual_a
,
actual_b
,
actual_c
]]),
axis
=
0
)
self
.
assertTrue
(
all
(
per_example_usage
<=
tf
.
minimum
(
seg_limit
,
input_cap
)))
# This test covers the in-process behavior of a BertTokenizer layer.
# This test covers the in-process behavior of a BertTokenizer layer.
# For saving, restoring, and the restored behavior (incl. shape inference),
# For saving, restoring, and the restored behavior (incl. shape inference),
# see nlp/tools/export_tfhub_lib_test.py.
# see nlp/tools/export_tfhub_lib_test.py.
...
...
official/nlp/modeling/models/bert_classifier.py
View file @
78c43ef1
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
...
@@ -45,10 +45,11 @@ class BertClassifier(tf.keras.Model):
dropout_rate: The dropout probability of the cls head.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
use_encoder_pooler: Whether to use the pooler layer pre-defined inside the
encoder.
encoder.
head_name: Name of the classification head.
cls_head: (Optional) The layer instance to use for the classifier head.
cls_head: (Optional) The layer instance to use for the classifier head.
It should take in the output from network and produce the final logits.
It should take in the output from network and produce the final logits.
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
If set, the arguments ('num_classes', 'initializer', 'dropout_rate',
'use_encoder_pooler') will be ignored.
'use_encoder_pooler'
, 'head_name'
) will be ignored.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
...
@@ -57,9 +58,11 @@ class BertClassifier(tf.keras.Model):
initializer
=
'glorot_uniform'
,
initializer
=
'glorot_uniform'
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
use_encoder_pooler
=
True
,
use_encoder_pooler
=
True
,
head_name
=
'sentence_prediction'
,
cls_head
=
None
,
cls_head
=
None
,
**
kwargs
):
**
kwargs
):
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
head_name
=
head_name
self
.
initializer
=
initializer
self
.
initializer
=
initializer
self
.
use_encoder_pooler
=
use_encoder_pooler
self
.
use_encoder_pooler
=
use_encoder_pooler
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
...
@@ -92,7 +95,7 @@ class BertClassifier(tf.keras.Model):
num_classes
=
num_classes
,
num_classes
=
num_classes
,
initializer
=
initializer
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
name
=
'sentence_prediction'
)
name
=
head_name
)
predictions
=
classifier
(
cls_inputs
)
predictions
=
classifier
(
cls_inputs
)
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
...
@@ -137,6 +140,7 @@ class BertClassifier(tf.keras.Model):
return
{
return
{
'network'
:
self
.
_network
,
'network'
:
self
.
_network
,
'num_classes'
:
self
.
num_classes
,
'num_classes'
:
self
.
num_classes
,
'head_name'
:
self
.
head_name
,
'initializer'
:
self
.
initializer
,
'initializer'
:
self
.
initializer
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'use_encoder_pooler'
:
self
.
use_encoder_pooler
,
'cls_head'
:
self
.
_cls_head
,
'cls_head'
:
self
.
_cls_head
,
...
...
official/nlp/modeling/models/seq2seq_transformer.py
View file @
78c43ef1
...
@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
...
@@ -111,13 +111,15 @@ class Seq2SeqTransformer(tf.keras.Model):
def
_embedding_linear
(
self
,
embedding_matrix
,
x
):
def
_embedding_linear
(
self
,
embedding_matrix
,
x
):
"""Uses embeddings as linear transformation weights."""
"""Uses embeddings as linear transformation weights."""
embedding_matrix
=
tf
.
cast
(
embedding_matrix
,
dtype
=
self
.
compute_dtype
)
x
=
tf
.
cast
(
x
,
dtype
=
self
.
compute_dtype
)
batch_size
=
tf
.
shape
(
x
)[
0
]
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
1
]
length
=
tf
.
shape
(
x
)[
1
]
hidden_size
=
tf
.
shape
(
x
)[
2
]
hidden_size
=
tf
.
shape
(
x
)[
2
]
vocab_size
=
tf
.
shape
(
embedding_matrix
)[
0
]
vocab_size
=
tf
.
shape
(
embedding_matrix
)[
0
]
x
=
tf
.
reshape
(
x
,
[
-
1
,
hidden_size
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
hidden_size
])
logits
=
tf
.
matmul
(
x
,
tf
.
cast
(
embedding_matrix
,
x
.
dtype
),
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
embedding_matrix
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
vocab_size
])
...
...
official/nlp/modeling/models/xlnet.py
View file @
78c43ef1
...
@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -171,6 +171,7 @@ class XLNetClassifier(tf.keras.Model):
Defaults to a RandomNormal initializer.
Defaults to a RandomNormal initializer.
summary_type: Method used to summarize a sequence into a compact vector.
summary_type: Method used to summarize a sequence into a compact vector.
dropout_rate: The dropout probability of the cls head.
dropout_rate: The dropout probability of the cls head.
head_name: Name of the classification head.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -180,6 +181,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'random_normal'
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'random_normal'
,
summary_type
:
str
=
'last'
,
summary_type
:
str
=
'last'
,
dropout_rate
:
float
=
0.1
,
dropout_rate
:
float
=
0.1
,
head_name
:
str
=
'sentence_prediction'
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_network
=
network
self
.
_network
=
network
...
@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -192,6 +194,7 @@ class XLNetClassifier(tf.keras.Model):
'num_classes'
:
num_classes
,
'num_classes'
:
num_classes
,
'summary_type'
:
summary_type
,
'summary_type'
:
summary_type
,
'dropout_rate'
:
dropout_rate
,
'dropout_rate'
:
dropout_rate
,
'head_name'
:
head_name
,
}
}
if
summary_type
==
'last'
:
if
summary_type
==
'last'
:
...
@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -207,7 +210,7 @@ class XLNetClassifier(tf.keras.Model):
initializer
=
initializer
,
initializer
=
initializer
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
cls_token_idx
=
cls_token_idx
,
cls_token_idx
=
cls_token_idx
,
name
=
'sentence_prediction'
)
name
=
head_name
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_word_ids'
]
input_ids
=
inputs
[
'input_word_ids'
]
...
...
official/nlp/modeling/networks/bert_encoder.py
View file @
78c43ef1
...
@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
...
@@ -77,6 +77,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
parameter is originally added for ELECTRA model which needs to tie the
parameter is originally added for ELECTRA model which needs to tie the
generator embeddings with the discriminator embeddings.
generator embeddings with the discriminator embeddings.
dict_outputs: Whether to use a dictionary as the model outputs.
dict_outputs: Whether to use a dictionary as the model outputs.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
...
@@ -97,6 +100,7 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
embedding_width
=
None
,
embedding_width
=
None
,
embedding_layer
=
None
,
embedding_layer
=
None
,
dict_outputs
=
False
,
dict_outputs
=
False
,
norm_first
=
False
,
**
kwargs
):
**
kwargs
):
# b/164516224
# b/164516224
...
@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
...
@@ -120,7 +124,8 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
initializer
=
initializer
,
initializer
=
initializer
,
output_range
=
output_range
,
output_range
=
output_range
,
embedding_width
=
embedding_width
,
embedding_width
=
embedding_width
,
embedding_layer
=
embedding_layer
)
embedding_layer
=
embedding_layer
,
norm_first
=
norm_first
)
self
.
_embedding_layer_instance
=
embedding_layer
self
.
_embedding_layer_instance
=
embedding_layer
...
...
official/nlp/modeling/networks/bert_encoder_test.py
View file @
78c43ef1
...
@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
...
@@ -226,7 +226,8 @@ class BertEncoderTest(keras_parameterized.TestCase):
output_range
=-
1
,
output_range
=-
1
,
embedding_width
=
16
,
embedding_width
=
16
,
dict_outputs
=
True
,
dict_outputs
=
True
,
embedding_layer
=
None
)
embedding_layer
=
None
,
norm_first
=
False
)
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
network
=
bert_encoder
.
BertEncoder
(
**
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
=
dict
(
kwargs
)
expected_config
[
"activation"
]
=
tf
.
keras
.
activations
.
serialize
(
expected_config
[
"activation"
]
=
tf
.
keras
.
activations
.
serialize
(
...
...
official/nlp/modeling/ops/sampling_module.py
View file @
78c43ef1
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Sampling module for top_k, top_p and greedy decoding."""
"""Sampling module for top_k, top_p and greedy decoding."""
import
abc
import
abc
from
typing
import
Any
,
Callable
,
Dict
from
typing
import
Any
,
Callable
,
Dict
,
Optional
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
...
@@ -98,10 +98,10 @@ def sample_top_p(logits, top_p):
],
-
1
)
],
-
1
)
# Scatter sorted indices to original indexes.
# Scatter sorted indices to original indexes.
indices_to_remove
=
scatter_values_on_batch_indices
(
indices_to_remove
=
scatter_values_on_batch_indices
(
sorted_indices_to_remove
,
sorted_indices_to_remove
,
sorted_indices
)
sorted_indices
)
top_p_logits
=
set_tensor_by_indices_to_value
(
top_p_logits
=
set_tensor_by_indices_to_value
(
logits
,
indices_to_remove
,
logits
,
indices_to_remove
,
np
.
NINF
)
np
.
NINF
)
return
top_p_logits
return
top_p_logits
...
@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
...
@@ -121,13 +121,12 @@ def scatter_values_on_batch_indices(values, batch_indices):
tensor_shape
=
decoding_module
.
shape_list
(
batch_indices
)
tensor_shape
=
decoding_module
.
shape_list
(
batch_indices
)
broad_casted_batch_dims
=
tf
.
reshape
(
broad_casted_batch_dims
=
tf
.
reshape
(
tf
.
broadcast_to
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
tf
.
range
(
tensor_shape
[
0
]),
axis
=-
1
),
tf
.
expand_dims
(
tf
.
range
(
tensor_shape
[
0
]),
axis
=-
1
),
tensor_shape
),
tensor_shape
),
[
1
,
-
1
])
[
1
,
-
1
])
pair_indices
=
tf
.
transpose
(
pair_indices
=
tf
.
transpose
(
tf
.
concat
([
broad_casted_batch_dims
,
tf
.
concat
([
broad_casted_batch_dims
,
tf
.
reshape
(
batch_indices
,
[
1
,
-
1
])],
0
))
tf
.
reshape
(
batch_indices
,
[
1
,
-
1
])],
0
))
return
tf
.
scatter_nd
(
pair_indices
,
return
tf
.
scatter_nd
(
pair_indices
,
tf
.
reshape
(
values
,
[
-
1
]),
tensor_shape
)
tf
.
reshape
(
values
,
[
-
1
]),
tensor_shape
)
def
set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
def
set_tensor_by_indices_to_value
(
input_tensor
,
indices
,
value
):
...
@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
...
@@ -137,6 +136,7 @@ def set_tensor_by_indices_to_value(input_tensor, indices, value):
input_tensor: float (batch_size, dim)
input_tensor: float (batch_size, dim)
indices: bool (batch_size, dim)
indices: bool (batch_size, dim)
value: float scalar
value: float scalar
Returns:
Returns:
output_tensor: same shape as input_tensor.
output_tensor: same shape as input_tensor.
"""
"""
...
@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -150,11 +150,12 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
def
__init__
(
self
,
def
__init__
(
self
,
symbols_to_logits_fn
,
symbols_to_logits_fn
,
length_normalization_fn
:
Callable
[[
int
,
tf
.
DType
],
float
],
vocab_size
:
int
,
vocab_size
:
int
,
max_decode_length
:
int
,
max_decode_length
:
int
,
eos_id
:
int
,
eos_id
:
int
,
padded_decode
:
bool
,
padded_decode
:
bool
,
length_normalization_fn
:
Optional
[
Callable
[[
int
,
tf
.
DType
],
float
]]
=
None
,
top_k
=
0
,
top_k
=
0
,
top_p
=
1.0
,
top_p
=
1.0
,
sample_temperature
=
0.0
,
sample_temperature
=
0.0
,
...
@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -170,8 +171,8 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
self
.
max_decode_length
=
max_decode_length
self
.
max_decode_length
=
max_decode_length
self
.
top_k
=
tf
.
convert_to_tensor
(
top_k
,
dtype
=
tf
.
int32
)
self
.
top_k
=
tf
.
convert_to_tensor
(
top_k
,
dtype
=
tf
.
int32
)
self
.
top_p
=
tf
.
convert_to_tensor
(
top_p
,
dtype
=
tf
.
float32
)
self
.
top_p
=
tf
.
convert_to_tensor
(
top_p
,
dtype
=
tf
.
float32
)
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
sample_temperature
,
self
.
sample_temperature
=
tf
.
convert_to_tensor
(
dtype
=
tf
.
float32
)
sample_temperature
,
dtype
=
tf
.
float32
)
self
.
enable_greedy
=
enable_greedy
self
.
enable_greedy
=
enable_greedy
super
(
SamplingModule
,
self
).
__init__
(
super
(
SamplingModule
,
self
).
__init__
(
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
length_normalization_fn
=
length_normalization_fn
,
dtype
=
dtype
)
...
@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -330,12 +331,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
return
state
,
state_shape_invariants
return
state
,
state_shape_invariants
def
_get_new_alive_state
(
def
_get_new_alive_state
(
self
,
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
self
,
new_finished_flags
:
tf
.
Tensor
,
new_seq
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_cache
:
Dict
[
str
,
tf
.
Tensor
])
->
Dict
[
str
,
Any
]:
"""Gather the sequences that are still alive.
"""Gather the sequences that are still alive.
This function resets the sequences in the alive_state that are finished.
This function resets the sequences in the alive_state that are finished.
...
@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -360,9 +358,7 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
new_cache
decoding_module
.
StateKeys
.
ALIVE_CACHE
:
new_cache
}
}
def
_get_new_finished_state
(
self
,
def
_get_new_finished_state
(
self
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
state
:
Dict
[
str
,
Any
],
new_seq
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_log_probs
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
new_finished_flags
:
tf
.
Tensor
,
batch_size
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
batch_size
:
int
)
->
Dict
[
str
,
tf
.
Tensor
]:
...
@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
...
@@ -421,10 +417,9 @@ class SamplingModule(decoding_module.DecodingModule, metaclass=abc.ABCMeta):
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
length_norm
=
self
.
length_normalization_fn
(
self
.
max_decode_length
+
1
,
self
.
dtype
)
self
.
dtype
)
alive_log_probs
=
alive_log_probs
/
length_norm
alive_log_probs
=
alive_log_probs
/
length_norm
seq_cond
=
decoding_module
.
expand_to_same_rank
(
seq_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
finished_seq
)
finished_cond
,
finished_seq
)
score_cond
=
decoding_module
.
expand_to_same_rank
(
finished_cond
,
score_cond
=
decoding_module
.
expand_to_same_rank
(
finished_scores
)
finished_cond
,
finished_scores
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_seq
=
tf
.
where
(
seq_cond
,
finished_seq
,
alive_seq
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
finished_scores
=
tf
.
where
(
score_cond
,
finished_scores
,
alive_log_probs
)
return
finished_seq
,
finished_scores
return
finished_seq
,
finished_scores
...
...
official/nlp/train.py
View file @
78c43ef1
...
@@ -66,4 +66,5 @@ def main(_):
...
@@ -66,4 +66,5 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
tfm_flags
.
define_flags
()
flags
.
mark_flags_as_required
([
'experiment'
,
'mode'
,
'model_dir'
])
app
.
run
(
main
)
app
.
run
(
main
)
official/recommendation/data_pipeline.py
View file @
78c43ef1
...
@@ -29,17 +29,16 @@ import timeit
...
@@ -29,17 +29,16 @@ import timeit
import
traceback
import
traceback
import
typing
import
typing
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
six
from
six.moves
import
queue
from
six.moves
import
queue
import
tensorflow
as
tf
import
tensorflow
as
tf
from
absl
import
logging
from
tensorflow.python.tpu.datasets
import
StreamingFilesDataset
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
movielens
from
official.recommendation
import
popen_helper
from
official.recommendation
import
popen_helper
from
official.recommendation
import
stat_utils
from
official.recommendation
import
stat_utils
from
tensorflow.python.tpu.datasets
import
StreamingFilesDataset
SUMMARY_TEMPLATE
=
"""General:
SUMMARY_TEMPLATE
=
"""General:
{spacer}Num users: {num_users}
{spacer}Num users: {num_users}
...
@@ -119,6 +118,7 @@ class DatasetManager(object):
...
@@ -119,6 +118,7 @@ class DatasetManager(object):
"""Convert NumPy arrays into a TFRecords entry."""
"""Convert NumPy arrays into a TFRecords entry."""
def
create_int_feature
(
values
):
def
create_int_feature
(
values
):
values
=
np
.
squeeze
(
values
)
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
feature_dict
=
{
feature_dict
=
{
...
...
official/recommendation/data_preprocessing.py
View file @
78c43ef1
...
@@ -23,21 +23,19 @@ import os
...
@@ -23,21 +23,19 @@ import os
import
pickle
import
pickle
import
time
import
time
import
timeit
import
timeit
import
typing
# pylint: disable=wro
ng
-
import
-order
from
typi
ng
import
Dict
,
Text
,
Tuple
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
import
tensorflow
as
tf
import
tensorflow
as
tf
import
typing
from
typing
import
Dict
,
Text
,
Tuple
# pylint: enable=wrong-import-order
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
movielens
from
official.recommendation
import
movielens
_EXPECTED_CACHE_KEYS
=
(
rconst
.
TRAIN_USER_KEY
,
rconst
.
TRAIN_ITEM_KEY
,
_EXPECTED_CACHE_KEYS
=
(
rconst
.
TRAIN_USER_KEY
,
rconst
.
TRAIN_ITEM_KEY
,
rconst
.
EVAL_USER_KEY
,
rconst
.
EVAL_ITEM_KEY
,
rconst
.
EVAL_USER_KEY
,
rconst
.
EVAL_ITEM_KEY
,
rconst
.
USER_MAP
,
rconst
.
ITEM_MAP
)
rconst
.
USER_MAP
,
rconst
.
ITEM_MAP
)
...
@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
...
@@ -196,7 +194,7 @@ def _filter_index_sort(raw_rating_path: Text,
logging
.
info
(
"Writing raw data cache."
)
logging
.
info
(
"Writing raw data cache."
)
with
tf
.
io
.
gfile
.
GFile
(
cache_path
,
"wb"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
cache_path
,
"wb"
)
as
f
:
pickle
.
dump
(
data
,
f
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
data
,
f
,
protocol
=
4
)
# TODO(robieta): MLPerf cache clear.
# TODO(robieta): MLPerf cache clear.
return
data
,
valid_cache
return
data
,
valid_cache
...
...
official/recommendation/ranking/README.md
View file @
78c43ef1
...
@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
...
@@ -111,6 +111,7 @@ export TPU_NAME=my-dlrm-tpu
export
EXPERIMENT_NAME
=
my_experiment_name
export
EXPERIMENT_NAME
=
my_experiment_name
export
BUCKET_NAME
=
"gs://my_dlrm_bucket"
export
BUCKET_NAME
=
"gs://my_dlrm_bucket"
export
DATA_DIR
=
"
${
BUCKET_NAME
}
/data"
export
DATA_DIR
=
"
${
BUCKET_NAME
}
/data"
export
EMBEDDING_DIM
=
32
python3 models/official/recommendation/ranking/train.py
--mode
=
train_and_eval
\
python3 models/official/recommendation/ranking/train.py
--mode
=
train_and_eval
\
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
...
@@ -126,8 +127,8 @@ task:
...
@@ -126,8 +127,8 @@ task:
global_batch_size: 16384
global_batch_size: 16384
model:
model:
num_dense_features: 13
num_dense_features: 13
bottom_mlp: [512,256,
128
]
bottom_mlp: [512,256,
${
EMBEDDING_DIM
}
]
embedding_dim:
128
embedding_dim:
${
EMBEDDING_DIM
}
top_mlp: [1024,1024,512,256,1]
top_mlp: [1024,1024,512,256,1]
interaction: 'dot'
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
...
@@ -135,8 +136,8 @@ task:
...
@@ -135,8 +136,8 @@ task:
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer:
trainer:
use_orbit: true
use_orbit: true
validation_interval:
90000
validation_interval:
85352
checkpoint_interval:
100000
checkpoint_interval:
85352
validation_steps: 5440
validation_steps: 5440
train_steps: 256054
train_steps: 256054
steps_per_loop: 1000
steps_per_loop: 1000
...
@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
...
@@ -154,7 +155,9 @@ Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs):
to be updated and number of GPUs provided (for 4 GPUs):
```
shell
```
shell
python3 official/recommendation/ranking/main.py
--mode
=
train_and_eval
\
export
EMBEDDING_DIM
=
8
python3 official/recommendation/ranking/train.py
--mode
=
train_and_eval
\
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
--model_dir
=
${
BUCKET_NAME
}
/model_dirs/
${
EXPERIMENT_NAME
}
--params_override
=
"
runtime:
runtime:
distribution_strategy: 'mirrored'
distribution_strategy: 'mirrored'
...
...
official/
utils/misc/distribution_utils
.py
→
official/
recommendation/ranking/configs/__init__
.py
View file @
78c43ef1
...
@@ -12,6 +12,3 @@
...
@@ -12,6 +12,3 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Helper functions for running models in a distributed setting."""
# pylint: disable=wildcard-import
from
official.common.distribute_utils
import
*
official/recommendation/ranking/configs/config.py
View file @
78c43ef1
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Ranking Model configuration definition."""
"""Ranking Model configuration definition."""
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Union
import
dataclasses
import
dataclasses
from
official.core
import
exp_factory
from
official.core
import
exp_factory
...
@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
...
@@ -59,7 +59,13 @@ class ModelConfig(hyperparams.Config):
num_dense_features: Number of dense features.
num_dense_features: Number of dense features.
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
vocab_sizes: Vocab sizes for each of the sparse features. The order agrees
with the order of the input data.
with the order of the input data.
embedding_dim: Embedding dimension.
embedding_dim: An integer or a list of embedding table dimensions.
If it's an integer then all tables will have the same embedding dimension.
If it's a list then the length should match with `vocab_sizes`.
size_threshold: A threshold for table sizes below which a keras
embedding layer is used, and above which a TPU embedding layer is used.
If it's -1 then only keras embedding layer will be used for all tables,
if 0 only then only TPU embedding layer will be used.
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
bottom_mlp: The sizes of hidden layers for bottom MLP applied to dense
features.
features.
top_mlp: The sizes of hidden layers for top MLP.
top_mlp: The sizes of hidden layers for top MLP.
...
@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
...
@@ -68,7 +74,8 @@ class ModelConfig(hyperparams.Config):
"""
"""
num_dense_features
:
int
=
13
num_dense_features
:
int
=
13
vocab_sizes
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
vocab_sizes
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
embedding_dim
:
int
=
8
embedding_dim
:
Union
[
int
,
List
[
int
]]
=
8
size_threshold
:
int
=
50_000
bottom_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
bottom_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
top_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
top_mlp
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
interaction
:
str
=
'dot'
interaction
:
str
=
'dot'
...
@@ -188,7 +195,7 @@ def default_config() -> Config:
...
@@ -188,7 +195,7 @@ def default_config() -> Config:
runtime
=
cfg
.
RuntimeConfig
(),
runtime
=
cfg
.
RuntimeConfig
(),
task
=
Task
(
task
=
Task
(
model
=
ModelConfig
(
model
=
ModelConfig
(
embedding_dim
=
4
,
embedding_dim
=
8
,
vocab_sizes
=
vocab_sizes
,
vocab_sizes
=
vocab_sizes
,
bottom_mlp
=
[
64
,
32
,
4
],
bottom_mlp
=
[
64
,
32
,
4
],
top_mlp
=
[
64
,
32
,
1
]),
top_mlp
=
[
64
,
32
,
1
]),
...
...
official/recommendation/ranking/data/__init__.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/recommendation/ranking/data_pipeline.py
→
official/recommendation/ranking/data
/data
_pipeline.py
View file @
78c43ef1
...
@@ -136,7 +136,7 @@ class CriteoTsvReader:
...
@@ -136,7 +136,7 @@ class CriteoTsvReader:
num_replicas
=
ctx
.
num_replicas_in_sync
if
ctx
else
1
num_replicas
=
ctx
.
num_replicas_in_sync
if
ctx
else
1
if
params
.
is_training
:
if
params
.
is_training
:
dataset_size
=
1000
0
*
batch_size
*
num_replicas
dataset_size
=
1000
*
batch_size
*
num_replicas
else
:
else
:
dataset_size
=
1000
*
batch_size
*
num_replicas
dataset_size
=
1000
*
batch_size
*
num_replicas
dense_tensor
=
tf
.
random
.
uniform
(
dense_tensor
=
tf
.
random
.
uniform
(
...
@@ -169,6 +169,7 @@ class CriteoTsvReader:
...
@@ -169,6 +169,7 @@ class CriteoTsvReader:
'sparse_features'
:
sparse_tensor_elements
},
label_tensor
'sparse_features'
:
sparse_tensor_elements
},
label_tensor
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
input_elem
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
input_elem
)
dataset
=
dataset
.
cache
()
if
params
.
is_training
:
if
params
.
is_training
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
...
...
official/recommendation/ranking/data_pipeline_test.py
→
official/recommendation/ranking/data
/data
_pipeline_test.py
View file @
78c43ef1
...
@@ -17,8 +17,8 @@
...
@@ -17,8 +17,8 @@
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
class
DataPipelineTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
DataPipelineTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/recommendation/ranking/task.py
View file @
78c43ef1
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Task for the Ranking model."""
"""Task for the Ranking model."""
import
math
import
math
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_recommenders
as
tfrs
import
tensorflow_recommenders
as
tfrs
...
@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
...
@@ -23,36 +23,49 @@ import tensorflow_recommenders as tfrs
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.recommendation.ranking
import
common
from
official.recommendation.ranking
import
common
from
official.recommendation.ranking
import
data_pipeline
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.configs
import
config
from
official.recommendation.ranking.data
import
data_pipeline
RuntimeConfig
=
config_definitions
.
RuntimeConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
def
_get_tpu_embedding_feature_config
(
def
_get_tpu_embedding_feature_config
(
vocab_sizes
:
List
[
int
],
vocab_sizes
:
List
[
int
],
embedding_dim
:
int
,
embedding_dim
:
Union
[
int
,
List
[
int
]]
,
table_name_prefix
:
str
=
'embedding_table'
table_name_prefix
:
str
=
'embedding_table'
)
->
Dict
[
str
,
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
]:
)
->
Dict
[
str
,
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
]:
"""Returns TPU embedding feature config.
"""Returns TPU embedding feature config.
i'th table config will have vocab size of vocab_sizes[i] and embedding
dimension of embedding_dim if embedding_dim is an int or embedding_dim[i] if
embedding_dim is a list).
Args:
Args:
vocab_sizes: List of sizes of categories/id's in the table.
vocab_sizes: List of sizes of categories/id's in the table.
embedding_dim:
E
mbedding dimension.
embedding_dim:
An integer or a list of e
mbedding
table
dimension
s
.
table_name_prefix: a prefix for embedding tables.
table_name_prefix: a prefix for embedding tables.
Returns:
Returns:
A dictionary of feature_name, FeatureConfig pairs.
A dictionary of feature_name, FeatureConfig pairs.
"""
"""
if
isinstance
(
embedding_dim
,
List
):
if
len
(
vocab_sizes
)
!=
len
(
embedding_dim
):
raise
ValueError
(
f
'length of vocab_sizes:
{
len
(
vocab_sizes
)
}
is not equal to the '
f
'length of embedding_dim:
{
len
(
embedding_dim
)
}
'
)
elif
isinstance
(
embedding_dim
,
int
):
embedding_dim
=
[
embedding_dim
]
*
len
(
vocab_sizes
)
else
:
raise
ValueError
(
'embedding_dim is not either a list or an int, got '
f
'
{
type
(
embedding_dim
)
}
'
)
feature_config
=
{}
feature_config
=
{}
for
i
,
vocab_size
in
enumerate
(
vocab_sizes
):
for
i
,
vocab_size
in
enumerate
(
vocab_sizes
):
table_config
=
tf
.
tpu
.
experimental
.
embedding
.
TableConfig
(
table_config
=
tf
.
tpu
.
experimental
.
embedding
.
TableConfig
(
vocabulary_size
=
vocab_size
,
vocabulary_size
=
vocab_size
,
dim
=
embedding_dim
,
dim
=
embedding_dim
[
i
]
,
combiner
=
'mean'
,
combiner
=
'mean'
,
initializer
=
tf
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
initializers
.
TruncatedNormal
(
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
)),
mean
=
0.0
,
stddev
=
1
/
math
.
sqrt
(
embedding_dim
[
i
]
)),
name
=
table_name_prefix
+
'_%s'
%
i
)
name
=
table_name_prefix
+
'_%s'
%
i
)
feature_config
[
str
(
i
)]
=
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
(
feature_config
[
str
(
i
)]
=
tf
.
tpu
.
experimental
.
embedding
.
FeatureConfig
(
table
=
table_config
)
table
=
table_config
)
...
@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
...
@@ -72,7 +85,7 @@ class RankingTask(base_task.Task):
"""Task initialization.
"""Task initialization.
Args:
Args:
params: the Ran
n
kingModel task configuration instance.
params: the RankingModel task configuration instance.
optimizer_config: Optimizer configuration instance.
optimizer_config: Optimizer configuration instance.
logging_dir: a string pointing to where the model, summaries etc. will be
logging_dir: a string pointing to where the model, summaries etc. will be
saved.
saved.
...
@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
...
@@ -125,15 +138,18 @@ class RankingTask(base_task.Task):
self
.
optimizer_config
.
embedding_optimizer
)
self
.
optimizer_config
.
embedding_optimizer
)
embedding_optimizer
.
learning_rate
=
lr_callable
embedding_optimizer
.
learning_rate
=
lr_callable
emb_
feature_config
=
_get_tpu_embedding_feature_config
(
feature_config
=
_get_tpu_embedding_feature_config
(
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
,
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
,
embedding_dim
=
self
.
task_config
.
model
.
embedding_dim
)
vocab_sizes
=
self
.
task_config
.
model
.
vocab_sizes
)
tpu_embedding
=
tfrs
.
layers
.
embedding
.
TPUEmbedding
(
embedding_layer
=
tfrs
.
experimental
.
layers
.
embedding
.
PartialTPUEmbedding
(
emb_feature_config
,
embedding_optimizer
)
feature_config
=
feature_config
,
optimizer
=
embedding_optimizer
,
size_threshold
=
self
.
task_config
.
model
.
size_threshold
)
if
self
.
task_config
.
model
.
interaction
==
'dot'
:
if
self
.
task_config
.
model
.
interaction
==
'dot'
:
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
()
feature_interaction
=
tfrs
.
layers
.
feature_interaction
.
DotInteraction
(
skip_gather
=
True
)
elif
self
.
task_config
.
model
.
interaction
==
'cross'
:
elif
self
.
task_config
.
model
.
interaction
==
'cross'
:
feature_interaction
=
tf
.
keras
.
Sequential
([
feature_interaction
=
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Concatenate
(),
tf
.
keras
.
layers
.
Concatenate
(),
...
@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
...
@@ -145,7 +161,7 @@ class RankingTask(base_task.Task):
f
'is not supported it must be either
\'
dot
\'
or
\'
cross
\'
.'
)
f
'is not supported it must be either
\'
dot
\'
or
\'
cross
\'
.'
)
model
=
tfrs
.
experimental
.
models
.
Ranking
(
model
=
tfrs
.
experimental
.
models
.
Ranking
(
embedding_layer
=
tpu_
embedding
,
embedding_layer
=
embedding
_layer
,
bottom_stack
=
tfrs
.
layers
.
blocks
.
MLP
(
bottom_stack
=
tfrs
.
layers
.
blocks
.
MLP
(
units
=
self
.
task_config
.
model
.
bottom_mlp
,
final_activation
=
'relu'
),
units
=
self
.
task_config
.
model
.
bottom_mlp
,
final_activation
=
'relu'
),
feature_interaction
=
feature_interaction
,
feature_interaction
=
feature_interaction
,
...
@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
...
@@ -184,3 +200,5 @@ class RankingTask(base_task.Task):
@
property
@
property
def
optimizer_config
(
self
)
->
config
.
OptimizationConfig
:
def
optimizer_config
(
self
)
->
config
.
OptimizationConfig
:
return
self
.
_optimizer_config
return
self
.
_optimizer_config
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment