Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
44a545b6
Commit
44a545b6
authored
Nov 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 342972267
parent
9e0cd251
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
27 deletions
+35
-27
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+26
-22
official/nlp/modeling/models/bert_pretrainer_test.py
official/nlp/modeling/models/bert_pretrainer_test.py
+7
-5
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+2
-0
No files found.
official/nlp/modeling/models/bert_pretrainer.py
View file @
44a545b6
...
...
@@ -18,6 +18,7 @@ import collections
import
copy
from
typing
import
List
,
Optional
from
absl
import
logging
import
gin
import
tensorflow
as
tf
...
...
@@ -164,7 +165,6 @@ class BertPretrainer(tf.keras.Model):
class
BertPretrainerV2
(
tf
.
keras
.
Model
):
"""BERT pretraining model V2.
(Experimental).
Adds the masked language model head and optional classification heads upon the
transformer encoder.
...
...
@@ -198,7 +198,7 @@ class BertPretrainerV2(tf.keras.Model):
customized_masked_lm
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
name
:
str
=
'bert'
,
**
kwargs
):
s
elf
.
_self_setattr_tracking
=
False
s
uper
().
__init__
(
self
,
name
=
name
,
**
kwargs
)
self
.
_config
=
{
'encoder_network'
:
encoder_network
,
'mlm_initializer'
:
mlm_initializer
,
...
...
@@ -207,6 +207,28 @@ class BertPretrainerV2(tf.keras.Model):
}
self
.
encoder_network
=
encoder_network
inputs
=
copy
.
copy
(
self
.
encoder_network
.
inputs
)
self
.
classification_heads
=
classification_heads
or
[]
if
len
(
set
([
cls
.
name
for
cls
in
self
.
classification_heads
]))
!=
len
(
self
.
classification_heads
):
raise
ValueError
(
'Classification heads should have unique names.'
)
self
.
masked_lm
=
customized_masked_lm
or
layers
.
MaskedLM
(
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
name
=
'cls/predictions'
)
masked_lm_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
name
=
'masked_lm_positions'
,
dtype
=
tf
.
int32
)
inputs
.
append
(
masked_lm_positions
)
self
.
inputs
=
inputs
def
call
(
self
,
inputs
):
if
isinstance
(
inputs
,
list
):
logging
.
warning
(
'List inputs to BertPretrainer are discouraged.'
)
inputs
=
dict
([
(
ref
.
name
,
tensor
)
for
ref
,
tensor
in
zip
(
self
.
inputs
,
inputs
)
])
outputs
=
dict
()
encoder_network_outputs
=
self
.
encoder_network
(
inputs
)
if
isinstance
(
encoder_network_outputs
,
list
):
...
...
@@ -224,31 +246,13 @@ class BertPretrainerV2(tf.keras.Model):
else
:
raise
ValueError
(
'encoder_network
\'
s output should be either a list '
'or a dict, but got %s'
%
encoder_network_outputs
)
sequence_output
=
outputs
[
'sequence_output'
]
self
.
classification_heads
=
classification_heads
or
[]
if
len
(
set
([
cls
.
name
for
cls
in
self
.
classification_heads
]))
!=
len
(
self
.
classification_heads
):
raise
ValueError
(
'Classification heads should have unique names.'
)
if
customized_masked_lm
is
not
None
:
self
.
masked_lm
=
customized_masked_lm
else
:
self
.
masked_lm
=
layers
.
MaskedLM
(
embedding_table
=
self
.
encoder_network
.
get_embedding_table
(),
activation
=
mlm_activation
,
initializer
=
mlm_initializer
,
name
=
'cls/predictions'
)
masked_lm_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
name
=
'masked_lm_positions'
,
dtype
=
tf
.
int32
)
inputs
.
append
(
masked_lm_positions
)
masked_lm_positions
=
inputs
[
'masked_lm_positions'
]
outputs
[
'mlm_logits'
]
=
self
.
masked_lm
(
sequence_output
,
masked_positions
=
masked_lm_positions
)
for
cls_head
in
self
.
classification_heads
:
outputs
[
cls_head
.
name
]
=
cls_head
(
sequence_output
)
super
(
BertPretrainerV2
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
outputs
,
name
=
name
,
**
kwargs
)
return
outputs
@
property
def
checkpoint_items
(
self
):
...
...
official/nlp/modeling/models/bert_pretrainer_test.py
View file @
44a545b6
...
...
@@ -142,13 +142,15 @@ class BertPretrainerTest(keras_parameterized.TestCase):
encoder_network
=
test_network
,
customized_masked_lm
=
customized_masked_lm
)
num_token_predictions
=
20
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
)
lm_mask
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
)
inputs
=
dict
(
input_word_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_mask
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
input_type_ids
=
tf
.
keras
.
Input
(
shape
=
(
sequence_length
,),
dtype
=
tf
.
int32
),
masked_lm_positions
=
tf
.
keras
.
Input
(
shape
=
(
num_token_predictions
,),
dtype
=
tf
.
int32
))
# Invoke the trainer model on the inputs. This causes the layer to be built.
outputs
=
bert_trainer_model
(
[
word_ids
,
mask
,
type_ids
,
lm_mask
]
)
outputs
=
bert_trainer_model
(
inputs
)
has_encoder_outputs
=
dict_outputs
or
return_all_encoder_outputs
if
has_encoder_outputs
:
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
44a545b6
...
...
@@ -103,6 +103,8 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
inner_dim
=
768
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
# The model variables will be created after the forward call.
_
=
pretrain_model
(
pretrain_model
.
inputs
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
init_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment