Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
cc783998
Commit
cc783998
authored
Aug 18, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 18, 2020
Browse files
Internal change
PiperOrigin-RevId: 327325736
parent
bf6a29e4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
20 deletions
+29
-20
official/nlp/modeling/models/bert_classifier.py
official/nlp/modeling/models/bert_classifier.py
+5
-1
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+1
-2
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+23
-17
No files found.
official/nlp/modeling/models/bert_classifier.py
View file @
cc783998
...
@@ -97,7 +97,11 @@ class BertClassifier(tf.keras.Model):
...
@@ -97,7 +97,11 @@ class BertClassifier(tf.keras.Model):
@
property
@
property
def
checkpoint_items
(
self
):
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
items
=
dict
(
encoder
=
self
.
_network
)
if
hasattr
(
self
.
classifier
,
'checkpoint_items'
):
for
key
,
item
in
self
.
classifier
.
checkpoint_items
.
items
():
items
[
'.'
.
join
([
self
.
classifier
.
name
,
key
])]
=
item
return
items
def
get_config
(
self
):
def
get_config
(
self
):
return
self
.
_config
return
self
.
_config
...
...
official/nlp/tasks/sentence_prediction.py
View file @
cc783998
...
@@ -215,9 +215,8 @@ class SentencePredictionTask(base_task.Task):
...
@@ -215,9 +215,8 @@ class SentencePredictionTask(base_task.Task):
pretrain2finetune_mapping
=
{
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
}
}
# TODO(b/160251903): Investigate why no pooler dense improves finetuning
# accuracies.
if
self
.
task_config
.
init_cls_pooler
:
if
self
.
task_config
.
init_cls_pooler
:
# This option is valid when use_encoder_pooler is false.
pretrain2finetune_mapping
[
pretrain2finetune_mapping
[
'next_sentence.pooler_dense'
]
=
model
.
checkpoint_items
[
'next_sentence.pooler_dense'
]
=
model
.
checkpoint_items
[
'sentence_prediction.pooler_dense'
]
'sentence_prediction.pooler_dense'
]
...
...
official/nlp/tasks/sentence_prediction_test.py
View file @
cc783998
...
@@ -87,34 +87,40 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -87,34 +87,40 @@ class SentencePredictionTaskTest(tf.test.TestCase, parameterized.TestCase):
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
test_task
(
self
):
@
parameterized
.
named_parameters
(
config
=
sentence_prediction
.
SentencePredictionConfig
(
(
"init_cls_pooler"
,
True
),
init_checkpoint
=
self
.
get_temp_dir
(),
(
"init_encoder"
,
False
),
model
=
self
.
get_model_config
(
2
),
)
train_data
=
self
.
_train_data_config
)
def
test_task
(
self
,
init_cls_pooler
):
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
# Saves a checkpoint.
pretrain_cfg
=
bert
.
PretrainerConfig
(
pretrain_cfg
=
bert
.
PretrainerConfig
(
encoder
=
encoders
.
EncoderConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
)),
cls_heads
=
[
cls_heads
=
[
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
inner_dim
=
768
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
pretrain_model
=
masked_lm
.
MaskedLMTask
(
None
).
build_model
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
init_path
=
ckpt
.
save
(
self
.
get_temp_dir
())
# Creates the task.
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
init_path
,
model
=
self
.
get_model_config
(
num_classes
=
2
),
train_data
=
self
.
_train_data_config
,
init_cls_pooler
=
init_cls_pooler
)
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
initialize
(
model
)
task
.
initialize
(
model
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment