Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b2af7bc2
Commit
b2af7bc2
authored
Jun 09, 2020
by
A. Unique TensorFlower
Browse files
Move nlp/tasks/sentence_prediction.py
PiperOrigin-RevId: 315613738
parent
d4bb3055
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
336 additions
and
0 deletions
+336
-0
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+19
-0
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+64
-0
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+147
-0
official/nlp/tasks/sentence_prediction_test.py
official/nlp/tasks/sentence_prediction_test.py
+106
-0
No files found.
official/nlp/configs/bert.py
View file @
b2af7bc2
...
@@ -98,3 +98,22 @@ class BertPretrainEvalDataConfig(BertPretrainDataConfig):
...
@@ -98,3 +98,22 @@ class BertPretrainEvalDataConfig(BertPretrainDataConfig):
input_path
:
str
=
""
input_path
:
str
=
""
global_batch_size
:
int
=
512
global_batch_size
:
int
=
512
is_training
:
bool
=
False
is_training
:
bool
=
False
@
dataclasses
.
dataclass
class
BertSentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data of sentence prediction dataset."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
dataclasses
.
dataclass
class
BertSentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev data of MNLI sentence prediction dataset."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
False
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
official/nlp/data/sentence_prediction_dataloader.py
0 → 100644
View file @
b2af7bc2
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
from
typing
import
Mapping
,
Optional
import
tensorflow
as
tf
from
official.core
import
input_reader
class
SentencePredictionDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{
'input_word_ids'
:
record
[
'input_ids'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/tasks/sentence_prediction.py
0 → 100644
View file @
b2af7bc2
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
import
logging
import
dataclasses
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.core
import
base_task
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
bert
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.modeling
import
losses
as
loss_lib
@
dataclasses
.
dataclass
class
SentencePredictionConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `pretrain_checkpoint_dir` and `hub_module_url` can
# be specified.
pretrain_checkpoint_dir
:
str
=
''
hub_module_url
:
str
=
''
network
:
bert
.
BertPretrainerConfig
=
bert
.
BertPretrainerConfig
(
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
768
,
num_classes
=
3
,
dropout_rate
=
0.1
,
name
=
'sentence_prediction'
)
])
train_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
base_task
.
register_task_cls
(
SentencePredictionConfig
)
class
SentencePredictionTask
(
base_task
.
Task
):
"""Task object for sentence_prediction."""
def
__init__
(
self
,
params
=
cfg
.
TaskConfig
):
super
(
SentencePredictionTask
,
self
).
__init__
(
params
)
if
params
.
hub_module_url
and
params
.
pretrain_checkpoint_dir
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`pretrain_checkpoint_dir` can be specified.'
)
if
params
.
hub_module_url
:
self
.
_hub_module
=
hub
.
load
(
params
.
hub_module_url
)
else
:
self
.
_hub_module
=
None
def
build_model
(
self
):
if
self
.
_hub_module
:
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_mask'
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
None
,),
dtype
=
tf
.
int32
,
name
=
'input_type_ids'
)
bert_model
=
hub
.
KerasLayer
(
self
.
_hub_module
,
trainable
=
True
)
pooled_output
,
sequence_output
=
bert_model
(
[
input_word_ids
,
input_mask
,
input_type_ids
])
encoder_from_hub
=
tf
.
keras
.
Model
(
inputs
=
[
input_word_ids
,
input_mask
,
input_type_ids
],
outputs
=
[
sequence_output
,
pooled_output
])
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
,
encoder_network
=
encoder_from_hub
)
else
:
return
bert
.
instantiate_from_cfg
(
self
.
task_config
.
network
)
def
build_losses
(
self
,
features
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
labels
=
features
loss
=
loss_lib
.
weighted_sparse_categorical_crossentropy_loss
(
labels
=
labels
,
predictions
=
tf
.
nn
.
log_softmax
(
model_outputs
[
'sentence_prediction'
],
axis
=-
1
))
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
return
loss
def
build_inputs
(
self
,
params
,
input_context
=
None
):
"""Returns tf.data.Dataset for sentence_prediction task."""
if
params
.
input_path
==
'dummy'
:
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
tf
.
ones
((
1
,
1
),
dtype
=
tf
.
int32
)
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
sentence_prediction_dataloader
.
SentencePredictionDataLoader
(
params
).
load
(
input_context
)
def
build_metrics
(
self
,
training
=
None
):
del
training
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'cls_accuracy'
)
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
outputs
):
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
outputs
[
'sentence_prediction'
])
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
outputs
):
compiled_metrics
.
update_state
(
labels
,
outputs
[
'sentence_prediction'
])
def
initialize
(
self
,
model
):
"""Load a pretrained checkpoint (if exists) and then train from iter 0."""
pretrain_ckpt_dir
=
self
.
task_config
.
pretrain_checkpoint_dir
if
not
pretrain_ckpt_dir
:
return
pretrain2finetune_mapping
=
{
'encoder'
:
model
.
checkpoint_items
[
'encoder'
],
'next_sentence.pooler_dense'
:
model
.
checkpoint_items
[
'sentence_prediction.pooler_dense'
],
}
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrain2finetune_mapping
)
latest_pretrain_ckpt
=
tf
.
train
.
latest_checkpoint
(
pretrain_ckpt_dir
)
if
latest_pretrain_ckpt
is
None
:
raise
FileNotFoundError
(
'Cannot find pretrain checkpoint under {}'
.
format
(
pretrain_ckpt_dir
))
status
=
ckpt
.
restore
(
latest_pretrain_ckpt
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'finished loading pretrained checkpoint.'
)
official/nlp/tasks/sentence_prediction_test.py
0 → 100644
View file @
b2af7bc2
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.tasks.sentence_prediction."""
import
os
import
orbit
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
export_tfhub
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.tasks
import
sentence_prediction
class
SentencePredictionTaskTest
(
tf
.
test
.
TestCase
):
def
_run_task
(
self
,
config
):
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
strategy
=
tf
.
distribute
.
get_strategy
()
dataset
=
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
task
.
build_inputs
,
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
test_task
(
self
):
config
=
sentence_prediction
.
SentencePredictionConfig
(
network
=
bert
.
BertPretrainerConfig
(
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
]),
train_data
=
bert
.
BertSentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
1
))
task
=
sentence_prediction
.
SentencePredictionTask
(
config
)
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
config
.
train_data
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
task
.
validation_step
(
next
(
iterator
),
model
,
metrics
=
metrics
)
def
_export_bert_tfhub
(
self
):
bert_config
=
configs
.
BertConfig
(
vocab_size
=
30522
,
hidden_size
=
16
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
_
,
encoder
=
export_tfhub
.
create_bert_model
(
bert_config
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
hub_destination
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub
.
export_bert_tfhub
(
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
)
return
hub_destination
def
test_task_with_hub
(
self
):
hub_module_url
=
self
.
_export_bert_tfhub
()
config
=
sentence_prediction
.
SentencePredictionConfig
(
hub_module_url
=
hub_module_url
,
network
=
bert
.
BertPretrainerConfig
(
encoders
.
TransformerEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
1
),
num_masked_tokens
=
0
,
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
3
,
name
=
"sentence_prediction"
)
]),
train_data
=
bert
.
BertSentencePredictionDataConfig
(
input_path
=
"dummy"
,
seq_length
=
128
,
global_batch_size
=
10
))
self
.
_run_task
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment