Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
43587c64
Commit
43587c64
authored
Jun 16, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 16, 2020
Browse files
Internal change
PiperOrigin-RevId: 316784919
parent
85cfe94d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
17 deletions
+32
-17
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+14
-16
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+18
-1
No files found.
official/nlp/tasks/sentence_prediction.py
View file @
43587c64
...
...
@@ -29,9 +29,9 @@ from official.nlp.modeling import losses as loss_lib
@
dataclasses
.
dataclass
class
SentencePredictionConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `
pretra
in_checkpoint
_dir
` and `hub_module_url` can
# At most one of `in
it
_checkpoint` and `hub_module_url` can
# be specified.
pretra
in_checkpoint
_dir
:
str
=
''
in
it
_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
network
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
num_masked_tokens
=
0
,
...
...
@@ -52,7 +52,7 @@ class SentencePredictionTask(base_task.Task):
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
SentencePredictionTask
,
self
).
__init__
(
params
)
if
params
.
hub_module_url
and
params
.
pretra
in_checkpoint
_dir
:
if
params
.
hub_module_url
and
params
.
in
it
_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.'
)
if
params
.
hub_module_url
:
...
...
@@ -82,8 +82,8 @@ class SentencePredictionTask(base_task.Task):
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
labels
,
predictions
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'sentence_prediction'
],
axis
=-
1
))
predictions
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'sentence_prediction'
],
axis
=-
1
))
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
...
...
@@ -92,6 +92,7 @@ class SentencePredictionTask(base_task.Task):
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
...
...
@@ -112,9 +113,7 @@ class SentencePredictionTask(base_task.Task):
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
]
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
...
...
@@ -126,8 +125,10 @@ class SentencePredictionTask(base_task.Task):
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
pretrain_ckpt_dir
=
self
.
task_config
.
pretrain_checkpoint_dir
if
not
pretrain_ckpt_dir
:
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
pretrain2finetune_mapping
=
{
...
...
@@ -137,10 +138,7 @@ class SentencePredictionTask(base_task.Task):
model
.
checkpoint_items
[
'sentence_prediction.pooler_dense'
],
}
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrain2finetune_mapping
)
latest_pretrain_ckpt
=
tf
.
train
.
latest_checkpoint
(
pretrain_ckpt_dir
)
if
latest_pretrain_ckpt
is
None
:
raise
FileNotFoundError
(
'Cannot find pretrain checkpoint under {}'
.
format
(
pretrain_ckpt_dir
))
status
=
ckpt
.
restore
(
latest_pretrain_ckpt
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'finished loading pretrained checkpoint.'
)
logging
.
info
(
'finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
official/nlp/tasks/sentence_prediction_test.py
View file @
43587c64
...
...
@@ -43,8 +43,10 @@ class SentencePredictionTaskTest(tf.test.TestCase):
def
test_task
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
init_checkpoint
=
self
.
get_temp_dir
(),
network
=
bert
.
BertPretrainerConfig
(
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
...
...
@@ -62,6 +64,21 @@ class SentencePredictionTaskTest(tf.test.TestCase):
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
# Saves a checkpoint.
pretrain_cfg
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
20
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"next_sentence"
)
])
pretrain_model
=
bert
.
instantiate_from_cfg
(
pretrain_cfg
)
ckpt
=
tf
.
train
.
Checkpoint
(
model
=
pretrain_model
,
**
pretrain_model
.
checkpoint_items
)
ckpt
.
save
(
config
.
init_checkpoint
)
task
.
initialize
(
model
)
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment