Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
29120423
Commit
29120423
authored
Nov 30, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 30, 2020
Browse files
Adds a MultiClsHeads layer.
PiperOrigin-RevId: 344926166
parent
2a553e51
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
4 deletions
+140
-4
official/nlp/modeling/layers/cls_head.py
official/nlp/modeling/layers/cls_head.py
+77
-1
official/nlp/modeling/layers/cls_head_test.py
official/nlp/modeling/layers/cls_head_test.py
+22
-1
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+6
-2
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+35
-0
No files found.
official/nlp/modeling/layers/cls_head.py
View file @
29120423
...
@@ -42,7 +42,7 @@ class ClassificationHead(tf.keras.layers.Layer):
...
@@ -42,7 +42,7 @@ class ClassificationHead(tf.keras.layers.Layer):
initializer: Initializer for dense layer kernels.
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
**kwargs: Keyword arguments.
"""
"""
super
(
ClassificationHead
,
self
).
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
inner_dim
=
inner_dim
self
.
inner_dim
=
inner_dim
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
...
@@ -68,6 +68,7 @@ class ClassificationHead(tf.keras.layers.Layer):
...
@@ -68,6 +68,7 @@ class ClassificationHead(tf.keras.layers.Layer):
def
get_config
(
self
):
def
get_config
(
self
):
config
=
{
config
=
{
"cls_token_idx"
:
self
.
cls_token_idx
,
"dropout_rate"
:
self
.
dropout_rate
,
"dropout_rate"
:
self
.
dropout_rate
,
"num_classes"
:
self
.
num_classes
,
"num_classes"
:
self
.
num_classes
,
"inner_dim"
:
self
.
inner_dim
,
"inner_dim"
:
self
.
inner_dim
,
...
@@ -84,3 +85,78 @@ class ClassificationHead(tf.keras.layers.Layer):
...
@@ -84,3 +85,78 @@ class ClassificationHead(tf.keras.layers.Layer):
@
property
@
property
def
checkpoint_items
(
self
):
def
checkpoint_items
(
self
):
return
{
self
.
dense
.
name
:
self
.
dense
}
return
{
self
.
dense
.
name
:
self
.
dense
}
class
MultiClsHeads
(
tf
.
keras
.
layers
.
Layer
):
"""Pooling heads sharing the same pooling stem."""
def
__init__
(
self
,
inner_dim
,
cls_list
,
cls_token_idx
=
0
,
activation
=
"tanh"
,
dropout_rate
=
0.0
,
initializer
=
"glorot_uniform"
,
**
kwargs
):
"""Initializes the `MultiClsHeads`.
Args:
inner_dim: The dimensionality of inner projection layer.
cls_list: a list of pairs of (classification problem name and the numbers
of classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
**
kwargs
)
self
.
dropout_rate
=
dropout_rate
self
.
inner_dim
=
inner_dim
self
.
cls_list
=
cls_list
self
.
activation
=
tf_utils
.
get_activation
(
activation
)
self
.
initializer
=
tf
.
keras
.
initializers
.
get
(
initializer
)
self
.
cls_token_idx
=
cls_token_idx
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
inner_dim
,
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
"pooler_dense"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
out_projs
=
[]
for
name
,
num_classes
in
cls_list
:
self
.
out_projs
.
append
(
tf
.
keras
.
layers
.
Dense
(
units
=
num_classes
,
kernel_initializer
=
self
.
initializer
,
name
=
name
))
def
call
(
self
,
features
):
x
=
features
[:,
self
.
cls_token_idx
,
:]
# take <CLS> token.
x
=
self
.
dense
(
x
)
x
=
self
.
dropout
(
x
)
outputs
=
{}
for
proj_layer
in
self
.
out_projs
:
outputs
[
proj_layer
.
name
]
=
proj_layer
(
x
)
return
outputs
def
get_config
(
self
):
config
=
{
"dropout_rate"
:
self
.
dropout_rate
,
"cls_token_idx"
:
self
.
cls_token_idx
,
"cls_list"
:
self
.
cls_list
,
"inner_dim"
:
self
.
inner_dim
,
"activation"
:
tf
.
keras
.
activations
.
serialize
(
self
.
activation
),
"initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
initializer
),
}
config
.
update
(
super
().
get_config
())
return
config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
property
def
checkpoint_items
(
self
):
# TODO(hongkuny): add output projects to the checkpoint items.
return
{
self
.
dense
.
name
:
self
.
dense
}
official/nlp/modeling/layers/cls_head_test.py
View file @
29120423
...
@@ -20,7 +20,7 @@ import tensorflow as tf
...
@@ -20,7 +20,7 @@ import tensorflow as tf
from
official.nlp.modeling.layers
import
cls_head
from
official.nlp.modeling.layers
import
cls_head
class
ClassificationHead
(
tf
.
test
.
TestCase
):
class
ClassificationHead
Test
(
tf
.
test
.
TestCase
):
def
test_layer_invocation
(
self
):
def
test_layer_invocation
(
self
):
test_layer
=
cls_head
.
ClassificationHead
(
inner_dim
=
5
,
num_classes
=
2
)
test_layer
=
cls_head
.
ClassificationHead
(
inner_dim
=
5
,
num_classes
=
2
)
...
@@ -38,5 +38,26 @@ class ClassificationHead(tf.test.TestCase):
...
@@ -38,5 +38,26 @@ class ClassificationHead(tf.test.TestCase):
self
.
assertAllEqual
(
layer
.
get_config
(),
new_layer
.
get_config
())
self
.
assertAllEqual
(
layer
.
get_config
(),
new_layer
.
get_config
())
class
MultiClsHeadsTest
(
tf
.
test
.
TestCase
):
def
test_layer_invocation
(
self
):
cls_list
=
[(
"foo"
,
2
),
(
"bar"
,
3
)]
test_layer
=
cls_head
.
MultiClsHeads
(
inner_dim
=
5
,
cls_list
=
cls_list
)
features
=
tf
.
zeros
(
shape
=
(
2
,
10
,
10
),
dtype
=
tf
.
float32
)
outputs
=
test_layer
(
features
)
self
.
assertAllClose
(
outputs
[
"foo"
],
[[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertAllClose
(
outputs
[
"bar"
],
[[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]])
self
.
assertSameElements
(
test_layer
.
checkpoint_items
.
keys
(),
[
"pooler_dense"
])
def
test_layer_serialization
(
self
):
cls_list
=
[(
"foo"
,
2
),
(
"bar"
,
3
)]
test_layer
=
cls_head
.
MultiClsHeads
(
inner_dim
=
5
,
cls_list
=
cls_list
)
new_layer
=
cls_head
.
MultiClsHeads
.
from_config
(
test_layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/models/bert_pretrainer.py
View file @
29120423
...
@@ -183,7 +183,7 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -183,7 +183,7 @@ class BertPretrainerV2(tf.keras.Model):
dictionary.
dictionary.
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
Outputs: A dictionary of `lm_output`, classification head outputs keyed by
head names, and also outputs from `encoder_network`, keyed by
head names, and also outputs from `encoder_network`, keyed by
`pooled_output`,
`sequence_output` and `encoder_outputs` (if any).
`sequence_output` and `encoder_outputs` (if any).
"""
"""
def
__init__
(
def
__init__
(
...
@@ -248,7 +248,11 @@ class BertPretrainerV2(tf.keras.Model):
...
@@ -248,7 +248,11 @@ class BertPretrainerV2(tf.keras.Model):
outputs
[
'mlm_logits'
]
=
self
.
masked_lm
(
outputs
[
'mlm_logits'
]
=
self
.
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
sequence_output
,
masked_positions
=
masked_lm_positions
)
for
cls_head
in
self
.
classification_heads
:
for
cls_head
in
self
.
classification_heads
:
outputs
[
cls_head
.
name
]
=
cls_head
(
sequence_output
)
cls_outputs
=
cls_head
(
sequence_output
)
if
isinstance
(
cls_outputs
,
dict
):
outputs
.
update
(
cls_outputs
)
else
:
outputs
[
cls_head
.
name
]
=
cls_outputs
return
outputs
return
outputs
@
property
@
property
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
29120423
...
@@ -110,6 +110,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -110,6 +110,9 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
bert_trainer_model
.
get_config
(),
self
.
assertAllEqual
(
bert_trainer_model
.
get_config
(),
new_bert_trainer_model
.
get_config
())
new_bert_trainer_model
.
get_config
())
class
BertPretrainerV2Test
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
@
parameterized
.
parameters
(
itertools
.
product
(
(
False
,
True
),
(
False
,
True
),
(
False
,
True
),
(
False
,
True
),
...
@@ -175,6 +178,38 @@ class BertPretrainerTest(keras_parameterized.TestCase):
...
@@ -175,6 +178,38 @@ class BertPretrainerTest(keras_parameterized.TestCase):
self
.
assertAllEqual
(
expected_pooled_output_shape
,
self
.
assertAllEqual
(
expected_pooled_output_shape
,
outputs
[
'pooled_output'
].
shape
.
as_list
())
outputs
[
'pooled_output'
].
shape
.
as_list
())
def
test_multiple_cls_outputs
(
self
):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size
=
100
sequence_length
=
512
hidden_size
=
48
num_layers
=
2
test_network
=
networks
.
BertEncoder
(
vocab_size
=
vocab_size
,
num_layers
=
num_layers
,
hidden_size
=
hidden_size
,
max_sequence_length
=
sequence_length
,
dict_outputs
=
True
)
bert_trainer_model
=
bert_pretrainer
.
BertPretrainerV2
(
encoder_network
=
test_network
,
classification_heads
=
[
layers
.
MultiClsHeads
(
inner_dim
=
5
,
cls_list
=
[(
'foo'
,
2
),
(
'bar'
,
3
)])])
num_token_predictions
=
20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
masked_lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
bert_trainer_model
(
inputs
)
self
.
assertEqual
(
outputs
[
'foo'
].
shape
.
as_list
(),
[
None
,
2
])
self
.
assertEqual
(
outputs
[
'bar'
].
shape
.
as_list
(),
[
None
,
3
])
def
test_v2_serialize_deserialize
(
self
):
def
test_v2_serialize_deserialize
(
self
):
"""Validate that the BERT trainer can be serialized and deserialized."""
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# Build a transformer network to use within the BERT trainer. (Here, we use
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment