Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31ca3b97
Commit
31ca3b97
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
resovle merge conflicts
parents
3e9d886d
7fcd7cba
Changes
392
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1468 additions
and
475 deletions
+1468
-475
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+0
-1
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+47
-4
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+1
-1
official/nlp/bert/model_saving_utils.py
official/nlp/bert/model_saving_utils.py
+0
-4
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+3
-2
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+4
-1
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+2
-82
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+7
-6
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+91
-0
official/nlp/configs/electra_test.py
official/nlp/configs/electra_test.py
+49
-0
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+42
-4
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+374
-294
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+60
-13
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
official/nlp/data/data_loader_factory.py
official/nlp/data/data_loader_factory.py
+59
-0
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+17
-3
official/nlp/data/question_answering_dataloader.py
official/nlp/data/question_answering_dataloader.py
+95
-0
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+19
-1
official/nlp/data/tagging_data_lib.py
official/nlp/data/tagging_data_lib.py
+346
-0
official/nlp/data/tagging_data_loader.py
official/nlp/data/tagging_data_loader.py
+19
-1
No files found.
official/modeling/tf_utils.py
View file @
31ca3b97
...
...
@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def
get_activation
(
identifier
):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
...
...
official/nlp/albert/run_classifier.py
View file @
31ca3b97
...
...
@@ -14,23 +14,61 @@
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
def
predict
(
strategy
,
albert_config
,
input_meta_data
,
predict_input_fn
):
"""Function outputs both the ground truth predictions as .tsv files."""
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
albert_config
,
input_meta_data
[
'num_labels'
])[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
ground_truth
=
run_classifier_bert
.
get_predictions_and_labels
(
strategy
,
classifier_model
,
predict_input_fn
,
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
ground_truth_labels_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'output_labels.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
ground_truth_labels_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Ground truth results *****'
)
for
label
in
ground_truth
:
output_line
=
'
\t
'
.
join
(
str
(
label
))
+
'
\n
'
writer
.
write
(
output_line
)
return
def
main
(
_
):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
...
@@ -56,9 +94,14 @@ def main(_):
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
if
FLAGS
.
mode
==
'train_and_eval'
:
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
elif
FLAGS
.
mode
==
'predict'
:
predict
(
strategy
,
albert_config
,
input_meta_data
,
eval_input_fn
)
else
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
return
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
...
...
official/nlp/bert/export_tfhub.py
View file @
31ca3b97
...
...
@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
consum
ed
()
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
existing_objects_match
ed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
...
...
official/nlp/bert/model_saving_utils.py
View file @
31ca3b97
...
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
if
checkpoint_dir
:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if
restore_model_using_load_weights
:
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
# tf.train.Checkpoint API was used via custom training loop logic.
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
...
...
official/nlp/bert/model_training_utils.py
View file @
31ca3b97
...
...
@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@
deprecation
.
deprecated
(
None
,
'This function is deprecated. Please use Keras compile/fit instead.'
)
None
,
'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.'
)
def
run_customized_training_loop
(
# pylint: disable=invalid-name
_sentinel
=
None
,
...
...
@@ -557,7 +559,6 @@ def run_customized_training_loop(
for
metric
in
model
.
metrics
:
training_summary
[
metric
.
name
]
=
_float_metric_value
(
metric
)
if
eval_metrics
:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary
[
'last_train_metrics'
]
=
_float_metric_value
(
train_metrics
[
0
])
training_summary
[
'eval_metrics'
]
=
_float_metric_value
(
eval_metrics
[
0
])
...
...
official/nlp/bert/run_classifier.py
View file @
31ca3b97
...
...
@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
))[
0
]
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
False
)[
0
]
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
...
...
official/nlp/configs/bert.py
View file @
31ca3b97
...
...
@@ -24,7 +24,6 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
bert_pretrainer
...
...
@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@
dataclasses
.
dataclass
class
BertPretrainerConfig
(
base_config
.
Config
):
"""BERT encoder configuration."""
num_masked_tokens
:
int
=
76
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
]
if
cls_head_configs
else
[]
def
instantiate_
bert
pretrainer_from_cfg
(
def
instantiate_pretrainer_from_cfg
(
config
:
BertPretrainerConfig
,
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
bert_pretrainer
.
BertPretrainerV2
:
)
->
bert_pretrainer
.
BertPretrainerV2
:
"""Instantiates a BertPretrainer from the config."""
encoder_cfg
=
config
.
encoder
if
encoder_network
is
None
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
return
bert_pretrainer
.
BertPretrainerV2
(
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
encoder_network
=
encoder_network
,
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
dataclasses
.
dataclass
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
False
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
dataclasses
.
dataclass
class
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
False
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
TaggingDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
official/nlp/configs/bert_test.py
View file @
31ca3b97
...
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
...
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
...
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
...
...
@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
encoder
=
bert
.
instantiate_bertpretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"next_sentence.pooler_dense"
])
encoder
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
...
...
official/nlp/configs/electra.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
electra_pretrainer
@
dataclasses
.
dataclass
class
ELECTRAPretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
bert
.
ClsHeadConfig
]
)
->
List
[
layers
.
ClassificationHead
]:
if
cls_head_configs
:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
else
:
return
[]
def
instantiate_pretrainer_from_cfg
(
config
:
ELECTRAPretrainerConfig
,
generator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
discriminator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
)
->
electra_pretrainer
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/electra_test.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
class
ELECTRAModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
)
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/encoders.py
View file @
31ca3b97
...
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
"""
from
typing
import
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
...
...
@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
)
->
networks
.
TransformerEncoder
:
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network
=
networks
.
TransformerEncoder
(
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
config
.
vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
hidden_size
=
config
.
hidden_size
,
seq_length
=
None
,
max_seq_length
=
config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
dropout_rate
=
config
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
config
.
num_layers
,
pooled_output_dim
=
config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerEncoder"
:
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder_cls
))
encoder_network
=
encoder_cls
(
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
...
...
@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
official/nlp/data/classifier_data_lib.py
View file @
31ca3b97
...
...
@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
class
InputExample
(
object
):
"""A single training/test example for simple seq
uence
classification."""
"""A single training/test example for simple seq
regression/
classification."""
def
__init__
(
self
,
guid
,
...
...
@@ -48,8 +48,9 @@ class InputExample(object):
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
label: (Optional) string for classification, float for regression. The
label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int_iden: (Optional) int. The int identification number of example in the
...
...
@@ -84,10 +85,12 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
"""Base class for
data
converters for seq
uence
classification data
sets."""
"""Base class for converters for seq
regression/
classification datasets."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
self
.
process_text_fn
=
process_text_fn
self
.
is_regression
=
False
self
.
label_type
=
None
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
...
...
@@ -121,143 +124,158 @@ class DataProcessor(object):
return
lines
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
[]
for
language
in
self
.
languages
:
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"
contradiction"
,
"entailment"
,
"neutral
"
]
return
[
"
0"
,
"1
"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XNLI"
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header.
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
class
XtremeXnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
__init__
(
self
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"
contradiction"
,
"entailment"
,
"neutral
"
]
return
[
"
0"
,
"1
"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-XNLI"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
if
set_type
==
"test"
:
label
=
"0"
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
PawsxProcessor
(
DataProcessor
):
...
...
@@ -289,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -302,69 +320,15 @@ class PawsxProcessor(DataProcessor):
"""See base class."""
lines
=
[]
for
lang
in
PawsxProcessor
.
supported_languages
:
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
)))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
XtremePawsxProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
label
=
self
.
process_text_fn
(
line
[
3
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -373,12 +337,12 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k
.tsv"
))
[
1
:]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
label
=
self
.
process_text_fn
(
line
[
3
])
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
...
...
@@ -393,54 +357,8 @@ class XtremePawsxProcessor(DataProcessor):
return
"XTREME-PAWS-X"
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
class
QnliProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -450,7 +368,7 @@ class MrpcProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -459,26 +377,28 @@ class MrpcProcessor(DataProcessor):
def
get_labels
(
self
):
"""See base class."""
return
[
"
0"
,
"1
"
]
return
[
"
entailment"
,
"not_entailment
"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"
MRPC
"
return
"
QNLI
"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
guid
=
"%s-%s"
%
(
set_type
,
1
)
if
set_type
==
"test"
:
label
=
"0"
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
"entailment"
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -512,9 +432,9 @@ class QqpProcessor(DataProcessor):
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
...
...
@@ -529,52 +449,6 @@ class QqpProcessor(DataProcessor):
return
examples
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
class
RteProcessor
(
DataProcessor
):
"""Processor for the RTE data set (GLUE version)."""
...
...
@@ -605,7 +479,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -650,9 +524,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -667,8 +541,14 @@ class SstProcessor(DataProcessor):
return
examples
class
QnliProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
class
StsBProcessor
(
DataProcessor
):
"""Processor for the STS-B data set (GLUE version)."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
StsBProcessor
,
self
).
__init__
(
process_text_fn
=
process_text_fn
)
self
.
is_regression
=
True
self
.
label_type
=
float
self
.
_labels
=
None
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -678,7 +558,7 @@ class QnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -687,28 +567,26 @@ class QnliProcessor(DataProcessor):
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
return
self
.
_labels
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"
QNLI
"
return
"
STS-B
"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
7
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
8
])
if
set_type
==
"test"
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
"entailment"
label
=
0.0
else
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
label
=
self
.
label_type
(
tokenization
.
convert_to_unicode
(
line
[
9
]))
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -728,6 +606,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
...
...
@@ -745,6 +625,7 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
"""
def
__init__
(
self
,
...
...
@@ -784,6 +665,9 @@ class TfdsProcessor(DataProcessor):
self
.
label_type
=
dtype_map
[
d
.
get
(
"label_type"
,
"int"
)]
self
.
is_regression
=
cast_str_to_bool
(
d
.
get
(
"is_regression"
,
"False"
))
self
.
weight_key
=
d
.
get
(
"weight_key"
,
None
)
self
.
skip_label
=
d
.
get
(
"skip_label"
,
None
)
if
self
.
skip_label
is
not
None
:
self
.
skip_label
=
self
.
label_type
(
self
.
skip_label
)
def
get_train_examples
(
self
,
data_dir
):
assert
data_dir
is
None
...
...
@@ -804,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
...
@@ -822,6 +706,8 @@ class TfdsProcessor(DataProcessor):
if
self
.
text_b_key
:
text_b
=
self
.
process_text_fn
(
example
[
self
.
text_b_key
])
label
=
self
.
label_type
(
example
[
self
.
label_key
])
if
self
.
skip_label
is
not
None
and
label
==
self
.
skip_label
:
continue
if
self
.
weight_key
:
weight
=
float
(
example
[
self
.
weight_key
])
examples
.
append
(
...
...
@@ -862,7 +748,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -879,6 +765,200 @@ class WnliProcessor(DataProcessor):
return
examples
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
[]
for
language
in
self
.
languages
:
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XNLI"
class
XtremePawsxProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
XtremeXnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-XNLI"
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
...
...
@@ -989,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
official/nlp/data/create_finetuning_data.py
View file @
31ca3b97
...
...
@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
tagging_data_lib
FLAGS
=
flags
.
FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags
.
DEFINE_enum
(
"fine_tuning_task_type"
,
"classification"
,
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
],
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
,
"tagging"
],
"The name of the BERT fine tuning task for which data "
"will be generated.
.
"
)
"will be generated."
)
# BERT classification specific flags.
flags
.
DEFINE_string
(
...
...
@@ -48,30 +50,41 @@ flags.DEFINE_string(
"for the task."
)
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"QNLI"
,
"QQP"
,
"SST-2"
,
"XNLI"
,
"PAWS-X"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task
specific flag.
# XNLI task
-
specific flag.
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
# BERT Squad task specific flags.
# Retrieval task-specific flags.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# Tagging task-specific flags.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
# BERT Squad task-specific flags.
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
...
...
@@ -171,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
...
...
@@ -180,6 +194,8 @@ def generate_classifier_dataset():
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
...
...
@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
FLAGS
.
max_seq_length
)
def
generate_tagging_dataset
():
"""Generates tagging dataset."""
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
}
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
task_name
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
elif
FLAGS
.
tokenizer_impl
==
"sentence_piece"
:
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported tokenizer_impl: %s"
%
FLAGS
.
tokenizer_impl
)
processor
=
processors
[
task_name
]()
return
tagging_data_lib
.
generate_tf_record_from_data_file
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
train_data_output_path
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
processor_text_fn
)
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
...
...
@@ -304,8 +348,11 @@ def main(_):
input_meta_data
=
generate_regression_dataset
()
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
input_meta_data
=
generate_retrieval_dataset
()
el
se
:
el
if
FLAGS
.
fine_tuning_task_type
==
"squad"
:
input_meta_data
=
generate_squad_dataset
()
else
:
assert
FLAGS
.
fine_tuning_task_type
==
"tagging"
input_meta_data
=
generate_tagging_dataset
()
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
FLAGS
.
meta_data_file_path
))
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
meta_data_file_path
,
"w"
)
as
writer
:
...
...
official/nlp/data/create_pretraining_data.py
View file @
31ca3b97
...
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
import
collections
import
itertools
import
random
from
absl
import
app
...
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
max_predictions_per_seq
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
...
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
return
instances
...
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
...
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
tokens
=
tokens
,
segment_ids
=
segment_ids
,
...
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
):
"""Creates the predictions for the masked LM objective."""
# All of the n-grams of this size have been used.
if
not
ngrams
[
sz
]:
continue
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
# Found a usable n-gram! Mark its tokens as masked and add it to return.
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
gram_start_pos
=
i
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
covered_indexes
=
set
()
for
index_set
in
cand_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
output_tokens
=
list
(
tokens
)
for
gram
in
masked_grams
:
# 80% of the time, replace all n-gram tokens with [MASK]
if
rng
.
random
()
<
0.8
:
replacement_action
=
lambda
idx
:
"[MASK]"
else
:
# 10% of the time, keep all the original n-gram tokens.
if
rng
.
random
()
<
0.5
:
replacement_action
=
lambda
idx
:
tokens
[
idx
]
# 10% of the time, replace each n-gram token with a random word.
else
:
# 10% of the time, keep original
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
official/nlp/data/data_loader_factory.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from
official.utils
import
registry
_REGISTERED_DATA_LOADER_CLS
=
{}
def
register_data_loader_cls
(
data_config_cls
):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DATA_LOADER_CLS
,
data_config_cls
)
def
get_data_loader
(
data_config
):
"""Creates a data_loader from data_config."""
return
registry
.
lookup
(
_REGISTERED_DATA_LOADER_CLS
,
data_config
.
__class__
)(
data_config
)
official/nlp/data/pretrain_dataloader.py
View file @
31ca3b97
...
...
@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
class
BertPretrainDataLoader
:
"""A class to load dataset for bert pretraining task."""
...
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/data/question_answering_dataloader.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
''
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
# Settings below are question answering specific.
version_2_with_negative
:
bool
=
False
# Settings below are only used for eval mode.
input_preprocessed_data_path
:
str
=
''
doc_stride
:
int
=
128
query_length
:
int
=
64
vocab_file
:
str
=
''
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
class
QuestionAnsweringDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_is_training
=
params
.
is_training
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
elif
name
==
'segment_ids'
:
x
[
'input_type_ids'
]
=
tensor
else
:
x
[
name
]
=
tensor
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/data/sentence_prediction_dataloader.py
View file @
31ca3b97
...
...
@@ -15,11 +15,28 @@
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
''
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
class
SentencePredictionDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
...
...
@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/data/tagging_data_lib.py
0 → 100644
View file @
31ca3b97
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS."""
import
collections
import
os
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.data
import
classifier_data_lib
# A negative label id for the padding label, which will not contribute
# to loss/metrics in training.
_PADDING_LABEL_ID
=
-
1
# The special unknown token, used to substitute a word which has too many
# subwords after tokenization.
_UNK_TOKEN
=
"[UNK]"
class
InputExample
(
object
):
"""A single training/test example for token classification."""
def
__init__
(
self
,
sentence_id
,
words
=
None
,
label_ids
=
None
):
"""Constructs an InputExample."""
self
.
sentence_id
=
sentence_id
self
.
words
=
words
if
words
else
[]
self
.
label_ids
=
label_ids
if
label_ids
else
[]
def
add_word_and_label_id
(
self
,
word
,
label_id
):
"""Adds word and label_id pair in the example."""
self
.
words
.
append
(
word
)
self
.
label_ids
.
append
(
label_id
)
def
_read_one_file
(
file_name
,
label_list
):
"""Reads one file and returns a list of `InputExample` instances."""
lines
=
tf
.
io
.
gfile
.
GFile
(
file_name
,
"r"
).
readlines
()
examples
=
[]
label_id_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
sentence_id
=
0
example
=
InputExample
(
sentence_id
=
0
)
for
line
in
lines
:
line
=
line
.
strip
(
"
\n
"
)
if
line
:
# The format is: <token>\t<label> for train/dev set and <token> for test.
items
=
line
.
split
(
"
\t
"
)
assert
len
(
items
)
==
2
or
len
(
items
)
==
1
token
=
items
[
0
].
strip
()
# Assign a dummy label_id for test set
label_id
=
label_id_map
[
items
[
1
].
strip
()]
if
len
(
items
)
==
2
else
0
example
.
add_word_and_label_id
(
token
,
label_id
)
else
:
# Empty line indicates a new sentence.
if
example
.
words
:
examples
.
append
(
example
)
sentence_id
+=
1
example
=
InputExample
(
sentence_id
=
sentence_id
)
if
example
.
words
:
examples
.
append
(
example
)
return
examples
class
PanxProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Processor for the Panx data set."""
supported_languages
=
[
"ar"
,
"he"
,
"vi"
,
"id"
,
"jv"
,
"ms"
,
"tl"
,
"eu"
,
"ml"
,
"ta"
,
"te"
,
"af"
,
"nl"
,
"en"
,
"de"
,
"el"
,
"bn"
,
"hi"
,
"mr"
,
"ur"
,
"fa"
,
"fr"
,
"it"
,
"pt"
,
"es"
,
"bg"
,
"ru"
,
"ja"
,
"ka"
,
"ko"
,
"th"
,
"sw"
,
"yo"
,
"my"
,
"zh"
,
"kk"
,
"tr"
,
"et"
,
"fi"
,
"hu"
]
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
for
language
in
self
.
supported_languages
:
examples_dict
[
language
]
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"test-%s.tsv"
%
language
),
self
.
get_labels
())
return
examples_dict
def
get_labels
(
self
):
return
[
"O"
,
"B-PER"
,
"I-PER"
,
"B-LOC"
,
"I-LOC"
,
"B-ORG"
,
"I-ORG"
]
@
staticmethod
def
get_processor_name
():
return
"panx"
class
UdposProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Processor for the Udpos data set."""
supported_languages
=
[
"af"
,
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"et"
,
"eu"
,
"fa"
,
"fi"
,
"fr"
,
"he"
,
"hi"
,
"hu"
,
"id"
,
"it"
,
"ja"
,
"kk"
,
"ko"
,
"mr"
,
"nl"
,
"pt"
,
"ru"
,
"ta"
,
"te"
,
"th"
,
"tl"
,
"tr"
,
"ur"
,
"vi"
,
"yo"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
for
language
in
self
.
supported_languages
:
examples_dict
[
language
]
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"test-%s.tsv"
%
language
),
self
.
get_labels
())
return
examples_dict
def
get_labels
(
self
):
return
[
"ADJ"
,
"ADP"
,
"ADV"
,
"AUX"
,
"CCONJ"
,
"DET"
,
"INTJ"
,
"NOUN"
,
"NUM"
,
"PART"
,
"PRON"
,
"PROPN"
,
"PUNCT"
,
"SCONJ"
,
"SYM"
,
"VERB"
,
"X"
]
@
staticmethod
def
get_processor_name
():
return
"udpos"
def
_tokenize_example
(
example
,
max_length
,
tokenizer
,
text_preprocessing
=
None
):
"""Tokenizes words and breaks long example into short ones."""
# Needs additional [CLS] and [SEP] tokens.
max_length
=
max_length
-
2
new_examples
=
[]
new_example
=
InputExample
(
sentence_id
=
example
.
sentence_id
)
for
i
,
word
in
enumerate
(
example
.
words
):
if
any
([
x
<
0
for
x
in
example
.
label_ids
]):
raise
ValueError
(
"Unexpected negative label_id: %s"
%
example
.
label_ids
)
if
text_preprocessing
:
word
=
text_preprocessing
(
word
)
subwords
=
tokenizer
.
tokenize
(
word
)
if
(
not
subwords
or
len
(
subwords
)
>
max_length
)
and
word
:
subwords
=
[
_UNK_TOKEN
]
if
len
(
subwords
)
+
len
(
new_example
.
words
)
>
max_length
:
# Start a new example.
new_examples
.
append
(
new_example
)
new_example
=
InputExample
(
sentence_id
=
example
.
sentence_id
)
for
j
,
subword
in
enumerate
(
subwords
):
# Use the real label for the first subword, and pad label for
# the remainings.
subword_label
=
example
.
label_ids
[
i
]
if
j
==
0
else
_PADDING_LABEL_ID
new_example
.
add_word_and_label_id
(
subword
,
subword_label
)
if
new_example
.
words
:
new_examples
.
append
(
new_example
)
return
new_examples
def
_convert_single_example
(
example
,
max_seq_length
,
tokenizer
):
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
tokens
=
[
"[CLS]"
]
tokens
.
extend
(
example
.
words
)
tokens
.
append
(
"[SEP]"
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
label_ids
=
[
_PADDING_LABEL_ID
]
label_ids
.
extend
(
example
.
label_ids
)
label_ids
.
append
(
_PADDING_LABEL_ID
)
segment_ids
=
[
0
]
*
len
(
input_ids
)
input_mask
=
[
1
]
*
len
(
input_ids
)
# Pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
label_ids
.
append
(
_PADDING_LABEL_ID
)
def
create_int_feature
(
values
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"label_ids"
]
=
create_int_feature
(
label_ids
)
features
[
"sentence_id"
]
=
create_int_feature
([
example
.
sentence_id
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
return
tf_example
def
write_example_to_file
(
examples
,
tokenizer
,
max_seq_length
,
output_file
,
text_preprocessing
=
None
):
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
Note that the words inside each example will be tokenized and be applied by
`text_preprocessing` if available. Also, if the length of sentence (plus
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
will be broken into multiple short examples. For example:
Example (text_preprocessing=lowercase, max_seq_length=5)
words: ["What", "a", "great", "weekend"]
labels: [ 7, 5, 9, 10]
sentence_id: 0
preprocessed: ["what", "a", "great", "weekend"]
tokenized: ["what", "a", "great", "week", "##end"]
will result in two tf.example protos:
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
label_ids: [-1, 7, 5, 9, -1]
input_mask: [ 1, 1, 1, 1, 1]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
label_ids: [-1, 10, -1, -1, -1]
input_mask: [ 1, 1, 1, 0, 0]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
Note the use of -1 in `label_ids` to indicate that a token should not be
considered for classification (e.g., trailing ## wordpieces or special
token). Token classification models should accordingly ignore these when
calculating loss, metrics, etc...
Args:
examples: A list of `InputExample` instances.
tokenizer: The tokenizer to be applied on the data.
max_seq_length: Maximum length of generated sequences.
output_file: The name of the output tfrecord file.
text_preprocessing: optional preprocessing run on each word prior to
tokenization.
Returns:
The total number of tf.train.Example proto written to file.
"""
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
num_tokenized_examples
=
0
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d to %s"
,
ex_index
,
len
(
examples
),
output_file
)
tokenized_examples
=
_tokenize_example
(
example
,
max_seq_length
,
tokenizer
,
text_preprocessing
)
num_tokenized_examples
+=
len
(
tokenized_examples
)
for
per_tokenized_example
in
tokenized_examples
:
tf_example
=
_convert_single_example
(
per_tokenized_example
,
max_seq_length
,
tokenizer
)
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
return
num_tokenized_examples
def
token_classification_meta_data
(
train_data_size
,
max_seq_length
,
num_labels
,
eval_data_size
=
None
,
test_data_size
=
None
,
label_list
=
None
,
processor_type
=
None
):
"""Creates metadata for tagging (token classification) datasets."""
meta_data
=
{
"train_data_size"
:
train_data_size
,
"max_seq_length"
:
max_seq_length
,
"num_labels"
:
num_labels
,
"task_type"
:
"tagging"
,
"label_type"
:
"int"
,
"label_shape"
:
[
max_seq_length
],
}
if
eval_data_size
:
meta_data
[
"eval_data_size"
]
=
eval_data_size
if
test_data_size
:
meta_data
[
"test_data_size"
]
=
test_data_size
if
label_list
:
meta_data
[
"label_list"
]
=
label_list
if
processor_type
:
meta_data
[
"processor_type"
]
=
processor_type
return
meta_data
def
generate_tf_record_from_data_file
(
processor
,
data_dir
,
tokenizer
,
max_seq_length
,
train_data_output_path
,
eval_data_output_path
,
test_data_output_path
,
text_preprocessing
):
"""Generates tfrecord files from the raw data."""
common_kwargs
=
dict
(
tokenizer
=
tokenizer
,
max_seq_length
=
max_seq_length
,
text_preprocessing
=
text_preprocessing
)
train_examples
=
processor
.
get_train_examples
(
data_dir
)
train_data_size
=
write_example_to_file
(
train_examples
,
output_file
=
train_data_output_path
,
**
common_kwargs
)
eval_examples
=
processor
.
get_dev_examples
(
data_dir
)
eval_data_size
=
write_example_to_file
(
eval_examples
,
output_file
=
eval_data_output_path
,
**
common_kwargs
)
test_input_data_examples
=
processor
.
get_test_examples
(
data_dir
)
test_data_size
=
{}
for
language
,
examples
in
test_input_data_examples
.
items
():
test_data_size
[
language
]
=
write_example_to_file
(
examples
,
output_file
=
test_data_output_path
.
format
(
language
),
**
common_kwargs
)
labels
=
processor
.
get_labels
()
meta_data
=
token_classification_meta_data
(
train_data_size
,
max_seq_length
,
len
(
labels
),
eval_data_size
,
test_data_size
,
label_list
=
labels
,
processor_type
=
processor
.
get_processor_name
())
return
meta_data
official/nlp/data/tagging_data_loader.py
View file @
31ca3b97
...
...
@@ -15,17 +15,30 @@
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
is_training
:
bool
=
True
seq_length
:
int
=
128
include_sentence_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
TaggingDataConfig
)
class
TaggingDataLoader
:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
:
TaggingDataConfig
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_include_sentence_id
=
params
.
include_sentence_id
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
...
...
@@ -35,6 +48,9 @@ class TaggingDataLoader:
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_include_sentence_id
:
name_to_features
[
'sentence_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...
...
@@ -54,6 +70,8 @@ class TaggingDataLoader:
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
if
self
.
_include_sentence_id
:
x
[
'sentence_id'
]
=
record
[
'sentence_id'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
...
...
Prev
1
2
3
4
5
6
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment