Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6c63efed
Commit
6c63efed
authored
Jul 05, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 05, 2020
Browse files
Internal change
PiperOrigin-RevId: 319697990
parent
36e786dc
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
30 additions
and
6 deletions
+30
-6
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+3
-2
official/nlp/modeling/models/bert_pretrainer.py
official/nlp/modeling/models/bert_pretrainer.py
+1
-1
official/nlp/tasks/masked_lm.py
official/nlp/tasks/masked_lm.py
+16
-0
official/nlp/tasks/masked_lm_test.py
official/nlp/tasks/masked_lm_test.py
+7
-0
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+1
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+1
-1
official/nlp/tasks/tagging.py
official/nlp/tasks/tagging.py
+1
-1
No files found.
official/nlp/configs/bert_test.py
View file @
6c63efed
...
...
@@ -57,8 +57,9 @@ class BertModelsTest(tf.test.TestCase):
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
encoder
=
bert
.
instantiate_bertpretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"next_sentence.pooler_dense"
])
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
...
...
official/nlp/modeling/models/bert_pretrainer.py
View file @
6c63efed
...
...
@@ -217,7 +217,7 @@ class BertPretrainerV2(tf.keras.Model):
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
items
=
dict
(
encoder
=
self
.
encoder_network
)
items
=
dict
(
encoder
=
self
.
encoder_network
,
masked_lm
=
self
.
masked_lm
)
for
head
in
self
.
classification_heads
:
for
key
,
item
in
head
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
head
.
name
,
key
])]
=
item
...
...
official/nlp/tasks/masked_lm.py
View file @
6c63efed
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
"""Masked language task."""
from
absl
import
logging
import
dataclasses
import
tensorflow
as
tf
...
...
@@ -26,6 +27,7 @@ from official.nlp.data import data_loader_factory
@
dataclasses
.
dataclass
class
MaskedLMConfig
(
cfg
.
TaskConfig
):
"""The model config."""
init_checkpoint
:
str
=
''
model
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
2
,
dropout_rate
=
0.1
,
name
=
'next_sentence'
)
...
...
@@ -171,3 +173,17 @@ class MaskedLMTask(base_task.Task):
aux_losses
=
model
.
losses
)
self
.
process_metrics
(
metrics
,
inputs
,
outputs
)
return
{
self
.
loss
:
loss
}
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
# Restoring all modules defined by the model, e.g. encoder, masked_lm and
# cls pooler. The best initialization may vary case by case.
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/masked_lm_test.py
View file @
6c63efed
...
...
@@ -27,6 +27,7 @@ class MLMTaskTest(tf.test.TestCase):
def
test_task
(
self
):
config
=
masked_lm
.
MaskedLMConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
model
=
bert
.
BertPretrainerConfig
(
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
20
,
...
...
@@ -49,6 +50,12 @@ class MLMTaskTest(tf.test.TestCase):
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
model
,
**
model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/tasks/question_answering.py
View file @
6c63efed
...
...
@@ -282,5 +282,5 @@ class QuestionAnsweringTask(base_task.Task):
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'
f
inished loading pretrained checkpoint from %s'
,
logging
.
info
(
'
F
inished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/sentence_prediction.py
View file @
6c63efed
...
...
@@ -189,5 +189,5 @@ class SentencePredictionTask(base_task.Task):
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrain2finetune_mapping
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'
f
inished loading pretrained checkpoint from %s'
,
logging
.
info
(
'
F
inished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/tagging.py
View file @
6c63efed
...
...
@@ -212,5 +212,5 @@ class TaggingTask(base_task.Task):
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'
f
inished loading pretrained checkpoint from %s'
,
logging
.
info
(
'
F
inished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment