Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31ca3b97
Commit
31ca3b97
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
resovle merge conflicts
parents
3e9d886d
7fcd7cba
Changes
392
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1468 additions
and
475 deletions
+1468
-475
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+0
-1
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+47
-4
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+1
-1
official/nlp/bert/model_saving_utils.py
official/nlp/bert/model_saving_utils.py
+0
-4
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+3
-2
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+4
-1
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+2
-82
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+7
-6
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+91
-0
official/nlp/configs/electra_test.py
official/nlp/configs/electra_test.py
+49
-0
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+42
-4
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+374
-294
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+60
-13
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
official/nlp/data/data_loader_factory.py
official/nlp/data/data_loader_factory.py
+59
-0
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+17
-3
official/nlp/data/question_answering_dataloader.py
official/nlp/data/question_answering_dataloader.py
+95
-0
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+19
-1
official/nlp/data/tagging_data_lib.py
official/nlp/data/tagging_data_lib.py
+346
-0
official/nlp/data/tagging_data_loader.py
official/nlp/data/tagging_data_loader.py
+19
-1
No files found.
official/modeling/tf_utils.py
View file @
31ca3b97
...
@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
...
@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def
get_activation
(
identifier
):
def
get_activation
(
identifier
):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
...
...
official/nlp/albert/run_classifier.py
View file @
31ca3b97
...
@@ -14,23 +14,61 @@
...
@@ -14,23 +14,61 @@
# ==============================================================================
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
"""ALBERT classification finetuning runner in tf2.x."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
json
import
json
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
predict
(
strategy
,
albert_config
,
input_meta_data
,
predict_input_fn
):
"""Function outputs both the ground truth predictions as .tsv files."""
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
albert_config
,
input_meta_data
[
'num_labels'
])[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
ground_truth
=
run_classifier_bert
.
get_predictions_and_labels
(
strategy
,
classifier_model
,
predict_input_fn
,
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
ground_truth_labels_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'output_labels.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
ground_truth_labels_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Ground truth results *****'
)
for
label
in
ground_truth
:
output_line
=
'
\t
'
.
join
(
str
(
label
))
+
'
\n
'
writer
.
write
(
output_line
)
return
def
main
(
_
):
def
main
(
_
):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
@@ -56,9 +94,14 @@ def main(_):
...
@@ -56,9 +94,14 @@ def main(_):
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
FLAGS
.
bert_config_file
)
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
if
FLAGS
.
mode
==
'train_and_eval'
:
train_input_fn
,
eval_input_fn
)
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
elif
FLAGS
.
mode
==
'predict'
:
predict
(
strategy
,
albert_config
,
input_meta_data
,
eval_input_fn
)
else
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
return
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'bert_config_file'
)
...
...
official/nlp/bert/export_tfhub.py
View file @
31ca3b97
...
@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
...
@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case
,
vocab_file
)
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
consum
ed
()
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
existing_objects_match
ed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
...
...
official/nlp/bert/model_saving_utils.py
View file @
31ca3b97
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
if
checkpoint_dir
:
if
checkpoint_dir
:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if
restore_model_using_load_weights
:
if
restore_model_using_load_weights
:
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
# tf.train.Checkpoint API was used via custom training loop logic.
else
:
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
...
...
official/nlp/bert/model_training_utils.py
View file @
31ca3b97
...
@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
...
@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@
deprecation
.
deprecated
(
@
deprecation
.
deprecated
(
None
,
'This function is deprecated. Please use Keras compile/fit instead.'
)
None
,
'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.'
)
def
run_customized_training_loop
(
def
run_customized_training_loop
(
# pylint: disable=invalid-name
# pylint: disable=invalid-name
_sentinel
=
None
,
_sentinel
=
None
,
...
@@ -557,7 +559,6 @@ def run_customized_training_loop(
...
@@ -557,7 +559,6 @@ def run_customized_training_loop(
for
metric
in
model
.
metrics
:
for
metric
in
model
.
metrics
:
training_summary
[
metric
.
name
]
=
_float_metric_value
(
metric
)
training_summary
[
metric
.
name
]
=
_float_metric_value
(
metric
)
if
eval_metrics
:
if
eval_metrics
:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary
[
'last_train_metrics'
]
=
_float_metric_value
(
training_summary
[
'last_train_metrics'
]
=
_float_metric_value
(
train_metrics
[
0
])
train_metrics
[
0
])
training_summary
[
'eval_metrics'
]
=
_float_metric_value
(
eval_metrics
[
0
])
training_summary
[
'eval_metrics'
]
=
_float_metric_value
(
eval_metrics
[
0
])
...
...
official/nlp/bert/run_classifier.py
View file @
31ca3b97
...
@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
...
@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
))[
0
]
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
False
)[
0
]
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
...
...
official/nlp/configs/bert.py
View file @
31ca3b97
...
@@ -24,7 +24,6 @@ import tensorflow as tf
...
@@ -24,7 +24,6 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
bert_pretrainer
from
official.nlp.modeling.models
import
bert_pretrainer
...
@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
...
@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BertPretrainerConfig
(
base_config
.
Config
):
class
BertPretrainerConfig
(
base_config
.
Config
):
"""BERT encoder configuration."""
"""BERT encoder configuration."""
num_masked_tokens
:
int
=
76
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
...
@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
]
if
cls_head_configs
else
[]
]
if
cls_head_configs
else
[]
def
instantiate_
bert
pretrainer_from_cfg
(
def
instantiate_pretrainer_from_cfg
(
config
:
BertPretrainerConfig
,
config
:
BertPretrainerConfig
,
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
bert_pretrainer
.
BertPretrainerV2
:
)
->
bert_pretrainer
.
BertPretrainerV2
:
"""Instantiates a BertPretrainer from the config."""
"""Instantiates a BertPretrainer from the config."""
encoder_cfg
=
config
.
encoder
encoder_cfg
=
config
.
encoder
if
encoder_network
is
None
:
if
encoder_network
is
None
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
return
bert_pretrainer
.
BertPretrainerV2
(
return
bert_pretrainer
.
BertPretrainerV2
(
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
encoder_network
=
encoder_network
,
encoder_network
=
encoder_network
,
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
))
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
dataclasses
.
dataclass
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
False
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
dataclasses
.
dataclass
class
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
False
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
TaggingDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
official/nlp/configs/bert_test.py
View file @
31ca3b97
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def
test_network_invocation
(
self
):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
encoder
=
bert
.
instantiate_bertpretrainer_from_cfg
(
config
)
encoder
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
self
.
assertSameElements
(
[
"encoder"
,
"next_sentence.pooler_dense"
])
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
official/nlp/configs/electra.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
electra_pretrainer
@
dataclasses
.
dataclass
class
ELECTRAPretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
bert
.
ClsHeadConfig
]
)
->
List
[
layers
.
ClassificationHead
]:
if
cls_head_configs
:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
else
:
return
[]
def
instantiate_pretrainer_from_cfg
(
config
:
ELECTRAPretrainerConfig
,
generator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
discriminator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
)
->
electra_pretrainer
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/electra_test.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
class
ELECTRAModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
)
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/encoders.py
View file @
31ca3b97
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
Includes configurations and instantiation methods.
"""
"""
from
typing
import
Optional
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
def
instantiate_encoder_from_cfg
(
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
)
->
networks
.
TransformerEncoder
:
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network
=
networks
.
TransformerEncoder
(
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
config
.
vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
hidden_size
=
config
.
hidden_size
,
seq_length
=
None
,
max_seq_length
=
config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
dropout_rate
=
config
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
config
.
num_layers
,
pooled_output_dim
=
config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerEncoder"
:
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder_cls
))
encoder_network
=
encoder_cls
(
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
num_layers
=
config
.
num_layers
,
...
@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
...
@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length
=
config
.
max_position_embeddings
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
return
encoder_network
official/nlp/data/classifier_data_lib.py
View file @
31ca3b97
...
@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
...
@@ -31,7 +31,7 @@ from official.nlp.bert import tokenization
class
InputExample
(
object
):
class
InputExample
(
object
):
"""A single training/test example for simple seq
uence
classification."""
"""A single training/test example for simple seq
regression/
classification."""
def
__init__
(
self
,
def
__init__
(
self
,
guid
,
guid
,
...
@@ -48,8 +48,9 @@ class InputExample(object):
...
@@ -48,8 +48,9 @@ class InputExample(object):
sequence tasks, only this sequence must be specified.
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
label: (Optional) string for classification, float for regression. The
specified for train and dev examples, but not for test examples.
label of the example. This should be specified for train and dev
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
weight: (Optional) float. The weight of the example to be used during
training.
training.
int_iden: (Optional) int. The int identification number of example in the
int_iden: (Optional) int. The int identification number of example in the
...
@@ -84,10 +85,12 @@ class InputFeatures(object):
...
@@ -84,10 +85,12 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
class
DataProcessor
(
object
):
"""Base class for
data
converters for seq
uence
classification data
sets."""
"""Base class for converters for seq
regression/
classification datasets."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
self
.
process_text_fn
=
process_text_fn
self
.
process_text_fn
=
process_text_fn
self
.
is_regression
=
False
self
.
label_type
=
None
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
"""Gets a collection of `InputExample`s for the train set."""
...
@@ -121,143 +124,158 @@ class DataProcessor(object):
...
@@ -121,143 +124,158 @@ class DataProcessor(object):
return
lines
return
lines
class
XnliProcessor
(
DataProcessor
):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
"""Processor for the CoLA data set (GLUE version)."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
[]
return
self
.
_create_examples
(
for
language
in
self
.
languages
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
return
self
.
_create_examples
(
examples
=
[]
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
return
self
.
_create_examples
(
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"
contradiction"
,
"entailment"
,
"neutral
"
]
return
[
"
0"
,
"1
"
]
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"XNLI"
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header.
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
class
XtremeXnliProcessor
(
DataProcessor
):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
"""Processor for the MultiNLI data set (GLUE version)."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
def
__init__
(
self
,
"ur"
,
"vi"
,
"zh"
mnli_type
=
"matched"
,
]
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
if
i
==
0
:
text_a
=
self
.
process_text_fn
(
line
[
0
])
continue
text_b
=
self
.
process_text_fn
(
line
[
1
])
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
label
=
self
.
process_text_fn
(
line
[
2
])
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
return
self
.
_create_examples
(
examples
=
[]
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
return
self
.
_create_examples
(
for
lang
in
self
.
supported_languages
:
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"
contradiction"
,
"entailment"
,
"neutral
"
]
return
[
"
0"
,
"1
"
]
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"XTREME-XNLI"
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
if
set_type
==
"test"
:
label
=
"0"
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
PawsxProcessor
(
DataProcessor
):
class
PawsxProcessor
(
DataProcessor
):
...
@@ -289,7 +307,7 @@ class PawsxProcessor(DataProcessor):
...
@@ -289,7 +307,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
@@ -302,69 +320,15 @@ class PawsxProcessor(DataProcessor):
...
@@ -302,69 +320,15 @@ class PawsxProcessor(DataProcessor):
"""See base class."""
"""See base class."""
lines
=
[]
lines
=
[]
for
lang
in
PawsxProcessor
.
supported_languages
:
for
lang
in
PawsxProcessor
.
supported_languages
:
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
)))
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
XtremePawsxProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
label
=
self
.
process_text_fn
(
line
[
2
])
label
=
self
.
process_text_fn
(
line
[
3
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
...
@@ -373,12 +337,12 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -373,12 +337,12 @@ class XtremePawsxProcessor(DataProcessor):
"""See base class."""
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k
.tsv"
))
[
1
:]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
label
=
"0"
label
=
self
.
process_text_fn
(
line
[
3
])
examples_by_lang
[
lang
].
append
(
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
return
examples_by_lang
...
@@ -393,54 +357,8 @@ class XtremePawsxProcessor(DataProcessor):
...
@@ -393,54 +357,8 @@ class XtremePawsxProcessor(DataProcessor):
return
"XTREME-PAWS-X"
return
"XTREME-PAWS-X"
class
MnliProcessor
(
DataProcessor
):
class
QnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
"""Processor for the QNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
line
[
8
])
text_b
=
self
.
process_text_fn
(
line
[
9
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -450,7 +368,7 @@ class MrpcProcessor(DataProcessor):
...
@@ -450,7 +368,7 @@ class MrpcProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -459,26 +377,28 @@ class MrpcProcessor(DataProcessor):
...
@@ -459,26 +377,28 @@ class MrpcProcessor(DataProcessor):
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"
0"
,
"1
"
]
return
[
"
entailment"
,
"not_entailment
"
]
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"
MRPC
"
return
"
QNLI
"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
1
)
text_a
=
self
.
process_text_fn
(
line
[
3
])
text_b
=
self
.
process_text_fn
(
line
[
4
])
if
set_type
==
"test"
:
if
set_type
==
"test"
:
label
=
"0"
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
"entailment"
else
:
else
:
label
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
...
@@ -512,9 +432,9 @@ class QqpProcessor(DataProcessor):
...
@@ -512,9 +432,9 @@ class QqpProcessor(DataProcessor):
return
"QQP"
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
...
@@ -529,52 +449,6 @@ class QqpProcessor(DataProcessor):
...
@@ -529,52 +449,6 @@ class QqpProcessor(DataProcessor):
return
examples
return
examples
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
if
set_type
==
"test"
:
text_a
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
else
:
text_a
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
class
RteProcessor
(
DataProcessor
):
class
RteProcessor
(
DataProcessor
):
"""Processor for the RTE data set (GLUE version)."""
"""Processor for the RTE data set (GLUE version)."""
...
@@ -605,7 +479,7 @@ class RteProcessor(DataProcessor):
...
@@ -605,7 +479,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -650,9 +524,9 @@ class SstProcessor(DataProcessor):
...
@@ -650,9 +524,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
@@ -667,8 +541,14 @@ class SstProcessor(DataProcessor):
...
@@ -667,8 +541,14 @@ class SstProcessor(DataProcessor):
return
examples
return
examples
class
QnliProcessor
(
DataProcessor
):
class
StsBProcessor
(
DataProcessor
):
"""Processor for the QNLI data set (GLUE version)."""
"""Processor for the STS-B data set (GLUE version)."""
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
StsBProcessor
,
self
).
__init__
(
process_text_fn
=
process_text_fn
)
self
.
is_regression
=
True
self
.
label_type
=
float
self
.
_labels
=
None
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -678,7 +558,7 @@ class QnliProcessor(DataProcessor):
...
@@ -678,7 +558,7 @@ class QnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev
_matched
"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
...
@@ -687,28 +567,26 @@ class QnliProcessor(DataProcessor):
...
@@ -687,28 +567,26 @@ class QnliProcessor(DataProcessor):
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
return
self
.
_labels
@
staticmethod
@
staticmethod
def
get_processor_name
():
def
get_processor_name
():
"""See base class."""
"""See base class."""
return
"
QNLI
"
return
"
STS-B
"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
7
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
8
])
if
set_type
==
"test"
:
if
set_type
==
"test"
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
0.0
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
"entailment"
else
:
else
:
text_a
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
self
.
label_type
(
tokenization
.
convert_to_unicode
(
line
[
9
]))
text_b
=
tokenization
.
convert_to_unicode
(
line
[
2
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
...
@@ -728,6 +606,8 @@ class TfdsProcessor(DataProcessor):
...
@@ -728,6 +606,8 @@ class TfdsProcessor(DataProcessor):
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
"is_regression=true,label_type=float"
"is_regression=true,label_type=float"
tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
"skip_label=-1"
Possible parameters (please refer to the documentation of Tensorflow Datasets
Possible parameters (please refer to the documentation of Tensorflow Datasets
(TFDS) for the meaning of individual parameters):
(TFDS) for the meaning of individual parameters):
dataset: Required dataset name (potentially with subset and version number).
dataset: Required dataset name (potentially with subset and version number).
...
@@ -745,6 +625,7 @@ class TfdsProcessor(DataProcessor):
...
@@ -745,6 +625,7 @@ class TfdsProcessor(DataProcessor):
label_type: Type of the label key (defaults to `int`).
label_type: Type of the label key (defaults to `int`).
weight_key: Key of the float sample weight (is not used if not provided).
weight_key: Key of the float sample weight (is not used if not provided).
is_regression: Whether the task is a regression problem (defaults to False).
is_regression: Whether the task is a regression problem (defaults to False).
skip_label: Skip examples with given label (defaults to None).
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -784,6 +665,9 @@ class TfdsProcessor(DataProcessor):
...
@@ -784,6 +665,9 @@ class TfdsProcessor(DataProcessor):
self
.
label_type
=
dtype_map
[
d
.
get
(
"label_type"
,
"int"
)]
self
.
label_type
=
dtype_map
[
d
.
get
(
"label_type"
,
"int"
)]
self
.
is_regression
=
cast_str_to_bool
(
d
.
get
(
"is_regression"
,
"False"
))
self
.
is_regression
=
cast_str_to_bool
(
d
.
get
(
"is_regression"
,
"False"
))
self
.
weight_key
=
d
.
get
(
"weight_key"
,
None
)
self
.
weight_key
=
d
.
get
(
"weight_key"
,
None
)
self
.
skip_label
=
d
.
get
(
"skip_label"
,
None
)
if
self
.
skip_label
is
not
None
:
self
.
skip_label
=
self
.
label_type
(
self
.
skip_label
)
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
assert
data_dir
is
None
assert
data_dir
is
None
...
@@ -804,7 +688,7 @@ class TfdsProcessor(DataProcessor):
...
@@ -804,7 +688,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
@@ -822,6 +706,8 @@ class TfdsProcessor(DataProcessor):
...
@@ -822,6 +706,8 @@ class TfdsProcessor(DataProcessor):
if
self
.
text_b_key
:
if
self
.
text_b_key
:
text_b
=
self
.
process_text_fn
(
example
[
self
.
text_b_key
])
text_b
=
self
.
process_text_fn
(
example
[
self
.
text_b_key
])
label
=
self
.
label_type
(
example
[
self
.
label_key
])
label
=
self
.
label_type
(
example
[
self
.
label_key
])
if
self
.
skip_label
is
not
None
and
label
==
self
.
skip_label
:
continue
if
self
.
weight_key
:
if
self
.
weight_key
:
weight
=
float
(
example
[
self
.
weight_key
])
weight
=
float
(
example
[
self
.
weight_key
])
examples
.
append
(
examples
.
append
(
...
@@ -862,7 +748,7 @@ class WnliProcessor(DataProcessor):
...
@@ -862,7 +748,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
...
@@ -879,6 +765,200 @@ class WnliProcessor(DataProcessor):
...
@@ -879,6 +765,200 @@ class WnliProcessor(DataProcessor):
return
examples
return
examples
class
XnliProcessor
(
DataProcessor
):
"""Processor for the XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
language
=
"en"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
XnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
language
==
"all"
:
self
.
languages
=
XnliProcessor
.
supported_languages
elif
language
not
in
XnliProcessor
.
supported_languages
:
raise
ValueError
(
"language %s is not supported for XNLI task."
%
language
)
else
:
self
.
languages
=
[
language
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
[]
for
language
in
self
.
languages
:
# Skips the header.
lines
.
extend
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"multinli"
,
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
language
=
self
.
process_text_fn
(
line
[
0
])
text_a
=
self
.
process_text_fn
(
line
[
6
])
text_b
=
self
.
process_text_fn
(
line
[
7
])
label
=
self
.
process_text_fn
(
line
[
1
])
examples_by_lang
[
language
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XNLI"
class
XtremePawsxProcessor
(
DataProcessor
):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-PAWS-X"
class
XtremeXnliProcessor
(
DataProcessor
):
"""Processor for the XTREME XNLI data set."""
supported_languages
=
[
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"fr"
,
"hi"
,
"ru"
,
"sw"
,
"th"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"XTREME-XNLI"
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
def
convert_single_example
(
ex_index
,
example
,
label_list
,
max_seq_length
,
tokenizer
):
tokenizer
):
"""Converts a single `InputExample` into a single `InputFeatures`."""
"""Converts a single `InputExample` into a single `InputFeatures`."""
...
@@ -989,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
...
@@ -989,7 +1069,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
official/nlp/data/create_finetuning_data.py
View file @
31ca3b97
...
@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
...
@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
# sentence-piece tokenizer based squad_lib
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
tagging_data_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
"fine_tuning_task_type"
,
"classification"
,
"fine_tuning_task_type"
,
"classification"
,
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
],
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
,
"tagging"
],
"The name of the BERT fine tuning task for which data "
"The name of the BERT fine tuning task for which data "
"will be generated.
.
"
)
"will be generated."
)
# BERT classification specific flags.
# BERT classification specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
...
@@ -48,30 +50,41 @@ flags.DEFINE_string(
...
@@ -48,30 +50,41 @@ flags.DEFINE_string(
"for the task."
)
"for the task."
)
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"QNLI"
,
"QQP"
,
"SST-2"
,
"XNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"PAWS-X"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"The name of the task to train BERT classifier. The "
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
"PAWS-X."
)
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
# MNLI task-specific flag.
"The name of sentence retrieval task for scoring"
)
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task
specific flag.
# XNLI task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# BERT Squad task specific flags.
# Retrieval task-specific flags.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# Tagging task-specific flags.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
# BERT Squad task-specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
"The input data file in for generating training data for BERT squad task."
)
...
@@ -171,7 +184,8 @@ def generate_classifier_dataset():
...
@@ -171,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
"cola"
:
classifier_data_lib
.
ColaProcessor
,
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
"qnli"
:
...
@@ -180,6 +194,8 @@ def generate_classifier_dataset():
...
@@ -180,6 +194,8 @@ def generate_classifier_dataset():
"rte"
:
classifier_data_lib
.
RteProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
language
=
FLAGS
.
xnli_language
),
...
@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
...
@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
FLAGS
.
max_seq_length
)
FLAGS
.
max_seq_length
)
def
generate_tagging_dataset
():
"""Generates tagging dataset."""
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
}
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
task_name
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
elif
FLAGS
.
tokenizer_impl
==
"sentence_piece"
:
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported tokenizer_impl: %s"
%
FLAGS
.
tokenizer_impl
)
processor
=
processors
[
task_name
]()
return
tagging_data_lib
.
generate_tf_record_from_data_file
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
train_data_output_path
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
processor_text_fn
)
def
main
(
_
):
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
if
not
FLAGS
.
vocab_file
:
...
@@ -304,8 +348,11 @@ def main(_):
...
@@ -304,8 +348,11 @@ def main(_):
input_meta_data
=
generate_regression_dataset
()
input_meta_data
=
generate_regression_dataset
()
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
input_meta_data
=
generate_retrieval_dataset
()
input_meta_data
=
generate_retrieval_dataset
()
el
se
:
el
if
FLAGS
.
fine_tuning_task_type
==
"squad"
:
input_meta_data
=
generate_squad_dataset
()
input_meta_data
=
generate_squad_dataset
()
else
:
assert
FLAGS
.
fine_tuning_task_type
==
"tagging"
input_meta_data
=
generate_tagging_dataset
()
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
FLAGS
.
meta_data_file_path
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
FLAGS
.
meta_data_file_path
))
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
meta_data_file_path
,
"w"
)
as
writer
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
meta_data_file_path
,
"w"
)
as
writer
:
...
...
official/nlp/data/create_pretraining_data.py
View file @
31ca3b97
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
collections
import
itertools
import
random
import
random
from
absl
import
app
from
absl
import
app
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
max_predictions_per_seq
,
rng
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
# All of the n-grams of this size have been used.
max_predictions_per_seq
,
vocab_words
,
rng
,
if
not
ngrams
[
sz
]:
do_whole_word_mask
):
continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
# Found a usable n-gram! Mark its tokens as masked and add it to return.
for
(
i
,
token
)
in
enumerate
(
tokens
):
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
continue
# Whole Word Masking means that if we mask all of the wordpieces
if
gram_start_pos
is
not
None
:
# corresponding to an original word. When a word has been split into
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
# WordPieces, the first token does not have any marker and any subsequence
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
# tokens are prefixed with ##. So whenever we see the ## token, we
gram_start_pos
=
i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
output_tokens
=
list
(
tokens
)
for
index_set
in
cand_indexes
:
for
gram
in
masked_grams
:
if
len
(
masked_lms
)
>=
num_to_predict
:
# 80% of the time, replace all n-gram tokens with [MASK]
break
if
rng
.
random
()
<
0.8
:
# If adding a whole-word mask would exceed the maximum number of
replacement_action
=
lambda
idx
:
"[MASK]"
# predictions, then just skip this candidate.
else
:
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
# 10% of the time, keep all the original n-gram tokens.
continue
if
rng
.
random
()
<
0.5
:
is_any_index_covered
=
False
replacement_action
=
lambda
idx
:
tokens
[
idx
]
for
index
in
index_set
:
# 10% of the time, replace each n-gram token with a random word.
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
else
:
else
:
# 10% of the time, keep original
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
@@ -467,7 +642,7 @@ def main(_):
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
official/nlp/data/data_loader_factory.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from
official.utils
import
registry
_REGISTERED_DATA_LOADER_CLS
=
{}
def
register_data_loader_cls
(
data_config_cls
):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DATA_LOADER_CLS
,
data_config_cls
)
def
get_data_loader
(
data_config
):
"""Creates a data_loader from data_config."""
return
registry
.
lookup
(
_REGISTERED_DATA_LOADER_CLS
,
data_config
.
__class__
)(
data_config
)
official/nlp/data/pretrain_dataloader.py
View file @
31ca3b97
...
@@ -16,11 +16,27 @@
...
@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
"""Loads dataset for the BERT pretraining task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
class
BertPretrainDataLoader
:
class
BertPretrainDataLoader
:
"""A class to load dataset for bert pretraining task."""
"""A class to load dataset for bert pretraining task."""
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
return
reader
.
read
(
input_context
)
official/nlp/data/question_answering_dataloader.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
''
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
# Settings below are question answering specific.
version_2_with_negative
:
bool
=
False
# Settings below are only used for eval mode.
input_preprocessed_data_path
:
str
=
''
doc_stride
:
int
=
128
query_length
:
int
=
64
vocab_file
:
str
=
''
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
class
QuestionAnsweringDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_is_training
=
params
.
is_training
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
elif
name
==
'segment_ids'
:
x
[
'input_type_ids'
]
=
tensor
else
:
x
[
name
]
=
tensor
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/data/sentence_prediction_dataloader.py
View file @
31ca3b97
...
@@ -15,11 +15,28 @@
...
@@ -15,11 +15,28 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
"""Loads dataset for the sentence prediction (classification) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
''
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
class
SentencePredictionDataLoader
:
class
SentencePredictionDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
"""A class to load dataset for sentence prediction (classification) task."""
...
@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
...
@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/data/tagging_data_lib.py
0 → 100644
View file @
31ca3b97
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS."""
import
collections
import
os
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.data
import
classifier_data_lib
# A negative label id for the padding label, which will not contribute
# to loss/metrics in training.
_PADDING_LABEL_ID
=
-
1
# The special unknown token, used to substitute a word which has too many
# subwords after tokenization.
_UNK_TOKEN
=
"[UNK]"
class
InputExample
(
object
):
"""A single training/test example for token classification."""
def
__init__
(
self
,
sentence_id
,
words
=
None
,
label_ids
=
None
):
"""Constructs an InputExample."""
self
.
sentence_id
=
sentence_id
self
.
words
=
words
if
words
else
[]
self
.
label_ids
=
label_ids
if
label_ids
else
[]
def
add_word_and_label_id
(
self
,
word
,
label_id
):
"""Adds word and label_id pair in the example."""
self
.
words
.
append
(
word
)
self
.
label_ids
.
append
(
label_id
)
def
_read_one_file
(
file_name
,
label_list
):
"""Reads one file and returns a list of `InputExample` instances."""
lines
=
tf
.
io
.
gfile
.
GFile
(
file_name
,
"r"
).
readlines
()
examples
=
[]
label_id_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
sentence_id
=
0
example
=
InputExample
(
sentence_id
=
0
)
for
line
in
lines
:
line
=
line
.
strip
(
"
\n
"
)
if
line
:
# The format is: <token>\t<label> for train/dev set and <token> for test.
items
=
line
.
split
(
"
\t
"
)
assert
len
(
items
)
==
2
or
len
(
items
)
==
1
token
=
items
[
0
].
strip
()
# Assign a dummy label_id for test set
label_id
=
label_id_map
[
items
[
1
].
strip
()]
if
len
(
items
)
==
2
else
0
example
.
add_word_and_label_id
(
token
,
label_id
)
else
:
# Empty line indicates a new sentence.
if
example
.
words
:
examples
.
append
(
example
)
sentence_id
+=
1
example
=
InputExample
(
sentence_id
=
sentence_id
)
if
example
.
words
:
examples
.
append
(
example
)
return
examples
class
PanxProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Processor for the Panx data set."""
supported_languages
=
[
"ar"
,
"he"
,
"vi"
,
"id"
,
"jv"
,
"ms"
,
"tl"
,
"eu"
,
"ml"
,
"ta"
,
"te"
,
"af"
,
"nl"
,
"en"
,
"de"
,
"el"
,
"bn"
,
"hi"
,
"mr"
,
"ur"
,
"fa"
,
"fr"
,
"it"
,
"pt"
,
"es"
,
"bg"
,
"ru"
,
"ja"
,
"ka"
,
"ko"
,
"th"
,
"sw"
,
"yo"
,
"my"
,
"zh"
,
"kk"
,
"tr"
,
"et"
,
"fi"
,
"hu"
]
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
for
language
in
self
.
supported_languages
:
examples_dict
[
language
]
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"test-%s.tsv"
%
language
),
self
.
get_labels
())
return
examples_dict
def
get_labels
(
self
):
return
[
"O"
,
"B-PER"
,
"I-PER"
,
"B-LOC"
,
"I-LOC"
,
"B-ORG"
,
"I-ORG"
]
@
staticmethod
def
get_processor_name
():
return
"panx"
class
UdposProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Processor for the Udpos data set."""
supported_languages
=
[
"af"
,
"ar"
,
"bg"
,
"de"
,
"el"
,
"en"
,
"es"
,
"et"
,
"eu"
,
"fa"
,
"fi"
,
"fr"
,
"he"
,
"hi"
,
"hu"
,
"id"
,
"it"
,
"ja"
,
"kk"
,
"ko"
,
"mr"
,
"nl"
,
"pt"
,
"ru"
,
"ta"
,
"te"
,
"th"
,
"tl"
,
"tr"
,
"ur"
,
"vi"
,
"yo"
,
"zh"
]
def
get_train_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
),
self
.
get_labels
())
def
get_dev_examples
(
self
,
data_dir
):
return
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
),
self
.
get_labels
())
def
get_test_examples
(
self
,
data_dir
):
examples_dict
=
{}
for
language
in
self
.
supported_languages
:
examples_dict
[
language
]
=
_read_one_file
(
os
.
path
.
join
(
data_dir
,
"test-%s.tsv"
%
language
),
self
.
get_labels
())
return
examples_dict
def
get_labels
(
self
):
return
[
"ADJ"
,
"ADP"
,
"ADV"
,
"AUX"
,
"CCONJ"
,
"DET"
,
"INTJ"
,
"NOUN"
,
"NUM"
,
"PART"
,
"PRON"
,
"PROPN"
,
"PUNCT"
,
"SCONJ"
,
"SYM"
,
"VERB"
,
"X"
]
@
staticmethod
def
get_processor_name
():
return
"udpos"
def
_tokenize_example
(
example
,
max_length
,
tokenizer
,
text_preprocessing
=
None
):
"""Tokenizes words and breaks long example into short ones."""
# Needs additional [CLS] and [SEP] tokens.
max_length
=
max_length
-
2
new_examples
=
[]
new_example
=
InputExample
(
sentence_id
=
example
.
sentence_id
)
for
i
,
word
in
enumerate
(
example
.
words
):
if
any
([
x
<
0
for
x
in
example
.
label_ids
]):
raise
ValueError
(
"Unexpected negative label_id: %s"
%
example
.
label_ids
)
if
text_preprocessing
:
word
=
text_preprocessing
(
word
)
subwords
=
tokenizer
.
tokenize
(
word
)
if
(
not
subwords
or
len
(
subwords
)
>
max_length
)
and
word
:
subwords
=
[
_UNK_TOKEN
]
if
len
(
subwords
)
+
len
(
new_example
.
words
)
>
max_length
:
# Start a new example.
new_examples
.
append
(
new_example
)
new_example
=
InputExample
(
sentence_id
=
example
.
sentence_id
)
for
j
,
subword
in
enumerate
(
subwords
):
# Use the real label for the first subword, and pad label for
# the remainings.
subword_label
=
example
.
label_ids
[
i
]
if
j
==
0
else
_PADDING_LABEL_ID
new_example
.
add_word_and_label_id
(
subword
,
subword_label
)
if
new_example
.
words
:
new_examples
.
append
(
new_example
)
return
new_examples
def
_convert_single_example
(
example
,
max_seq_length
,
tokenizer
):
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
tokens
=
[
"[CLS]"
]
tokens
.
extend
(
example
.
words
)
tokens
.
append
(
"[SEP]"
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
label_ids
=
[
_PADDING_LABEL_ID
]
label_ids
.
extend
(
example
.
label_ids
)
label_ids
.
append
(
_PADDING_LABEL_ID
)
segment_ids
=
[
0
]
*
len
(
input_ids
)
input_mask
=
[
1
]
*
len
(
input_ids
)
# Pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
label_ids
.
append
(
_PADDING_LABEL_ID
)
def
create_int_feature
(
values
):
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"label_ids"
]
=
create_int_feature
(
label_ids
)
features
[
"sentence_id"
]
=
create_int_feature
([
example
.
sentence_id
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
return
tf_example
def
write_example_to_file
(
examples
,
tokenizer
,
max_seq_length
,
output_file
,
text_preprocessing
=
None
):
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
Note that the words inside each example will be tokenized and be applied by
`text_preprocessing` if available. Also, if the length of sentence (plus
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
will be broken into multiple short examples. For example:
Example (text_preprocessing=lowercase, max_seq_length=5)
words: ["What", "a", "great", "weekend"]
labels: [ 7, 5, 9, 10]
sentence_id: 0
preprocessed: ["what", "a", "great", "weekend"]
tokenized: ["what", "a", "great", "week", "##end"]
will result in two tf.example protos:
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
label_ids: [-1, 7, 5, 9, -1]
input_mask: [ 1, 1, 1, 1, 1]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
label_ids: [-1, 10, -1, -1, -1]
input_mask: [ 1, 1, 1, 0, 0]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
Note the use of -1 in `label_ids` to indicate that a token should not be
considered for classification (e.g., trailing ## wordpieces or special
token). Token classification models should accordingly ignore these when
calculating loss, metrics, etc...
Args:
examples: A list of `InputExample` instances.
tokenizer: The tokenizer to be applied on the data.
max_seq_length: Maximum length of generated sequences.
output_file: The name of the output tfrecord file.
text_preprocessing: optional preprocessing run on each word prior to
tokenization.
Returns:
The total number of tf.train.Example proto written to file.
"""
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
num_tokenized_examples
=
0
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d to %s"
,
ex_index
,
len
(
examples
),
output_file
)
tokenized_examples
=
_tokenize_example
(
example
,
max_seq_length
,
tokenizer
,
text_preprocessing
)
num_tokenized_examples
+=
len
(
tokenized_examples
)
for
per_tokenized_example
in
tokenized_examples
:
tf_example
=
_convert_single_example
(
per_tokenized_example
,
max_seq_length
,
tokenizer
)
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
return
num_tokenized_examples
def
token_classification_meta_data
(
train_data_size
,
max_seq_length
,
num_labels
,
eval_data_size
=
None
,
test_data_size
=
None
,
label_list
=
None
,
processor_type
=
None
):
"""Creates metadata for tagging (token classification) datasets."""
meta_data
=
{
"train_data_size"
:
train_data_size
,
"max_seq_length"
:
max_seq_length
,
"num_labels"
:
num_labels
,
"task_type"
:
"tagging"
,
"label_type"
:
"int"
,
"label_shape"
:
[
max_seq_length
],
}
if
eval_data_size
:
meta_data
[
"eval_data_size"
]
=
eval_data_size
if
test_data_size
:
meta_data
[
"test_data_size"
]
=
test_data_size
if
label_list
:
meta_data
[
"label_list"
]
=
label_list
if
processor_type
:
meta_data
[
"processor_type"
]
=
processor_type
return
meta_data
def
generate_tf_record_from_data_file
(
processor
,
data_dir
,
tokenizer
,
max_seq_length
,
train_data_output_path
,
eval_data_output_path
,
test_data_output_path
,
text_preprocessing
):
"""Generates tfrecord files from the raw data."""
common_kwargs
=
dict
(
tokenizer
=
tokenizer
,
max_seq_length
=
max_seq_length
,
text_preprocessing
=
text_preprocessing
)
train_examples
=
processor
.
get_train_examples
(
data_dir
)
train_data_size
=
write_example_to_file
(
train_examples
,
output_file
=
train_data_output_path
,
**
common_kwargs
)
eval_examples
=
processor
.
get_dev_examples
(
data_dir
)
eval_data_size
=
write_example_to_file
(
eval_examples
,
output_file
=
eval_data_output_path
,
**
common_kwargs
)
test_input_data_examples
=
processor
.
get_test_examples
(
data_dir
)
test_data_size
=
{}
for
language
,
examples
in
test_input_data_examples
.
items
():
test_data_size
[
language
]
=
write_example_to_file
(
examples
,
output_file
=
test_data_output_path
.
format
(
language
),
**
common_kwargs
)
labels
=
processor
.
get_labels
()
meta_data
=
token_classification_meta_data
(
train_data_size
,
max_seq_length
,
len
(
labels
),
eval_data_size
,
test_data_size
,
label_list
=
labels
,
processor_type
=
processor
.
get_processor_name
())
return
meta_data
official/nlp/data/tagging_data_loader.py
View file @
31ca3b97
...
@@ -15,17 +15,30 @@
...
@@ -15,17 +15,30 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
is_training
:
bool
=
True
seq_length
:
int
=
128
include_sentence_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
TaggingDataConfig
)
class
TaggingDataLoader
:
class
TaggingDataLoader
:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
:
TaggingDataConfig
):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
self
.
_include_sentence_id
=
params
.
include_sentence_id
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
...
@@ -35,6 +48,9 @@ class TaggingDataLoader:
...
@@ -35,6 +48,9 @@ class TaggingDataLoader:
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
}
if
self
.
_include_sentence_id
:
name_to_features
[
'sentence_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...
@@ -54,6 +70,8 @@ class TaggingDataLoader:
...
@@ -54,6 +70,8 @@ class TaggingDataLoader:
'input_mask'
:
record
[
'input_mask'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
'input_type_ids'
:
record
[
'segment_ids'
]
}
}
if
self
.
_include_sentence_id
:
x
[
'sentence_id'
]
=
record
[
'sentence_id'
]
y
=
record
[
'label_ids'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
return
(
x
,
y
)
...
...
Prev
1
2
3
4
5
6
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment