Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
31ca3b97
Commit
31ca3b97
authored
Jul 23, 2020
by
Kaushik Shivakumar
Browse files
resovle merge conflicts
parents
3e9d886d
7fcd7cba
Changes
392
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1468 additions
and
475 deletions
+1468
-475
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+0
-1
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+47
-4
official/nlp/bert/export_tfhub.py
official/nlp/bert/export_tfhub.py
+1
-1
official/nlp/bert/model_saving_utils.py
official/nlp/bert/model_saving_utils.py
+0
-4
official/nlp/bert/model_training_utils.py
official/nlp/bert/model_training_utils.py
+3
-2
official/nlp/bert/run_classifier.py
official/nlp/bert/run_classifier.py
+4
-1
official/nlp/configs/bert.py
official/nlp/configs/bert.py
+2
-82
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+7
-6
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+91
-0
official/nlp/configs/electra_test.py
official/nlp/configs/electra_test.py
+49
-0
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+42
-4
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+374
-294
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+60
-13
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+233
-58
official/nlp/data/data_loader_factory.py
official/nlp/data/data_loader_factory.py
+59
-0
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+17
-3
official/nlp/data/question_answering_dataloader.py
official/nlp/data/question_answering_dataloader.py
+95
-0
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+19
-1
official/nlp/data/tagging_data_lib.py
official/nlp/data/tagging_data_lib.py
+346
-0
official/nlp/data/tagging_data_loader.py
official/nlp/data/tagging_data_loader.py
+19
-1
No files found.
official/modeling/tf_utils.py
View file @
31ca3b97
...
@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
...
@@ -88,7 +88,6 @@ def is_special_none_tensor(tensor):
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
# TODO(hongkuny): consider moving custom string-map lookup to keras api.
def
get_activation
(
identifier
):
def
get_activation
(
identifier
):
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
"""Maps a identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
...
...
official/nlp/albert/run_classifier.py
View file @
31ca3b97
...
@@ -14,23 +14,61 @@
...
@@ -14,23 +14,61 @@
# ==============================================================================
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
"""ALBERT classification finetuning runner in tf2.x."""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
json
import
json
import
os
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
predict
(
strategy
,
albert_config
,
input_meta_data
,
predict_input_fn
):
"""Function outputs both the ground truth predictions as .tsv files."""
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
albert_config
,
input_meta_data
[
'num_labels'
])[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
ground_truth
=
run_classifier_bert
.
get_predictions_and_labels
(
strategy
,
classifier_model
,
predict_input_fn
,
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
ground_truth_labels_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'output_labels.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
ground_truth_labels_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Ground truth results *****'
)
for
label
in
ground_truth
:
output_line
=
'
\t
'
.
join
(
str
(
label
))
+
'
\n
'
writer
.
write
(
output_line
)
return
def
main
(
_
):
def
main
(
_
):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
@@ -56,9 +94,14 @@ def main(_):
...
@@ -56,9 +94,14 @@ def main(_):
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'train_and_eval'
:
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
train_input_fn
,
eval_input_fn
)
elif
FLAGS
.
mode
==
'predict'
:
predict
(
strategy
,
albert_config
,
input_meta_data
,
eval_input_fn
)
else
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
return
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
flags
.
mark_flag_as_required
(
'bert_config_file'
)
...
...
official/nlp/bert/export_tfhub.py
View file @
31ca3b97
...
@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
...
@@ -79,7 +79,7 @@ def export_bert_tfhub(bert_config: configs.BertConfig,
do_lower_case
,
vocab_file
)
do_lower_case
,
vocab_file
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
core_model
,
encoder
=
create_bert_model
(
bert_config
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
consum
ed
()
checkpoint
.
restore
(
model_checkpoint_path
).
assert_
existing_objects_match
ed
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
...
...
official/nlp/bert/model_saving_utils.py
View file @
31ca3b97
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
...
@@ -55,14 +55,10 @@ def export_bert_model(model_export_path: typing.Text,
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
raise
ValueError
(
'model must be a tf.keras.Model object.'
)
if
checkpoint_dir
:
if
checkpoint_dir
:
# Keras compile/fit() was used to save checkpoint using
# model.save_weights().
if
restore_model_using_load_weights
:
if
restore_model_using_load_weights
:
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
model_weight_path
=
os
.
path
.
join
(
checkpoint_dir
,
'checkpoint'
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
assert
tf
.
io
.
gfile
.
exists
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
model
.
load_weights
(
model_weight_path
)
# tf.train.Checkpoint API was used via custom training loop logic.
else
:
else
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
...
...
official/nlp/bert/model_training_utils.py
View file @
31ca3b97
...
@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
...
@@ -99,7 +99,9 @@ def write_txt_summary(training_summary, summary_dir):
@
deprecation
.
deprecated
(
@
deprecation
.
deprecated
(
None
,
'This function is deprecated. Please use Keras compile/fit instead.'
)
None
,
'This function is deprecated and we do not expect adding new '
'functionalities. Please do not have your code depending '
'on this library.'
)
def
run_customized_training_loop
(
def
run_customized_training_loop
(
# pylint: disable=invalid-name
# pylint: disable=invalid-name
_sentinel
=
None
,
_sentinel
=
None
,
...
@@ -557,7 +559,6 @@ def run_customized_training_loop(
...
@@ -557,7 +559,6 @@ def run_customized_training_loop(
for
metric
in
model
.
metrics
:
for
metric
in
model
.
metrics
:
training_summary
[
metric
.
name
]
=
_float_metric_value
(
metric
)
training_summary
[
metric
.
name
]
=
_float_metric_value
(
metric
)
if
eval_metrics
:
if
eval_metrics
:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary
[
'last_train_metrics'
]
=
_float_metric_value
(
training_summary
[
'last_train_metrics'
]
=
_float_metric_value
(
train_metrics
[
0
])
train_metrics
[
0
])
training_summary
[
'eval_metrics'
]
=
_float_metric_value
(
eval_metrics
[
0
])
training_summary
[
'eval_metrics'
]
=
_float_metric_value
(
eval_metrics
[
0
])
...
...
official/nlp/bert/run_classifier.py
View file @
31ca3b97
...
@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
...
@@ -343,7 +343,10 @@ def export_classifier(model_export_path, input_meta_data, bert_config,
# Export uses float32 for now, even if training uses mixed precision.
# Export uses float32 for now, even if training uses mixed precision.
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
classifier_model
=
bert_models
.
classifier_model
(
classifier_model
=
bert_models
.
classifier_model
(
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
))[
0
]
bert_config
,
input_meta_data
.
get
(
'num_labels'
,
1
),
hub_module_url
=
FLAGS
.
hub_module_url
,
hub_module_trainable
=
False
)[
0
]
model_saving_utils
.
export_bert_model
(
model_saving_utils
.
export_bert_model
(
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
model_export_path
,
model
=
classifier_model
,
checkpoint_dir
=
model_dir
)
...
...
official/nlp/configs/bert.py
View file @
31ca3b97
...
@@ -24,7 +24,6 @@ import tensorflow as tf
...
@@ -24,7 +24,6 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
bert_pretrainer
from
official.nlp.modeling.models
import
bert_pretrainer
...
@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
...
@@ -43,7 +42,6 @@ class ClsHeadConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
BertPretrainerConfig
(
base_config
.
Config
):
class
BertPretrainerConfig
(
base_config
.
Config
):
"""BERT encoder configuration."""
"""BERT encoder configuration."""
num_masked_tokens
:
int
=
76
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
...
@@ -56,96 +54,18 @@ def instantiate_classification_heads_from_cfgs(
]
if
cls_head_configs
else
[]
]
if
cls_head_configs
else
[]
def
instantiate_
bert
pretrainer_from_cfg
(
def
instantiate_pretrainer_from_cfg
(
config
:
BertPretrainerConfig
,
config
:
BertPretrainerConfig
,
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
encoder_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
)
->
bert_pretrainer
.
BertPretrainerV2
:
)
->
bert_pretrainer
.
BertPretrainerV2
:
"""Instantiates a BertPretrainer from the config."""
"""Instantiates a BertPretrainer from the config."""
encoder_cfg
=
config
.
encoder
encoder_cfg
=
config
.
encoder
if
encoder_network
is
None
:
if
encoder_network
is
None
:
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
encoder_network
=
encoders
.
instantiate_encoder_from_cfg
(
encoder_cfg
)
return
bert_pretrainer
.
BertPretrainerV2
(
return
bert_pretrainer
.
BertPretrainerV2
(
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
stddev
=
encoder_cfg
.
initializer_range
),
encoder_network
=
encoder_network
,
encoder_network
=
encoder_network
,
classification_heads
=
instantiate_classification_heads_from_cfgs
(
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
config
.
cls_heads
))
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
dataclasses
.
dataclass
class
BertPretrainEvalDataConfig
(
BertPretrainDataConfig
):
"""Data config for the eval set in BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
""
global_batch_size
:
int
=
512
is_training
:
bool
=
False
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
@
dataclasses
.
dataclass
class
SentencePredictionDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for sentence prediction (tasks/sentence_prediction)."""
input_path
:
str
=
""
global_batch_size
:
int
=
32
is_training
:
bool
=
False
seq_length
:
int
=
128
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
QADevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for queston answering (tasks/question_answering)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
@
dataclasses
.
dataclass
class
TaggingDevDataConfig
(
cfg
.
DataConfig
):
"""Dev Data config for tagging (tasks/tagging)."""
input_path
:
str
=
""
global_batch_size
:
int
=
48
is_training
:
bool
=
False
seq_length
:
int
=
384
drop_remainder
:
bool
=
False
official/nlp/configs/bert_test.py
View file @
31ca3b97
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def
test_network_invocation
(
self
):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
_
=
bert
.
instantiate_
bert
pretrainer_from_cfg
(
config
)
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
config
=
bert
.
BertPretrainerConfig
(
...
@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
...
@@ -56,9 +56,10 @@ class BertModelsTest(tf.test.TestCase):
bert
.
ClsHeadConfig
(
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
])
encoder
=
bert
.
instantiate_bertpretrainer_from_cfg
(
config
)
encoder
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
self
.
assertSameElements
(
[
"encoder"
,
"next_sentence.pooler_dense"
])
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
official/nlp/configs/electra.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
electra_pretrainer
@
dataclasses
.
dataclass
class
ELECTRAPretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
discriminator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
bert
.
ClsHeadConfig
]
)
->
List
[
layers
.
ClassificationHead
]:
if
cls_head_configs
:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
else
:
return
[]
def
instantiate_pretrainer_from_cfg
(
config
:
ELECTRAPretrainerConfig
,
generator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
discriminator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
)
->
electra_pretrainer
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
# Copy discriminator's embeddings to generator for easier model serialization.
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
if
generator_network
is
None
:
if
config
.
tie_embeddings
:
embedding_layer
=
discriminator_network
.
get_embedding_layer
()
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
,
embedding_layer
=
embedding_layer
)
else
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
),
disallow_correct
=
config
.
disallow_correct
)
official/nlp/configs/electra_test.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
class
ELECTRAModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
)
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/encoders.py
View file @
31ca3b97
...
@@ -17,12 +17,13 @@
...
@@ -17,12 +17,13 @@
Includes configurations and instantiation methods.
Includes configurations and instantiation methods.
"""
"""
from
typing
import
Optional
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
from
official.nlp.modeling
import
networks
...
@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
...
@@ -40,12 +41,47 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
def
instantiate_encoder_from_cfg
(
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
)
->
networks
.
TransformerEncoder
:
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
,
embedding_layer
:
Optional
[
layers
.
OnDeviceEmbedding
]
=
None
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network
=
networks
.
TransformerEncoder
(
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
config
.
vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
hidden_size
=
config
.
hidden_size
,
seq_length
=
None
,
max_seq_length
=
config
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
dropout_rate
=
config
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
config
.
num_layers
,
pooled_output_dim
=
config
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerEncoder"
:
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder_cls
))
encoder_network
=
encoder_cls
(
vocab_size
=
config
.
vocab_size
,
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
num_layers
=
config
.
num_layers
,
...
@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
...
@@ -58,5 +94,7 @@ def instantiate_encoder_from_cfg(
max_sequence_length
=
config
.
max_position_embeddings
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
config
.
initializer_range
),
embedding_width
=
config
.
embedding_size
,
embedding_layer
=
embedding_layer
)
return
encoder_network
return
encoder_network
official/nlp/data/classifier_data_lib.py
View file @
31ca3b97
This diff is collapsed.
Click to expand it.
official/nlp/data/create_finetuning_data.py
View file @
31ca3b97
...
@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
...
@@ -32,14 +32,16 @@ from official.nlp.data import sentence_retrieval_lib
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
# sentence-piece tokenizer based squad_lib
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
squad_lib_sp
from
official.nlp.data
import
tagging_data_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags
.
DEFINE_enum
(
flags
.
DEFINE_enum
(
"fine_tuning_task_type"
,
"classification"
,
"fine_tuning_task_type"
,
"classification"
,
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
],
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
,
"tagging"
],
"The name of the BERT fine tuning task for which data "
"The name of the BERT fine tuning task for which data "
"will be generated.
.
"
)
"will be generated."
)
# BERT classification specific flags.
# BERT classification specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
...
@@ -48,30 +50,41 @@ flags.DEFINE_string(
...
@@ -48,30 +50,41 @@ flags.DEFINE_string(
"for the task."
)
"for the task."
)
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"QNLI"
,
"QQP"
,
"SST-2"
,
"XNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"PAWS-X"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"The name of the task to train BERT classifier. The "
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
"PAWS-X."
)
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
# MNLI task-specific flag.
"The name of sentence retrieval task for scoring"
)
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task
specific flag.
# XNLI task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
"of all languages will be used for training."
)
# BERT Squad task specific flags.
# Retrieval task-specific flags.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# Tagging task-specific flags.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
# BERT Squad task-specific flags.
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
"The input data file in for generating training data for BERT squad task."
)
...
@@ -171,7 +184,8 @@ def generate_classifier_dataset():
...
@@ -171,7 +184,8 @@ def generate_classifier_dataset():
"cola"
:
"cola"
:
classifier_data_lib
.
ColaProcessor
,
classifier_data_lib
.
ColaProcessor
,
"mnli"
:
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
"qnli"
:
...
@@ -180,6 +194,8 @@ def generate_classifier_dataset():
...
@@ -180,6 +194,8 @@ def generate_classifier_dataset():
"rte"
:
classifier_data_lib
.
RteProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
language
=
FLAGS
.
xnli_language
),
...
@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
...
@@ -284,6 +300,34 @@ def generate_retrieval_dataset():
FLAGS
.
max_seq_length
)
FLAGS
.
max_seq_length
)
def
generate_tagging_dataset
():
"""Generates tagging dataset."""
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
}
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
task_name
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
elif
FLAGS
.
tokenizer_impl
==
"sentence_piece"
:
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported tokenizer_impl: %s"
%
FLAGS
.
tokenizer_impl
)
processor
=
processors
[
task_name
]()
return
tagging_data_lib
.
generate_tf_record_from_data_file
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
train_data_output_path
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
processor_text_fn
)
def
main
(
_
):
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
if
not
FLAGS
.
vocab_file
:
...
@@ -304,8 +348,11 @@ def main(_):
...
@@ -304,8 +348,11 @@ def main(_):
input_meta_data
=
generate_regression_dataset
()
input_meta_data
=
generate_regression_dataset
()
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
input_meta_data
=
generate_retrieval_dataset
()
input_meta_data
=
generate_retrieval_dataset
()
el
se
:
el
if
FLAGS
.
fine_tuning_task_type
==
"squad"
:
input_meta_data
=
generate_squad_dataset
()
input_meta_data
=
generate_squad_dataset
()
else
:
assert
FLAGS
.
fine_tuning_task_type
==
"tagging"
input_meta_data
=
generate_tagging_dataset
()
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
FLAGS
.
meta_data_file_path
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
FLAGS
.
meta_data_file_path
))
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
meta_data_file_path
,
"w"
)
as
writer
:
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
meta_data_file_path
,
"w"
)
as
writer
:
...
...
official/nlp/data/create_pretraining_data.py
View file @
31ca3b97
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
collections
import
itertools
import
random
import
random
from
absl
import
app
from
absl
import
app
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
...
@@ -48,6 +49,12 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
...
@@ -192,7 +199,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
max_predictions_per_seq
,
rng
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
...
@@ -229,7 +237,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
...
@@ -238,7 +246,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
...
@@ -337,7 +346,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
...
@@ -355,72 +364,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Arguments:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Arguments:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Arguments:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
# All of the n-grams of this size have been used.
max_predictions_per_seq
,
vocab_words
,
rng
,
if
not
ngrams
[
sz
]:
do_whole_word_mask
):
continue
"""Creates the predictions for the masked LM objective."""
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
cand_indexes
=
[]
# Found a usable n-gram! Mark its tokens as masked and add it to return.
for
(
i
,
token
)
in
enumerate
(
tokens
):
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Arguments:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
continue
# Whole Word Masking means that if we mask all of the wordpieces
if
gram_start_pos
is
not
None
:
# corresponding to an original word. When a word has been split into
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
# WordPieces, the first token does not have any marker and any subsequence
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
# tokens are prefixed with ##. So whenever we see the ## token, we
gram_start_pos
=
i
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
else
:
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
output_tokens
=
list
(
tokens
)
for
index_set
in
cand_indexes
:
for
gram
in
masked_grams
:
if
len
(
masked_lms
)
>=
num_to_predict
:
# 80% of the time, replace all n-gram tokens with [MASK]
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
replacement_action
=
lambda
idx
:
"[MASK]"
else
:
else
:
# 10% of the time, keep
original
# 10% of the time, keep
all the original n-gram tokens.
if
rng
.
random
()
<
0.5
:
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
i
nde
x
]
replacement_action
=
lambda
idx
:
tokens
[
i
d
x
]
# 10% of the time, replace with random word
# 10% of the time, replace
each n-gram token
with
a
random word
.
else
:
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
@@ -467,7 +642,7 @@ def main(_):
...
@@ -467,7 +642,7 @@ def main(_):
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
official/nlp/data/data_loader_factory.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A global factory to access NLP registered data loaders."""
from
official.utils
import
registry
_REGISTERED_DATA_LOADER_CLS
=
{}
def
register_data_loader_cls
(
data_config_cls
):
"""Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
This decorator supports registration of data loaders as follows:
```
@dataclasses.dataclass
class MyDataConfig(DataConfig):
# Add fields here.
pass
@register_data_loader_cls(MyDataConfig)
class MyDataLoader:
# Inherits def __init__(self, data_config).
pass
my_data_config = MyDataConfig()
# Returns MyDataLoader(my_data_config).
my_loader = get_data_loader(my_data_config)
```
Args:
data_config_cls: a subclass of DataConfig (*not* an instance
of DataConfig).
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of data_config_cls.
"""
return
registry
.
register
(
_REGISTERED_DATA_LOADER_CLS
,
data_config_cls
)
def
get_data_loader
(
data_config
):
"""Creates a data_loader from data_config."""
return
registry
.
lookup
(
_REGISTERED_DATA_LOADER_CLS
,
data_config
.
__class__
)(
data_config
)
official/nlp/data/pretrain_dataloader.py
View file @
31ca3b97
...
@@ -16,11 +16,27 @@
...
@@ -16,11 +16,27 @@
"""Loads dataset for the BERT pretraining task."""
"""Loads dataset for the BERT pretraining task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
BertPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for BERT pretraining task (tasks/masked_lm)."""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
class
BertPretrainDataLoader
:
class
BertPretrainDataLoader
:
"""A class to load dataset for bert pretraining task."""
"""A class to load dataset for bert pretraining task."""
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
...
@@ -91,7 +107,5 @@ class BertPretrainDataLoader:
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
return
reader
.
read
(
input_context
)
official/nlp/data/question_answering_dataloader.py
0 → 100644
View file @
31ca3b97
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the question answering (e.g, SQuAD) task."""
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
QADataConfig
(
cfg
.
DataConfig
):
"""Data config for question answering task (tasks/question_answering)."""
input_path
:
str
=
''
global_batch_size
:
int
=
48
is_training
:
bool
=
True
seq_length
:
int
=
384
# Settings below are question answering specific.
version_2_with_negative
:
bool
=
False
# Settings below are only used for eval mode.
input_preprocessed_data_path
:
str
=
''
doc_stride
:
int
=
128
query_length
:
int
=
64
vocab_file
:
str
=
''
tokenization
:
str
=
'WordPiece'
# WordPiece or SentencePiece
do_lower_case
:
bool
=
True
@
data_loader_factory
.
register_data_loader_cls
(
QADataConfig
)
class
QuestionAnsweringDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_is_training
=
params
.
is_training
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
if
self
.
_is_training
:
name_to_features
[
'start_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'end_positions'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
else
:
name_to_features
[
'unique_ids'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
,
y
=
{},
{}
for
name
,
tensor
in
record
.
items
():
if
name
in
(
'start_positions'
,
'end_positions'
):
y
[
name
]
=
tensor
elif
name
==
'input_ids'
:
x
[
'input_word_ids'
]
=
tensor
elif
name
==
'segment_ids'
:
x
[
'input_type_ids'
]
=
tensor
else
:
x
[
name
]
=
tensor
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/data/sentence_prediction_dataloader.py
View file @
31ca3b97
...
@@ -15,11 +15,28 @@
...
@@ -15,11 +15,28 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the sentence prediction (classification) task."""
"""Loads dataset for the sentence prediction (classification) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
LABEL_TYPES_MAP
=
{
'int'
:
tf
.
int64
,
'float'
:
tf
.
float32
}
@
dataclasses
.
dataclass
class
SentencePredictionDataConfig
(
cfg
.
DataConfig
):
"""Data config for sentence prediction task (tasks/sentence_prediction)."""
input_path
:
str
=
''
global_batch_size
:
int
=
32
is_training
:
bool
=
True
seq_length
:
int
=
128
label_type
:
str
=
'int'
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
class
SentencePredictionDataLoader
:
class
SentencePredictionDataLoader
:
"""A class to load dataset for sentence prediction (classification) task."""
"""A class to load dataset for sentence prediction (classification) task."""
...
@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
...
@@ -29,11 +46,12 @@ class SentencePredictionDataLoader:
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
...
official/nlp/data/tagging_data_lib.py
0 → 100644
View file @
31ca3b97
This diff is collapsed.
Click to expand it.
official/nlp/data/tagging_data_loader.py
View file @
31ca3b97
...
@@ -15,17 +15,30 @@
...
@@ -15,17 +15,30 @@
# ==============================================================================
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from
typing
import
Mapping
,
Optional
from
typing
import
Mapping
,
Optional
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
input_reader
from
official.core
import
input_reader
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.nlp.data
import
data_loader_factory
@
dataclasses
.
dataclass
class
TaggingDataConfig
(
cfg
.
DataConfig
):
"""Data config for tagging (tasks/tagging)."""
is_training
:
bool
=
True
seq_length
:
int
=
128
include_sentence_id
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
TaggingDataConfig
)
class
TaggingDataLoader
:
class
TaggingDataLoader
:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def
__init__
(
self
,
params
):
def
__init__
(
self
,
params
:
TaggingDataConfig
):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
self
.
_include_sentence_id
=
params
.
include_sentence_id
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
...
@@ -35,6 +48,9 @@ class TaggingDataLoader:
...
@@ -35,6 +48,9 @@ class TaggingDataLoader:
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
}
if
self
.
_include_sentence_id
:
name_to_features
[
'sentence_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
...
@@ -54,6 +70,8 @@ class TaggingDataLoader:
...
@@ -54,6 +70,8 @@ class TaggingDataLoader:
'input_mask'
:
record
[
'input_mask'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
'input_type_ids'
:
record
[
'segment_ids'
]
}
}
if
self
.
_include_sentence_id
:
x
[
'sentence_id'
]
=
record
[
'sentence_id'
]
y
=
record
[
'label_ids'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
return
(
x
,
y
)
...
...
Prev
1
2
3
4
5
6
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment