Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f16a7b5b
Unverified
Commit
f16a7b5b
authored
May 04, 2021
by
vedanshu
Committed by
GitHub
May 04, 2021
Browse files
Merge pull request
#1
from tensorflow/master
new pull
parents
8e9296ff
8f58f396
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2793 additions
and
454 deletions
+2793
-454
official/nlp/configs/bert_test.py
official/nlp/configs/bert_test.py
+0
-66
official/nlp/configs/electra.py
official/nlp/configs/electra.py
+8
-54
official/nlp/configs/electra_test.py
official/nlp/configs/electra_test.py
+0
-49
official/nlp/configs/encoders.py
official/nlp/configs/encoders.py
+291
-41
official/nlp/configs/encoders_test.py
official/nlp/configs/encoders_test.py
+42
-0
official/nlp/configs/experiment_configs.py
official/nlp/configs/experiment_configs.py
+19
-0
official/nlp/configs/experiments/glue_mnli_matched.yaml
official/nlp/configs/experiments/glue_mnli_matched.yaml
+49
-0
official/nlp/configs/experiments/squad_v1.yaml
official/nlp/configs/experiments/squad_v1.yaml
+50
-0
official/nlp/configs/finetuning_experiments.py
official/nlp/configs/finetuning_experiments.py
+139
-0
official/nlp/configs/models/bert_en_uncased_base.yaml
official/nlp/configs/models/bert_en_uncased_base.yaml
+16
-0
official/nlp/configs/pretraining_experiments.py
official/nlp/configs/pretraining_experiments.py
+82
-0
official/nlp/configs/wmt_transformer_experiments.py
official/nlp/configs/wmt_transformer_experiments.py
+110
-0
official/nlp/continuous_finetune_lib.py
official/nlp/continuous_finetune_lib.py
+215
-0
official/nlp/continuous_finetune_lib_test.py
official/nlp/continuous_finetune_lib_test.py
+98
-0
official/nlp/data/__init__.py
official/nlp/data/__init__.py
+14
-0
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+426
-102
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+134
-74
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+251
-68
official/nlp/data/create_pretraining_data_test.py
official/nlp/data/create_pretraining_data_test.py
+128
-0
official/nlp/data/create_xlnet_pretraining_data.py
official/nlp/data/create_xlnet_pretraining_data.py
+721
-0
No files found.
Too many changes to show.
To preserve performance only
298 of 298+
files are displayed.
Plain diff
Email patch
official/nlp/configs/bert_test.py
deleted
100644 → 0
View file @
8e9296ff
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
class
BertModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
))
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
with
self
.
assertRaises
(
ValueError
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
),
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
def
test_checkpoint_items
(
self
):
config
=
bert
.
BertPretrainerConfig
(
encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
encoder
=
bert
.
instantiate_pretrainer_from_cfg
(
config
)
self
.
assertSameElements
(
encoder
.
checkpoint_items
.
keys
(),
[
"encoder"
,
"masked_lm"
,
"next_sentence.pooler_dense"
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/electra.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,71 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
,
Optional
from
typing
import
List
import
dataclasses
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
layers
from
official.nlp.modeling.models
import
electra_pretrainer
@
dataclasses
.
dataclass
class
E
LECTRA
PretrainerConfig
(
base_config
.
Config
):
class
E
lectra
PretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
generator_encoder
:
encoders
.
TransformerEncoderConfig
=
(
encoders
.
TransformerEncoderConfig
())
discrimin
ator_encoder
:
encoders
.
Transformer
EncoderConfig
=
(
encoders
.
Transformer
EncoderConfig
()
)
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
gener
ator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
(
)
discriminator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
def
instantiate_classification_heads_from_cfgs
(
cls_head_configs
:
List
[
bert
.
ClsHeadConfig
]
)
->
List
[
layers
.
ClassificationHead
]:
if
cls_head_configs
:
return
[
layers
.
ClassificationHead
(
**
cfg
.
as_dict
())
for
cfg
in
cls_head_configs
]
else
:
return
[]
def
instantiate_pretrainer_from_cfg
(
config
:
ELECTRAPretrainerConfig
,
generator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
discriminator_network
:
Optional
[
tf
.
keras
.
Model
]
=
None
,
)
->
electra_pretrainer
.
ElectraPretrainer
:
"""Instantiates ElectraPretrainer from the config."""
generator_encoder_cfg
=
config
.
generator_encoder
discriminator_encoder_cfg
=
config
.
discriminator_encoder
if
generator_network
is
None
:
generator_network
=
encoders
.
instantiate_encoder_from_cfg
(
generator_encoder_cfg
)
if
discriminator_network
is
None
:
discriminator_network
=
encoders
.
instantiate_encoder_from_cfg
(
discriminator_encoder_cfg
)
return
electra_pretrainer
.
ElectraPretrainer
(
generator_network
=
generator_network
,
discriminator_network
=
discriminator_network
,
vocab_size
=
config
.
generator_encoder
.
vocab_size
,
num_classes
=
config
.
num_classes
,
sequence_length
=
config
.
sequence_length
,
last_hidden_dim
=
config
.
generator_encoder
.
hidden_size
,
num_token_predictions
=
config
.
num_masked_tokens
,
mlm_activation
=
tf_utils
.
get_activation
(
generator_encoder_cfg
.
hidden_activation
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
generator_encoder_cfg
.
initializer_range
),
classification_heads
=
instantiate_classification_heads_from_cfgs
(
config
.
cls_heads
))
official/nlp/configs/electra_test.py
deleted
100644 → 0
View file @
8e9296ff
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for ELECTRA configurations and models instantiation."""
import
tensorflow
as
tf
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
electra
from
official.nlp.configs
import
encoders
class
ELECTRAModelsTest
(
tf
.
test
.
TestCase
):
def
test_network_invocation
(
self
):
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
)
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
# Invokes with classification heads.
config
=
electra
.
ELECTRAPretrainerConfig
(
generator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
1
),
discriminator_encoder
=
encoders
.
TransformerEncoderConfig
(
vocab_size
=
10
,
num_layers
=
2
),
cls_heads
=
[
bert
.
ClsHeadConfig
(
inner_dim
=
10
,
num_classes
=
2
,
name
=
"next_sentence"
)
])
_
=
electra
.
instantiate_pretrainer_from_cfg
(
config
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/encoders.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,22 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transformer Encoders.
Includes configurations and
instantiation
methods.
Includes configurations and
factory
methods.
"""
from
typing
import
Optional
from
absl
import
logging
import
dataclasses
import
gin
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.modeling
import
networks
from
official.nlp.projects.bigbird
import
encoder
as
bigbird_encoder
@
dataclasses
.
dataclass
class
Transform
erEncoderConfig
(
base_config
.
Config
):
class
B
er
t
EncoderConfig
(
hyperparams
.
Config
):
"""BERT encoder configuration."""
vocab_size
:
int
=
30522
hidden_size
:
int
=
768
...
...
@@ -40,56 +43,303 @@ class TransformerEncoderConfig(base_config.Config):
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
embedding_size
:
Optional
[
int
]
=
None
output_range
:
Optional
[
int
]
=
None
return_all_encoder_outputs
:
bool
=
False
@
dataclasses
.
dataclass
class
MobileBertEncoderConfig
(
hyperparams
.
Config
):
"""MobileBERT encoder configuration.
Attributes:
word_vocab_size: number of words in the vocabulary.
word_embed_size: word embedding size.
type_vocab_size: number of word types.
max_sequence_length: maximum length of input sequence.
num_blocks: number of transformer block in the encoder model.
hidden_size: the hidden size for the transformer block.
num_attention_heads: number of attention heads in the transformer block.
intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
layer.
hidden_activation: the non-linear activation function to apply to the
output of the intermediate/feed-forward layer.
hidden_dropout_prob: dropout probability for the hidden layers.
attention_probs_dropout_prob: dropout probability of the attention
probabilities.
intra_bottleneck_size: the size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for
initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: whether to share linear transformation for keys
and queries.
num_feedforward_networks: number of stacked feed-forward networks.
normalization_type: the type of normalization_type, only 'no_norm' and
'layer_norm' are supported. 'no_norm' represents the element-wise linear
transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model.
classifier_activation: if using the tanh activation for the final
representation of the [CLS] token in fine-tuning.
"""
word_vocab_size
:
int
=
30522
word_embed_size
:
int
=
128
type_vocab_size
:
int
=
2
max_sequence_length
:
int
=
512
num_blocks
:
int
=
24
hidden_size
:
int
=
512
num_attention_heads
:
int
=
4
intermediate_size
:
int
=
4096
hidden_activation
:
str
=
"gelu"
hidden_dropout_prob
:
float
=
0.1
attention_probs_dropout_prob
:
float
=
0.1
intra_bottleneck_size
:
int
=
1024
initializer_range
:
float
=
0.02
use_bottleneck_attention
:
bool
=
False
key_query_shared_bottleneck
:
bool
=
False
num_feedforward_networks
:
int
=
1
normalization_type
:
str
=
"layer_norm"
classifier_activation
:
bool
=
True
input_mask_dtype
:
str
=
"int32"
@
dataclasses
.
dataclass
class
AlbertEncoderConfig
(
hyperparams
.
Config
):
"""ALBERT encoder configuration."""
vocab_size
:
int
=
30000
embedding_width
:
int
=
128
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.0
attention_dropout_rate
:
float
=
0.0
max_position_embeddings
:
int
=
512
type_vocab_size
:
int
=
2
initializer_range
:
float
=
0.02
@
dataclasses
.
dataclass
class
BigBirdEncoderConfig
(
hyperparams
.
Config
):
"""BigBird encoder configuration."""
vocab_size
:
int
=
50358
hidden_size
:
int
=
768
num_layers
:
int
=
12
num_attention_heads
:
int
=
12
hidden_activation
:
str
=
"gelu"
intermediate_size
:
int
=
3072
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
max_position_embeddings
:
int
=
4096
num_rand_blocks
:
int
=
3
block_size
:
int
=
64
type_vocab_size
:
int
=
16
initializer_range
:
float
=
0.02
embedding_width
:
Optional
[
int
]
=
None
use_gradient_checkpointing
:
bool
=
False
@
dataclasses
.
dataclass
class
XLNetEncoderConfig
(
hyperparams
.
Config
):
"""XLNet encoder configuration."""
vocab_size
:
int
=
32000
num_layers
:
int
=
24
hidden_size
:
int
=
1024
num_attention_heads
:
int
=
16
head_size
:
int
=
64
inner_size
:
int
=
4096
inner_activation
:
str
=
"gelu"
dropout_rate
:
float
=
0.1
attention_dropout_rate
:
float
=
0.1
attention_type
:
str
=
"bi"
bi_data
:
bool
=
False
tie_attention_biases
:
bool
=
False
memory_length
:
int
=
0
same_length
:
bool
=
False
clamp_length
:
int
=
-
1
reuse_length
:
int
=
0
use_cls_mask
:
bool
=
False
embedding_width
:
int
=
1024
initializer_range
:
float
=
0.02
two_stream
:
bool
=
False
@
dataclasses
.
dataclass
class
EncoderConfig
(
hyperparams
.
OneOfConfig
):
"""Encoder configuration."""
type
:
Optional
[
str
]
=
"bert"
albert
:
AlbertEncoderConfig
=
AlbertEncoderConfig
()
bert
:
BertEncoderConfig
=
BertEncoderConfig
()
bigbird
:
BigBirdEncoderConfig
=
BigBirdEncoderConfig
()
mobilebert
:
MobileBertEncoderConfig
=
MobileBertEncoderConfig
()
xlnet
:
XLNetEncoderConfig
=
XLNetEncoderConfig
()
ENCODER_CLS
=
{
"bert"
:
networks
.
BertEncoder
,
"mobilebert"
:
networks
.
MobileBERTEncoder
,
"albert"
:
networks
.
AlbertEncoder
,
"bigbird"
:
bigbird_encoder
.
BigBirdEncoder
,
"xlnet"
:
networks
.
XLNetBase
,
}
@
gin
.
configurable
def
instantiate_encoder_from_cfg
(
config
:
TransformerEncoderConfig
,
encoder_cls
=
networks
.
TransformerEncoder
):
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
def
build_encoder
(
config
:
EncoderConfig
,
embedding_layer
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
encoder_cls
=
None
,
bypass_config
:
bool
=
False
):
"""Instantiate a Transformer encoder network from EncoderConfig.
Args:
config: the one-of encoder config, which provides encoder parameters of a
chosen encoder.
embedding_layer: an external embedding layer passed to the encoder.
encoder_cls: an external encoder cls not included in the supported encoders,
usually used by gin.configurable.
bypass_config: whether to ignore config instance to create the object with
`encoder_cls`.
Returns:
An encoder instance.
"""
encoder_type
=
config
.
type
encoder_cfg
=
config
.
get
()
encoder_cls
=
encoder_cls
or
ENCODER_CLS
[
encoder_type
]
logging
.
info
(
"Encoder class: %s to build..."
,
encoder_cls
.
__name__
)
if
bypass_config
:
return
encoder_cls
()
if
encoder_cls
.
__name__
==
"EncoderScaffold"
:
embedding_cfg
=
dict
(
vocab_size
=
config
.
vocab_size
,
type_vocab_size
=
config
.
type_vocab_size
,
hidden_size
=
config
.
hidden_size
,
seq_length
=
None
,
max_seq_length
=
config
.
max_position_embeddings
,
vocab_size
=
encoder_cfg
.
vocab_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
max_seq_length
=
encoder_cfg
.
max_position_embeddings
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
),
dropout_rate
=
confi
g
.
dropout_rate
,
stddev
=
encoder_cf
g
.
initializer_range
),
dropout_rate
=
encoder_cf
g
.
dropout_rate
,
)
hidden_cfg
=
dict
(
num_attention_heads
=
confi
g
.
num_attention_heads
,
intermediate_size
=
confi
g
.
intermediate_size
,
num_attention_heads
=
encoder_cf
g
.
num_attention_heads
,
intermediate_size
=
encoder_cf
g
.
intermediate_size
,
intermediate_activation
=
tf_utils
.
get_activation
(
confi
g
.
hidden_activation
),
dropout_rate
=
confi
g
.
dropout_rate
,
attention_dropout_rate
=
confi
g
.
attention_dropout_rate
,
encoder_cf
g
.
hidden_activation
),
dropout_rate
=
encoder_cf
g
.
dropout_rate
,
attention_dropout_rate
=
encoder_cf
g
.
attention_dropout_rate
,
kernel_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
confi
g
.
initializer_range
),
stddev
=
encoder_cf
g
.
initializer_range
),
)
kwargs
=
dict
(
embedding_cfg
=
embedding_cfg
,
hidden_cfg
=
hidden_cfg
,
num_hidden_instances
=
confi
g
.
num_layers
,
pooled_output_dim
=
confi
g
.
hidden_size
,
num_hidden_instances
=
encoder_cf
g
.
num_layers
,
pooled_output_dim
=
encoder_cf
g
.
hidden_size
,
pooler_layer_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
stddev
=
encoder_cfg
.
initializer_range
),
return_all_layer_outputs
=
encoder_cfg
.
return_all_encoder_outputs
,
dict_outputs
=
True
)
return
encoder_cls
(
**
kwargs
)
if
encoder_cls
.
__name__
!=
"TransformerEncoder"
:
raise
ValueError
(
"Unknown encoder network class. %s"
%
str
(
encoder_cls
))
encoder_network
=
encoder_cls
(
vocab_size
=
config
.
vocab_size
,
hidden_size
=
config
.
hidden_size
,
num_layers
=
config
.
num_layers
,
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
config
.
hidden_activation
),
dropout_rate
=
config
.
dropout_rate
,
attention_dropout_rate
=
config
.
attention_dropout_rate
,
sequence_length
=
None
,
max_sequence_length
=
config
.
max_position_embeddings
,
type_vocab_size
=
config
.
type_vocab_size
,
if
encoder_type
==
"mobilebert"
:
return
encoder_cls
(
word_vocab_size
=
encoder_cfg
.
word_vocab_size
,
word_embed_size
=
encoder_cfg
.
word_embed_size
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
max_sequence_length
=
encoder_cfg
.
max_sequence_length
,
num_blocks
=
encoder_cfg
.
num_blocks
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
intermediate_act_fn
=
encoder_cfg
.
hidden_activation
,
hidden_dropout_prob
=
encoder_cfg
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
encoder_cfg
.
attention_probs_dropout_prob
,
intra_bottleneck_size
=
encoder_cfg
.
intra_bottleneck_size
,
initializer_range
=
encoder_cfg
.
initializer_range
,
use_bottleneck_attention
=
encoder_cfg
.
use_bottleneck_attention
,
key_query_shared_bottleneck
=
encoder_cfg
.
key_query_shared_bottleneck
,
num_feedforward_networks
=
encoder_cfg
.
num_feedforward_networks
,
normalization_type
=
encoder_cfg
.
normalization_type
,
classifier_activation
=
encoder_cfg
.
classifier_activation
,
input_mask_dtype
=
encoder_cfg
.
input_mask_dtype
)
if
encoder_type
==
"albert"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
embedding_width
=
encoder_cfg
.
embedding_width
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
dict_outputs
=
True
)
if
encoder_type
==
"bigbird"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
num_rand_blocks
=
encoder_cfg
.
num_rand_blocks
,
block_size
=
encoder_cfg
.
block_size
,
max_position_embeddings
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
encoder_cfg
.
initializer_range
),
embedding_width
=
encoder_cfg
.
embedding_width
,
use_gradient_checkpointing
=
encoder_cfg
.
use_gradient_checkpointing
)
if
encoder_type
==
"xlnet"
:
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
num_layers
=
encoder_cfg
.
num_layers
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
head_size
=
encoder_cfg
.
head_size
,
inner_size
=
encoder_cfg
.
inner_size
,
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
attention_type
=
encoder_cfg
.
attention_type
,
bi_data
=
encoder_cfg
.
bi_data
,
two_stream
=
encoder_cfg
.
two_stream
,
tie_attention_biases
=
encoder_cfg
.
tie_attention_biases
,
memory_length
=
encoder_cfg
.
memory_length
,
clamp_length
=
encoder_cfg
.
clamp_length
,
reuse_length
=
encoder_cfg
.
reuse_length
,
inner_activation
=
encoder_cfg
.
inner_activation
,
use_cls_mask
=
encoder_cfg
.
use_cls_mask
,
embedding_width
=
encoder_cfg
.
embedding_width
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
# Uses the default BERTEncoder configuration schema to create the encoder.
# If it does not match, please add a switch branch by the encoder type.
return
encoder_cls
(
vocab_size
=
encoder_cfg
.
vocab_size
,
hidden_size
=
encoder_cfg
.
hidden_size
,
num_layers
=
encoder_cfg
.
num_layers
,
num_attention_heads
=
encoder_cfg
.
num_attention_heads
,
intermediate_size
=
encoder_cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
encoder_cfg
.
hidden_activation
),
dropout_rate
=
encoder_cfg
.
dropout_rate
,
attention_dropout_rate
=
encoder_cfg
.
attention_dropout_rate
,
max_sequence_length
=
encoder_cfg
.
max_position_embeddings
,
type_vocab_size
=
encoder_cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
config
.
initializer_range
))
return
encoder_network
stddev
=
encoder_cfg
.
initializer_range
),
output_range
=
encoder_cfg
.
output_range
,
embedding_width
=
encoder_cfg
.
embedding_size
,
embedding_layer
=
embedding_layer
,
return_all_encoder_outputs
=
encoder_cfg
.
return_all_encoder_outputs
,
dict_outputs
=
True
)
official/nlp/configs/encoders_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.configs.encoders."""
import
os
import
tensorflow
as
tf
from
official.modeling
import
hyperparams
from
official.nlp.configs
import
encoders
class
EncodersTest
(
tf
.
test
.
TestCase
):
def
test_encoder_from_yaml
(
self
):
config
=
encoders
.
EncoderConfig
(
type
=
"bert"
,
bert
=
encoders
.
BertEncoderConfig
(
num_layers
=
1
))
encoder
=
encoders
.
build_encoder
(
config
)
ckpt
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
ckpt_path
=
ckpt
.
save
(
self
.
get_temp_dir
()
+
"/ckpt"
)
params_save_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"params.yaml"
)
hyperparams
.
save_params_dict_to_yaml
(
config
,
params_save_path
)
retored_cfg
=
encoders
.
EncoderConfig
.
from_yaml
(
params_save_path
)
retored_encoder
=
encoders
.
build_encoder
(
retored_cfg
)
status
=
tf
.
train
.
Checkpoint
(
encoder
=
retored_encoder
).
restore
(
ckpt_path
)
status
.
assert_consumed
()
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/configs/experiment_configs.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiments definition."""
# pylint: disable=unused-import
from
official.nlp.configs
import
finetuning_experiments
from
official.nlp.configs
import
pretraining_experiments
from
official.nlp.configs
import
wmt_transformer_experiments
official/nlp/configs/experiments/glue_mnli_matched.yaml
0 → 100644
View file @
f16a7b5b
task
:
hub_module_url
:
'
'
model
:
num_classes
:
3
init_checkpoint
:
'
'
metric_type
:
'
accuracy'
train_data
:
drop_remainder
:
true
global_batch_size
:
32
input_path
:
'
'
is_training
:
true
seq_length
:
128
label_type
:
'
int'
validation_data
:
drop_remainder
:
false
global_batch_size
:
32
input_path
:
'
'
is_training
:
false
seq_length
:
128
label_type
:
'
int'
trainer
:
checkpoint_interval
:
3000
optimizer_config
:
learning_rate
:
polynomial
:
# 100% of train_steps.
decay_steps
:
36813
end_learning_rate
:
0.0
initial_learning_rate
:
3.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
# ~10% of train_steps.
warmup_steps
:
3681
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
# Training data size 392,702 examples, 3 epochs.
train_steps
:
36813
validation_interval
:
6135
# Eval data size = 9815 examples.
validation_steps
:
307
best_checkpoint_export_subdir
:
'
best_ckpt'
best_checkpoint_eval_metric
:
'
cls_accuracy'
best_checkpoint_metric_comp
:
'
higher'
official/nlp/configs/experiments/squad_v1.yaml
0 → 100644
View file @
f16a7b5b
task
:
hub_module_url
:
'
'
max_answer_length
:
30
n_best_size
:
20
null_score_diff_threshold
:
0.0
init_checkpoint
:
'
'
train_data
:
drop_remainder
:
true
global_batch_size
:
48
input_path
:
'
'
is_training
:
true
seq_length
:
384
validation_data
:
do_lower_case
:
true
doc_stride
:
128
drop_remainder
:
false
global_batch_size
:
48
input_path
:
'
'
is_training
:
false
query_length
:
64
seq_length
:
384
tokenization
:
WordPiece
version_2_with_negative
:
false
vocab_file
:
'
'
trainer
:
checkpoint_interval
:
1000
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
decay_steps
:
3699
end_learning_rate
:
0.0
initial_learning_rate
:
8.0e-05
power
:
1.0
type
:
polynomial
optimizer
:
type
:
adamw
warmup
:
polynomial
:
power
:
1
warmup_steps
:
370
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
train_steps
:
3699
validation_interval
:
1000
validation_steps
:
226
best_checkpoint_export_subdir
:
'
best_ckpt'
best_checkpoint_eval_metric
:
'
final_f1'
best_checkpoint_metric_comp
:
'
higher'
official/nlp/configs/finetuning_experiments.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
question_answering_dataloader
from
official.nlp.data
import
sentence_prediction_dataloader
from
official.nlp.data
import
tagging_dataloader
from
official.nlp.tasks
import
question_answering
from
official.nlp.tasks
import
sentence_prediction
from
official.nlp.tasks
import
tagging
@
exp_factory
.
register_config_factory
(
'bert/sentence_prediction'
)
def
bert_sentence_prediction
()
->
cfg
.
ExperimentConfig
:
r
"""BERT GLUE."""
config
=
cfg
.
ExperimentConfig
(
task
=
sentence_prediction
.
SentencePredictionConfig
(
train_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(),
validation_data
=
sentence_prediction_dataloader
.
SentencePredictionDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
3e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
config
.
task
.
model
.
encoder
.
type
=
'bert'
return
config
@
exp_factory
.
register_config_factory
(
'bert/squad'
)
def
bert_squad
()
->
cfg
.
ExperimentConfig
:
"""BERT Squad V1/V2."""
config
=
cfg
.
ExperimentConfig
(
task
=
question_answering
.
QuestionAnsweringConfig
(
train_data
=
question_answering_dataloader
.
QADataConfig
(),
validation_data
=
question_answering_dataloader
.
QADataConfig
()),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
8e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
config
.
task
.
model
.
encoder
.
type
=
'bert'
return
config
@
exp_factory
.
register_config_factory
(
'bert/tagging'
)
def
bert_tagging
()
->
cfg
.
ExperimentConfig
:
"""BERT tagging task."""
config
=
cfg
.
ExperimentConfig
(
task
=
tagging
.
TaggingConfig
(
train_data
=
tagging_dataloader
.
TaggingDataConfig
(),
validation_data
=
tagging_dataloader
.
TaggingDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
8e-5
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
,
])
return
config
official/nlp/configs/models/bert_en_uncased_base.yaml
0 → 100644
View file @
f16a7b5b
task
:
model
:
encoder
:
type
:
bert
bert
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
hidden_activation
:
gelu
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
type_vocab_size
:
2
vocab_size
:
30522
official/nlp/configs/pretraining_experiments.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretraining experiment configurations."""
# pylint: disable=g-doc-return-or-yield,line-too-long
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dynamic_dataloader
from
official.nlp.tasks
import
masked_lm
_TRAINER
=
cfg
.
TrainerConfig
(
train_steps
=
1000000
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.01
,
'exclude_from_weight_decay'
:
[
'LayerNorm'
,
'layer_norm'
,
'bias'
],
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
1e-4
,
'end_learning_rate'
:
0.0
,
}
},
'warmup'
:
{
'type'
:
'polynomial'
}
}))
@
exp_factory
.
register_config_factory
(
'bert/pretraining'
)
def
bert_pretraining
()
->
cfg
.
ExperimentConfig
:
"""BERT pretraining experiment."""
config
=
cfg
.
ExperimentConfig
(
task
=
masked_lm
.
MaskedLMConfig
(
train_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
_TRAINER
,
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'bert/pretraining_dynamic'
)
def
bert_dynamic
()
->
cfg
.
ExperimentConfig
:
"""BERT base with dynamic input sequences.
TPU needs to run with tf.data service with round-robin behavior.
"""
config
=
cfg
.
ExperimentConfig
(
task
=
masked_lm
.
MaskedLMConfig
(
train_data
=
pretrain_dynamic_dataloader
.
BertPretrainDataConfig
(),
validation_data
=
pretrain_dataloader
.
BertPretrainDataConfig
(
is_training
=
False
)),
trainer
=
_TRAINER
,
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/nlp/configs/wmt_transformer_experiments.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=g-doc-return-or-yield,line-too-long
"""WMT translation configurations."""
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
wmt_dataloader
from
official.nlp.tasks
import
translation
@
exp_factory
.
register_config_factory
(
'wmt_transformer/large'
)
def
wmt_transformer_large
()
->
cfg
.
ExperimentConfig
:
"""WMT Transformer Large.
Please refer to
tensorflow_models/official/nlp/data/train_sentencepiece.py
to generate sentencepiece_model
and pass
--params_override=task.sentencepiece_model_path='YOUR_PATH'
to the train script.
"""
learning_rate
=
2.0
hidden_size
=
1024
learning_rate
*=
(
hidden_size
**-
0.5
)
warmup_steps
=
16000
train_steps
=
300000
token_batch_size
=
24576
encdecoder
=
translation
.
EncDecoder
(
num_attention_heads
=
16
,
intermediate_size
=
hidden_size
*
4
)
config
=
cfg
.
ExperimentConfig
(
task
=
translation
.
TranslationConfig
(
model
=
translation
.
ModelConfig
(
encoder
=
encdecoder
,
decoder
=
encdecoder
,
embedding_width
=
hidden_size
,
padded_decode
=
True
,
decode_max_length
=
100
),
train_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'train'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
True
,
global_batch_size
=
token_batch_size
,
static_batch
=
True
,
max_seq_length
=
64
),
validation_data
=
wmt_dataloader
.
WMTDataConfig
(
tfds_name
=
'wmt14_translate/de-en'
,
tfds_split
=
'test'
,
src_lang
=
'en'
,
tgt_lang
=
'de'
,
is_training
=
False
,
global_batch_size
=
32
,
static_batch
=
True
,
max_seq_length
=
100
,
),
sentencepiece_model_path
=
None
,
),
trainer
=
cfg
.
TrainerConfig
(
train_steps
=
train_steps
,
validation_steps
=-
1
,
steps_per_loop
=
1000
,
summary_interval
=
1000
,
checkpoint_interval
=
5000
,
validation_interval
=
5000
,
max_to_keep
=
1
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adam'
,
'adam'
:
{
'beta_2'
:
0.997
,
'epsilon'
:
1e-9
,
},
},
'learning_rate'
:
{
'type'
:
'power'
,
'power'
:
{
'initial_learning_rate'
:
learning_rate
,
'power'
:
-
0.5
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
warmup_steps
,
'warmup_learning_rate'
:
0.0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.sentencepiece_model_path != None'
,
])
return
config
official/nlp/continuous_finetune_lib.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM continuous finetuning+eval training driver library."""
import
gc
import
os
import
time
from
typing
import
Any
,
Mapping
,
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
train_lib
as
multitask_train_lib
def
_flatten_dict
(
xs
):
"""Flatten a nested dictionary.
The nested keys are flattened to a tuple.
Example::
xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
flat_xs = flatten_dict(xs)
print(flat_xs)
# {
# ('foo',): 1,
# ('bar', 'a'): 2,
# }
Note that empty dictionaries are ignored and
will not be restored by `unflatten_dict`.
Args:
xs: a nested dictionary
Returns:
The flattened dictionary.
"""
assert
isinstance
(
xs
,
dict
),
'input is not a dict'
def
_flatten
(
xs
,
prefix
):
if
not
isinstance
(
xs
,
dict
):
return
{
prefix
:
xs
}
result
=
{}
for
key
,
value
in
xs
.
items
():
path
=
prefix
+
(
key
,)
result
.
update
(
_flatten
(
value
,
path
))
return
result
return
_flatten
(
xs
,
())
def
run_continuous_finetune
(
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
pretrain_steps
:
Optional
[
int
]
=
None
,
)
->
Mapping
[
str
,
Any
]:
"""Run modes with continuous training.
Currently only supports continuous_train_and_eval.
Args:
mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
checkpoint directory. Once a new checkpoint is discovered, loads the
checkpoint, finetune the model by training it (probably on another dataset
or with another task), then evaluate the finetuned model.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
pretrain_steps: Optional, the number of total training steps for the
pretraining job.
Returns:
eval logs: returns eval metrics logs when run_post_eval is set to True,
othewise, returns {}.
"""
assert
mode
==
'continuous_train_and_eval'
,
(
'Only continuous_train_and_eval is supported by continuous_finetune. '
'Got mode: {}'
.
format
(
mode
))
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
retry_times
=
0
while
not
tf
.
io
.
gfile
.
isdir
(
params
.
task
.
init_checkpoint
):
# Wait for the init_checkpoint directory to be created.
if
retry_times
>=
60
:
raise
ValueError
(
'ExperimentConfig.task.init_checkpoint must be a directory for '
'continuous_train_and_eval mode.'
)
retry_times
+=
1
time
.
sleep
(
60
)
summary_writer
=
tf
.
summary
.
create_file_writer
(
os
.
path
.
join
(
model_dir
,
'eval'
))
global_step
=
0
def
timeout_fn
():
if
pretrain_steps
and
global_step
<
pretrain_steps
:
# Keeps waiting for another timeout period.
logging
.
info
(
'Continue waiting for new checkpoint as current pretrain '
'global_step=%d and target is %d.'
,
global_step
,
pretrain_steps
)
return
False
# Quits the loop.
return
True
for
pretrain_ckpt
in
tf
.
train
.
checkpoints_iterator
(
checkpoint_dir
=
params
.
task
.
init_checkpoint
,
min_interval_secs
=
10
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
):
# If there are checkpoints, they might be the finetune checkpoint of a
# different pretrained checkpoint. So we just remove all checkpoints.
train_utils
.
remove_ckpts
(
model_dir
)
with
distribution_strategy
.
scope
():
global_step
=
train_utils
.
read_global_step_from_checkpoint
(
pretrain_ckpt
)
# Replaces params.task.init_checkpoint to make sure that we load
# exactly this pretrain checkpoint.
if
params
.
trainer
.
best_checkpoint_export_subdir
:
best_ckpt_subdir
=
'{}_{}'
.
format
(
params
.
trainer
.
best_checkpoint_export_subdir
,
global_step
)
params_replaced
=
params
.
replace
(
task
=
{
'init_checkpoint'
:
pretrain_ckpt
},
trainer
=
{
'best_checkpoint_export_subdir'
:
best_ckpt_subdir
})
else
:
params_replaced
=
params
.
replace
(
task
=
{
'init_checkpoint'
:
pretrain_ckpt
})
params_replaced
.
lock
()
logging
.
info
(
'Running finetuning with params: %s'
,
params_replaced
)
with
distribution_strategy
.
scope
():
if
isinstance
(
params
,
configs
.
MultiEvalExperimentConfig
):
task
=
task_factory
.
get_task
(
params_replaced
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
params_replaced
.
eval_tasks
)
(
_
,
eval_metrics
)
=
multitask_train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
train_task
=
task
,
eval_tasks
=
eval_tasks
,
mode
=
'train_and_eval'
,
params
=
params_replaced
,
model_dir
=
model_dir
,
run_post_eval
=
True
,
save_summary
=
False
)
else
:
task
=
task_factory
.
get_task
(
params_replaced
.
task
,
logging_dir
=
model_dir
)
_
,
eval_metrics
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'train_and_eval'
,
params
=
params_replaced
,
model_dir
=
model_dir
,
run_post_eval
=
True
,
save_summary
=
False
)
logging
.
info
(
'Evaluation finished. Pretrain global_step: %d'
,
global_step
)
train_utils
.
write_json_summary
(
model_dir
,
global_step
,
eval_metrics
)
if
not
os
.
path
.
basename
(
model_dir
):
# if model_dir.endswith('/')
summary_grp
=
os
.
path
.
dirname
(
model_dir
)
+
'_'
+
task
.
name
else
:
summary_grp
=
os
.
path
.
basename
(
model_dir
)
+
'_'
+
task
.
name
summaries
=
{}
for
name
,
value
in
_flatten_dict
(
eval_metrics
).
items
():
summaries
[
summary_grp
+
'/'
+
'-'
.
join
(
name
)]
=
value
train_utils
.
write_summary
(
summary_writer
,
global_step
,
summaries
)
train_utils
.
remove_ckpts
(
model_dir
)
# In TF2, the resource life cycle is bound with the python object life
# cycle. Force trigger python garbage collection here so those resources
# can be deallocated in time, so it doesn't cause OOM when allocating new
# objects.
# TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
# if we need gc here.
gc
.
collect
()
if
run_post_eval
:
return
eval_metrics
return
{}
official/nlp/continuous_finetune_lib_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
import
tensorflow
as
tf
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.nlp
import
continuous_finetune_lib
FLAGS
=
flags
.
FLAGS
tfm_flags
.
define_flags
()
class
ContinuousFinetuneTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_model_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'model_dir'
)
def
testContinuousFinetune
(
self
):
pretrain_steps
=
1
src_model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
'continuous_train_and_eval'
,
model_dir
=
self
.
_model_dir
,
params_override
=
{
'task'
:
{
'init_checkpoint'
:
src_model_dir
,
},
'trainer'
:
{
'continuous_eval_timeout'
:
1
,
'steps_per_loop'
:
1
,
'train_steps'
:
1
,
'validation_steps'
:
1
,
'best_checkpoint_export_subdir'
:
'best_ckpt'
,
'best_checkpoint_eval_metric'
:
'acc'
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
}
})
with
flagsaver
.
flagsaver
(
**
flags_dict
):
# Train and save some checkpoints.
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
distribution_strategy
=
tf
.
distribute
.
get_strategy
()
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
src_model_dir
)
_
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'train'
,
params
=
params
,
model_dir
=
src_model_dir
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
eval_metrics
=
continuous_finetune_lib
.
run_continuous_finetune
(
FLAGS
.
mode
,
params
,
FLAGS
.
model_dir
,
run_post_eval
=
True
,
pretrain_steps
=
pretrain_steps
)
self
.
assertIn
(
'best_acc'
,
eval_metrics
)
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
FLAGS
.
model_dir
,
'checkpoint'
)))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/data/__init__.py
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/nlp/data/classifier_data_lib.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,16 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for classification task."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""BERT library to process data for classification task."""
import
collections
import
csv
import
importlib
import
json
import
os
from
absl
import
logging
...
...
@@ -39,7 +36,7 @@ class InputExample(object):
text_b
=
None
,
label
=
None
,
weight
=
None
,
int
_id
en
=
None
):
example
_id
=
None
):
"""Constructs a InputExample.
Args:
...
...
@@ -53,15 +50,15 @@ class InputExample(object):
examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int
_id
en
: (Optional) int. The int identification number of example in
the
corpus.
example
_id: (Optional) int. The int identification number of example in
the
corpus.
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
self
.
weight
=
weight
self
.
int
_id
en
=
int
_id
en
self
.
example
_id
=
example
_id
class
InputFeatures
(
object
):
...
...
@@ -74,14 +71,14 @@ class InputFeatures(object):
label_id
,
is_real_example
=
True
,
weight
=
None
,
int
_id
en
=
None
):
example
_id
=
None
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
self
.
is_real_example
=
is_real_example
self
.
weight
=
weight
self
.
int
_id
en
=
int
_id
en
self
.
example
_id
=
example
_id
class
DataProcessor
(
object
):
...
...
@@ -123,6 +120,63 @@ class DataProcessor(object):
lines
.
append
(
line
)
return
lines
@
classmethod
def
_read_jsonl
(
cls
,
input_file
):
"""Reads a json line file."""
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"r"
)
as
f
:
lines
=
[]
for
json_str
in
f
:
lines
.
append
(
json
.
loads
(
json_str
))
return
lines
class
AxProcessor
(
DataProcessor
):
"""Processor for the AX dataset (GLUE diagnostics dataset)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"AX"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
text_a_index
=
1
if
set_type
==
"test"
else
8
text_b_index
=
2
if
set_type
==
"test"
else
9
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
# Skip header.
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
text_a
=
self
.
process_text_fn
(
line
[
text_a_index
])
text_b
=
self
.
process_text_fn
(
line
[
text_b_index
])
if
set_type
==
"test"
:
label
=
"contradiction"
else
:
label
=
self
.
process_text_fn
(
line
[
-
1
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
...
...
@@ -152,10 +206,10 @@ class ColaProcessor(DataProcessor):
return
"COLA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
# Only the test set has a header
for
i
,
line
in
enumerate
(
lines
):
# Only the test set has a header
.
if
set_type
==
"test"
and
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -170,9 +224,55 @@ class ColaProcessor(DataProcessor):
return
examples
class
ImdbProcessor
(
DataProcessor
):
"""Processor for the IMDb dataset."""
def
get_labels
(
self
):
return
[
"neg"
,
"pos"
]
def
get_train_examples
(
self
,
data_dir
):
return
self
.
_create_examples
(
os
.
path
.
join
(
data_dir
,
"train"
))
def
get_dev_examples
(
self
,
data_dir
):
return
self
.
_create_examples
(
os
.
path
.
join
(
data_dir
,
"test"
))
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"IMDB"
def
_create_examples
(
self
,
data_dir
):
"""Creates examples."""
examples
=
[]
for
label
in
[
"neg"
,
"pos"
]:
cur_dir
=
os
.
path
.
join
(
data_dir
,
label
)
for
filename
in
tf
.
io
.
gfile
.
listdir
(
cur_dir
):
if
not
filename
.
endswith
(
"txt"
):
continue
if
len
(
examples
)
%
1000
==
0
:
logging
.
info
(
"Loading dev example %d"
,
len
(
examples
))
path
=
os
.
path
.
join
(
cur_dir
,
filename
)
with
tf
.
io
.
gfile
.
GFile
(
path
,
"r"
)
as
f
:
text
=
f
.
read
().
strip
().
replace
(
"<br />"
,
" "
)
examples
.
append
(
InputExample
(
guid
=
"unused_id"
,
text_a
=
text
,
text_b
=
None
,
label
=
label
))
return
examples
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
def
__init__
(
self
,
mnli_type
=
"matched"
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
MnliProcessor
,
self
).
__init__
(
process_text_fn
)
if
mnli_type
not
in
(
"matched"
,
"mismatched"
):
raise
ValueError
(
"Invalid `mnli_type`: %s"
%
mnli_type
)
self
.
mnli_type
=
mnli_type
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
...
...
@@ -180,14 +280,23 @@ class MnliProcessor(DataProcessor):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_mismatched.tsv"
)),
"dev_mismatched"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
if
self
.
mnli_type
==
"matched"
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_matched.tsv"
)),
"test"
)
else
:
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"test_mismatched.tsv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
...
...
@@ -199,9 +308,9 @@ class MnliProcessor(DataProcessor):
return
"MNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
line
[
0
]))
...
...
@@ -244,9 +353,9 @@ class MrpcProcessor(DataProcessor):
return
"MRPC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -290,7 +399,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
language
,
train_tsv
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -307,7 +416,7 @@ class PawsxProcessor(DataProcessor):
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"dev_2k.tsv"
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -321,7 +430,7 @@ class PawsxProcessor(DataProcessor):
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
lang
,
"test_2k.tsv"
))[
1
:]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"test-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
1
])
text_b
=
self
.
process_text_fn
(
line
[
2
])
...
...
@@ -368,9 +477,9 @@ class QnliProcessor(DataProcessor):
return
"QNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
1
)
...
...
@@ -415,18 +524,24 @@ class QqpProcessor(DataProcessor):
return
"QQP"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
line
[
0
])
try
:
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
5
]
except
IndexError
:
continue
if
set_type
==
"test"
:
text_a
=
line
[
1
]
text_b
=
line
[
2
]
label
=
"0"
else
:
# There appear to be some garbage lines in the train dataset.
try
:
text_a
=
line
[
3
]
text_b
=
line
[
4
]
label
=
line
[
5
]
except
IndexError
:
continue
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
...
...
@@ -462,7 +577,7 @@ class RteProcessor(DataProcessor):
return
"RTE"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -507,9 +622,9 @@ class SstProcessor(DataProcessor):
return
"SST-2"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
...
...
@@ -558,7 +673,7 @@ class StsBProcessor(DataProcessor):
return
"STS-B"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -671,7 +786,7 @@ class TfdsProcessor(DataProcessor):
return
"TFDS_"
+
self
.
dataset_name
def
_create_examples
(
self
,
split_name
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
if
split_name
not
in
self
.
dataset
:
raise
ValueError
(
"Split {} not available."
.
format
(
split_name
))
dataset
=
self
.
dataset
[
split_name
].
as_numpy_iterator
()
...
...
@@ -731,7 +846,7 @@ class WnliProcessor(DataProcessor):
return
"WNLI"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training
and dev
sets."""
"""Creates examples for the training
/dev/test
sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
...
...
@@ -777,7 +892,7 @@ class XnliProcessor(DataProcessor):
"multinli.train.%s.tsv"
%
language
))[
1
:])
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
...
...
@@ -792,7 +907,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.dev.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"dev-%d"
%
i
...
...
@@ -807,7 +922,7 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"xnli.test.tsv"
))
examples_by_lang
=
{
k
:
[]
for
k
in
XnliProcessor
.
supported_languages
}
for
(
i
,
line
)
in
enumerate
(
lines
):
for
i
,
line
in
enumerate
(
lines
):
if
i
==
0
:
continue
guid
=
"test-%d"
%
i
...
...
@@ -833,45 +948,104 @@ class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages
=
[
"de"
,
"en"
,
"es"
,
"fr"
,
"ja"
,
"ko"
,
"zh"
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
translated_data_dir
=
None
,
only_use_en_dev
=
True
):
"""See base class.
Args:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training and testing data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
XtremePawsxProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
translated_data_dir
=
translated_data_dir
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
None
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-train"
,
f
"en-
{
lang
}
-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"train-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
4
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
only_use_en_dev
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"dev-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{}
for
lang
in
self
.
supported_languages
:
examples_by_lang
[
lang
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"test-
%d"
%
i
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"0"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
not
None
:
for
lang
in
self
.
supported_languages
:
if
lang
==
"en"
:
continue
examples_by_lang
[
f
"
{
lang
}
-en"
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-test"
,
f
"test-
{
lang
}
-en-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-en-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
"0"
examples_by_lang
[
f
"
{
lang
}
-en"
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
...
...
@@ -891,45 +1065,111 @@ class XtremeXnliProcessor(DataProcessor):
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
,
translated_data_dir
=
None
,
only_use_en_dev
=
True
):
"""See base class.
Args:
process_text_fn: See base class.
translated_data_dir: If specified, will also include translated data in
the training data.
only_use_en_dev: If True, only use english dev data. Otherwise, use dev
data from all languages.
"""
super
(
XtremeXnliProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
translated_data_dir
=
translated_data_dir
self
.
only_use_en_dev
=
only_use_en_dev
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
None
:
for
i
,
line
in
enumerate
(
lines
):
guid
=
"train-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-train"
,
f
"en-
{
lang
}
-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"train-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
self
.
process_text_fn
(
line
[
4
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
only_use_en_dev
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev-en.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
"dev-%d"
%
i
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
else
:
for
lang
in
self
.
supported_languages
:
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"dev-
{
lang
}
.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"dev-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
self
.
process_text_fn
(
line
[
2
])
if
label
==
self
.
process_text_fn
(
"contradictory"
):
label
=
self
.
process_text_fn
(
"contradiction"
)
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
examples_by_lang
=
{
k
:
[]
for
k
in
self
.
supported_languages
}
examples_by_lang
=
{}
for
lang
in
self
.
supported_languages
:
examples_by_lang
[
lang
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
f
"test-
{
lang
}
.tsv"
))
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
f
"test-
{
i
}
"
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
0
])
text_b
=
self
.
process_text_fn
(
line
[
1
])
label
=
"contradiction"
examples_by_lang
[
lang
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
if
self
.
translated_data_dir
is
not
None
:
for
lang
in
self
.
supported_languages
:
if
lang
==
"en"
:
continue
examples_by_lang
[
f
"
{
lang
}
-en"
]
=
[]
lines
=
self
.
_read_tsv
(
os
.
path
.
join
(
self
.
translated_data_dir
,
"translate-test"
,
f
"test-
{
lang
}
-en-translated.tsv"
))
for
i
,
line
in
enumerate
(
lines
):
guid
=
f
"test-
{
lang
}
-en-
{
i
}
"
text_a
=
self
.
process_text_fn
(
line
[
2
])
text_b
=
self
.
process_text_fn
(
line
[
3
])
label
=
"contradiction"
examples_by_lang
[
f
"
{
lang
}
-en"
].
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples_by_lang
def
get_labels
(
self
):
...
...
@@ -965,6 +1205,11 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
seg_id_a
=
0
seg_id_b
=
1
seg_id_cls
=
0
seg_id_pad
=
0
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
...
...
@@ -986,19 +1231,19 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_id_cls
)
for
token
in
tokens_a
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_id_a
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_id_a
)
if
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
seg_id_b
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
seg_id_b
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
...
...
@@ -1010,7 +1255,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
seg_id_pad
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
...
...
@@ -1027,7 +1272,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"label: %s (id = %s)"
,
example
.
label
,
str
(
label_id
))
logging
.
info
(
"weight: %s"
,
example
.
weight
)
logging
.
info
(
"
int
_id
en
: %s"
,
str
(
example
.
int_iden
)
)
logging
.
info
(
"
example
_id: %s"
,
example
.
example_id
)
feature
=
InputFeatures
(
input_ids
=
input_ids
,
...
...
@@ -1036,11 +1281,86 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id
=
label_id
,
is_real_example
=
True
,
weight
=
example
.
weight
,
int
_id
en
=
example
.
int
_id
en
)
example
_id
=
example
.
example
_id
)
return
feature
class
AXgProcessor
(
DataProcessor
):
"""Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"AX-g.jsonl"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"entailment"
,
"not_entailment"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"AXg"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
line
in
lines
:
guid
=
"%s-%s"
%
(
set_type
,
self
.
process_text_fn
(
str
(
line
[
"idx"
])))
text_a
=
self
.
process_text_fn
(
line
[
"premise"
])
text_b
=
self
.
process_text_fn
(
line
[
"hypothesis"
])
label
=
self
.
process_text_fn
(
line
[
"label"
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
class
SuperGLUERTEProcessor
(
DataProcessor
):
"""Processor for the RTE dataset (SuperGLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"train.jsonl"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"val.jsonl"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
return
self
.
_create_examples
(
self
.
_read_jsonl
(
os
.
path
.
join
(
data_dir
,
"test.jsonl"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return
[
"entailment"
,
"not_entailment"
]
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"RTESuperGLUE"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training/dev/test sets."""
examples
=
[]
for
i
,
line
in
enumerate
(
lines
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
"premise"
])
text_b
=
self
.
process_text_fn
(
line
[
"hypothesis"
])
if
set_type
==
"test"
:
label
=
"entailment"
else
:
label
=
self
.
process_text_fn
(
line
[
"label"
])
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
def
file_based_convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
...
...
@@ -1052,7 +1372,7 @@ def file_based_convert_examples_to_features(examples,
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
output_file
))
writer
=
tf
.
io
.
TFRecordWriter
(
output_file
)
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
ex_index
,
example
in
enumerate
(
examples
):
if
ex_index
%
10000
==
0
:
logging
.
info
(
"Writing example %d of %d"
,
ex_index
,
len
(
examples
))
...
...
@@ -1079,8 +1399,10 @@ def file_based_convert_examples_to_features(examples,
[
int
(
feature
.
is_real_example
)])
if
feature
.
weight
is
not
None
:
features
[
"weight"
]
=
create_float_feature
([
feature
.
weight
])
if
feature
.
int_iden
is
not
None
:
features
[
"int_iden"
]
=
create_int_feature
([
feature
.
int_iden
])
if
feature
.
example_id
is
not
None
:
features
[
"example_id"
]
=
create_int_feature
([
feature
.
example_id
])
else
:
features
[
"example_id"
]
=
create_int_feature
([
ex_index
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
...
...
@@ -1113,7 +1435,7 @@ def generate_tf_record_from_data_file(processor,
max_seq_length
=
128
):
"""Generates and saves training data into a tf record file.
Arg
ument
s:
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval/test data to process.
...
...
@@ -1137,13 +1459,15 @@ def generate_tf_record_from_data_file(processor,
label_type
=
getattr
(
processor
,
"label_type"
,
None
)
is_regression
=
getattr
(
processor
,
"is_regression"
,
False
)
has_sample_weights
=
getattr
(
processor
,
"weight_key"
,
False
)
assert
train_data_output_path
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
max_seq_length
,
tokenizer
,
train_data_output_path
,
label_type
)
num_training_data
=
len
(
train_input_data_examples
)
num_training_data
=
0
if
train_data_output_path
:
train_input_data_examples
=
processor
.
get_train_examples
(
data_dir
)
file_based_convert_examples_to_features
(
train_input_data_examples
,
label_list
,
max_seq_length
,
tokenizer
,
train_data_output_path
,
label_type
)
num_training_data
=
len
(
train_input_data_examples
)
if
eval_data_output_path
:
eval_input_data_examples
=
processor
.
get_dev_examples
(
data_dir
)
...
...
official/nlp/data/create_finetuning_data.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,17 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT finetuning task dataset generator."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""BERT finetuning task dataset generator."""
import
functools
import
json
import
os
# Import libraries
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
...
...
@@ -49,41 +46,60 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task."
)
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"COLA"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
# XNLI task specific flag.
flags
.
DEFINE_enum
(
"classification_task_name"
,
"MNLI"
,
[
"AX"
,
"COLA"
,
"IMDB"
,
"MNLI"
,
"MRPC"
,
"PAWS-X"
,
"QNLI"
,
"QQP"
,
"RTE"
,
"SST-2"
,
"STS-B"
,
"WNLI"
,
"XNLI"
,
"XTREME-XNLI"
,
"XTREME-PAWS-X"
,
"AX-g"
,
"SUPERGLUE-RTE"
],
"The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
# MNLI task-specific flag.
flags
.
DEFINE_enum
(
"mnli_type"
,
"matched"
,
[
"matched"
,
"mismatched"
],
"The type of MNLI dataset."
)
# XNLI task-specific flag.
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
"Language of training data for XN
I
L task. If the value is 'all', the data "
"Language of training data for XNL
I
task. If the value is 'all', the data "
"of all languages will be used for training."
)
# PAWS-X task
specific flag.
# PAWS-X task
-
specific flag.
flags
.
DEFINE_string
(
"pawsx_language"
,
"en"
,
"Language of trainig data for PAWS-X task. If the value is 'all', the data "
"Language of traini
n
g data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training."
)
# Retrieva task specific flags
# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
flags
.
DEFINE_string
(
"translated_input_data_dir"
,
None
,
"The translated input data dir. Should contain the .tsv files (or other "
"data files) for the task."
)
# Retrieval task-specific flags.
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# Tagging task
specific flags
# Tagging task
-
specific flags
.
flags
.
DEFINE_enum
(
"tagging_task_name"
,
"panx"
,
[
"panx"
,
"udpos"
],
"The name of BERT tagging (token classification) task."
)
# BERT Squad task specific flags.
flags
.
DEFINE_bool
(
"tagging_only_use_en_train"
,
True
,
"Whether only use english training data in tagging."
)
# BERT Squad task-specific flags.
flags
.
DEFINE_string
(
"squad_data_file"
,
None
,
"The input data file in for generating training data for BERT squad task."
)
flags
.
DEFINE_string
(
"translated_squad_data_folder"
,
None
,
"The translated data folder for generating training data for BERT squad "
"task."
)
flags
.
DEFINE_integer
(
"doc_stride"
,
128
,
"When splitting up a long document into chunks, how much stride to "
...
...
@@ -98,6 +114,14 @@ flags.DEFINE_bool(
"version_2_with_negative"
,
False
,
"If true, the SQuAD examples contain some that do not have an answer."
)
flags
.
DEFINE_bool
(
"xlnet_format"
,
False
,
"If true, then data will be preprocessed in a paragraph, query, class order"
" instead of the BERT-style class, paragraph, query order."
)
# XTREME specific flags.
flags
.
DEFINE_bool
(
"only_use_en_dev"
,
True
,
"Whether only use english dev data."
)
# Shared flags across BERT fine-tuning tasks.
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on."
)
...
...
@@ -136,36 +160,35 @@ flags.DEFINE_string("sp_model_file", "",
"The path to the model used by sentence piece tokenizer."
)
flags
.
DEFINE_enum
(
"tokeniz
er_impl
"
,
"
w
ord
_p
iece"
,
[
"
w
ord
_p
iece"
,
"
s
entence
_p
iece"
],
"Specifies the tokenizer implementation, i.e., whe
h
ter to use
w
ord
_p
iece "
"or
s
entence
_p
iece tokenizer. Canonical BERT uses
w
ord
_p
iece tokenizer, "
"while ALBERT uses
s
entence
_p
iece tokenizer."
)
"tokeniz
ation
"
,
"
W
ord
P
iece"
,
[
"
W
ord
P
iece"
,
"
S
entence
P
iece"
],
"Specifies the tokenizer implementation, i.e., whet
h
er to use
W
ord
P
iece "
"or
S
entence
P
iece tokenizer. Canonical BERT uses
W
ord
P
iece tokenizer, "
"while ALBERT uses
S
entence
P
iece tokenizer."
)
flags
.
DEFINE_string
(
"tfds_params"
,
""
,
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation)."
)
flags
.
DEFINE_string
(
"tfds_params"
,
""
,
"Comma-separated list of TFDS parameter assigments for "
"generic classfication data import (for more details "
"see the TfdsProcessor class documentation)."
)
def
generate_classifier_dataset
():
"""Generates classifier dataset and returns input meta data."""
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
FLAGS
.
tfds_params
)
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
classification_task_name
or
FLAGS
.
tfds_params
)
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
else
:
assert
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
assert
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
if
FLAGS
.
tfds_params
:
processor
=
classifier_data_lib
.
TfdsProcessor
(
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
,
None
,
...
...
@@ -176,31 +199,51 @@ def generate_classifier_dataset():
max_seq_length
=
FLAGS
.
max_seq_length
)
else
:
processors
=
{
"ax"
:
classifier_data_lib
.
AxProcessor
,
"cola"
:
classifier_data_lib
.
ColaProcessor
,
"imdb"
:
classifier_data_lib
.
ImdbProcessor
,
"mnli"
:
classifier_data_lib
.
MnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
MnliProcessor
,
mnli_type
=
FLAGS
.
mnli_type
),
"mrpc"
:
classifier_data_lib
.
MrpcProcessor
,
"qnli"
:
classifier_data_lib
.
QnliProcessor
,
"qqp"
:
classifier_data_lib
.
QqpProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"qqp"
:
classifier_data_lib
.
QqpProcessor
,
"rte"
:
classifier_data_lib
.
RteProcessor
,
"sst-2"
:
classifier_data_lib
.
SstProcessor
,
"sts-b"
:
classifier_data_lib
.
StsBProcessor
,
"xnli"
:
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
functools
.
partial
(
classifier_data_lib
.
XnliProcessor
,
language
=
FLAGS
.
xnli_language
),
"paws-x"
:
functools
.
partial
(
classifier_data_lib
.
PawsxProcessor
,
language
=
FLAGS
.
pawsx_language
),
"wnli"
:
classifier_data_lib
.
WnliProcessor
,
functools
.
partial
(
classifier_data_lib
.
PawsxProcessor
,
language
=
FLAGS
.
pawsx_language
),
"wnli"
:
classifier_data_lib
.
WnliProcessor
,
"xtreme-xnli"
:
functools
.
partial
(
classifier_data_lib
.
XtremeXnliProcessor
),
functools
.
partial
(
classifier_data_lib
.
XtremeXnliProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"xtreme-paws-x"
:
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
)
functools
.
partial
(
classifier_data_lib
.
XtremePawsxProcessor
,
translated_data_dir
=
FLAGS
.
translated_input_data_dir
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"ax-g"
:
classifier_data_lib
.
AXgProcessor
,
"superglue-rte"
:
classifier_data_lib
.
SuperGLUERTEProcessor
}
task_name
=
FLAGS
.
classification_task_name
.
lower
()
if
task_name
not
in
processors
:
...
...
@@ -219,20 +262,19 @@ def generate_classifier_dataset():
def
generate_regression_dataset
():
"""Generates regression dataset and returns input meta data."""
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
else
:
assert
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
assert
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
if
FLAGS
.
tfds_params
:
processor
=
classifier_data_lib
.
TfdsProcessor
(
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
tfds_params
=
FLAGS
.
tfds_params
,
process_text_fn
=
processor_text_fn
)
return
classifier_data_lib
.
generate_tf_record_from_data_file
(
processor
,
None
,
...
...
@@ -248,28 +290,42 @@ def generate_regression_dataset():
def
generate_squad_dataset
():
"""Generates squad training dataset and returns input meta data."""
assert
FLAGS
.
squad_data_file
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
return
squad_lib_wp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
vocab_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
input_file_path
=
FLAGS
.
squad_data_file
,
vocab_file_path
=
FLAGS
.
vocab_file
,
output_path
=
FLAGS
.
train_data_output_path
,
translated_input_folder
=
FLAGS
.
translated_squad_data_folder
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
doc_stride
=
FLAGS
.
doc_stride
,
version_2_with_negative
=
FLAGS
.
version_2_with_negative
,
xlnet_format
=
FLAGS
.
xlnet_format
)
else
:
assert
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
assert
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
return
squad_lib_sp
.
generate_tf_record_from_json_file
(
FLAGS
.
squad_data_file
,
FLAGS
.
sp_model_file
,
FLAGS
.
train_data_output_path
,
FLAGS
.
max_seq_length
,
FLAGS
.
do_lower_case
,
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
input_file_path
=
FLAGS
.
squad_data_file
,
sp_model_file
=
FLAGS
.
sp_model_file
,
output_path
=
FLAGS
.
train_data_output_path
,
translated_input_folder
=
FLAGS
.
translated_squad_data_folder
,
max_seq_length
=
FLAGS
.
max_seq_length
,
do_lower_case
=
FLAGS
.
do_lower_case
,
max_query_length
=
FLAGS
.
max_query_length
,
doc_stride
=
FLAGS
.
doc_stride
,
xlnet_format
=
FLAGS
.
xlnet_format
,
version_2_with_negative
=
FLAGS
.
version_2_with_negative
)
def
generate_retrieval_dataset
():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
retrieval_task_name
)
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
else
:
assert
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
assert
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
...
...
@@ -286,34 +342,38 @@ def generate_retrieval_dataset():
processor
=
processors
[
task_name
](
process_text_fn
=
processor_text_fn
)
return
sentence_retrieval_lib
.
generate_sentence_retrevial_tf_record
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
FLAGS
.
max_seq_length
)
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
FLAGS
.
max_seq_length
)
def
generate_tagging_dataset
():
"""Generates tagging dataset."""
processors
=
{
"panx"
:
tagging_data_lib
.
PanxProcessor
,
"udpos"
:
tagging_data_lib
.
UdposProcessor
,
"panx"
:
functools
.
partial
(
tagging_data_lib
.
PanxProcessor
,
only_use_en_train
=
FLAGS
.
tagging_only_use_en_train
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
"udpos"
:
functools
.
partial
(
tagging_data_lib
.
UdposProcessor
,
only_use_en_train
=
FLAGS
.
tagging_only_use_en_train
,
only_use_en_dev
=
FLAGS
.
only_use_en_dev
),
}
task_name
=
FLAGS
.
tagging_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
task_name
)
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
elif
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
:
elif
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
:
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
else
:
raise
ValueError
(
"Unsupported tokeniz
er_impl
: %s"
%
FLAGS
.
tokeniz
er_impl
)
raise
ValueError
(
"Unsupported tokeniz
ation
: %s"
%
FLAGS
.
tokeniz
ation
)
processor
=
processors
[
task_name
]()
return
tagging_data_lib
.
generate_tf_record_from_data_file
(
...
...
@@ -323,12 +383,12 @@ def generate_tagging_dataset():
def
main
(
_
):
if
FLAGS
.
tokeniz
er_impl
==
"
w
ord
_p
iece"
:
if
FLAGS
.
tokeniz
ation
==
"
W
ord
P
iece"
:
if
not
FLAGS
.
vocab_file
:
raise
ValueError
(
"FLAG vocab_file for word-piece tokenizer is not specified."
)
else
:
assert
FLAGS
.
tokeniz
er_impl
==
"
s
entence
_p
iece"
assert
FLAGS
.
tokeniz
ation
==
"
S
entence
P
iece"
if
not
FLAGS
.
sp_model_file
:
raise
ValueError
(
"FLAG sp_model_file for sentence-piece tokenizer is not specified."
)
...
...
official/nlp/data/create_pretraining_data.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
itertools
import
random
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
...
...
@@ -48,10 +47,20 @@ flags.DEFINE_bool(
"do_whole_word_mask"
,
False
,
"Whether to use whole word masking rather than per-WordPiece masking."
)
flags
.
DEFINE_integer
(
"max_ngram_size"
,
None
,
"Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
"weighting scheme to favor shorter n-grams. "
"Note: `--do_whole_word_mask=True` must also be set when n-gram masking."
)
flags
.
DEFINE_bool
(
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
flags
.
DEFINE_bool
(
"use_v2_feature_names"
,
False
,
"Whether to use the feature names consistent with the models."
)
flags
.
DEFINE_integer
(
"max_seq_length"
,
128
,
"Maximum sequence length."
)
flags
.
DEFINE_integer
(
"max_predictions_per_seq"
,
20
,
...
...
@@ -101,8 +110,8 @@ class TrainingInstance(object):
def
write_instance_to_example_files
(
instances
,
tokenizer
,
max_seq_length
,
max_predictions_per_seq
,
output_files
,
gzip_compress
):
"""Create TF example files from `TrainingInstance`s."""
gzip_compress
,
use_v2_feature_names
):
"""Create
s
TF example files from `TrainingInstance`s."""
writers
=
[]
for
output_file
in
output_files
:
writers
.
append
(
...
...
@@ -139,9 +148,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
next_sentence_label
=
1
if
instance
.
is_random_next
else
0
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
if
use_v2_feature_names
:
features
[
"input_word_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_type_ids"
]
=
create_int_feature
(
segment_ids
)
else
:
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"masked_lm_positions"
]
=
create_int_feature
(
masked_lm_positions
)
features
[
"masked_lm_ids"
]
=
create_int_feature
(
masked_lm_ids
)
features
[
"masked_lm_weights"
]
=
create_float_feature
(
masked_lm_weights
)
...
...
@@ -192,7 +206,8 @@ def create_training_instances(input_files,
masked_lm_prob
,
max_predictions_per_seq
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
...
...
@@ -229,7 +244,7 @@ def create_training_instances(input_files,
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
))
do_whole_word_mask
,
max_ngram_size
))
rng
.
shuffle
(
instances
)
return
instances
...
...
@@ -238,7 +253,8 @@ def create_training_instances(input_files,
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
=
False
):
do_whole_word_mask
=
False
,
max_ngram_size
=
None
):
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
...
...
@@ -337,7 +353,7 @@ def create_instances_from_document(
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
)
do_whole_word_mask
,
max_ngram_size
)
instance
=
TrainingInstance
(
tokens
=
tokens
,
segment_ids
=
segment_ids
,
...
...
@@ -355,72 +371,238 @@ def create_instances_from_document(
MaskedLmInstance
=
collections
.
namedtuple
(
"MaskedLmInstance"
,
[
"index"
,
"label"
])
# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
# words: ["The", "doghouse"]
# tokens: ["The", "dog", "##house"]
# grams: [(0,1), (1,3)]
_Gram
=
collections
.
namedtuple
(
"_Gram"
,
[
"begin"
,
"end"
])
def
_window
(
iterable
,
size
):
"""Helper to create a sliding window iterator with a given size.
E.g.,
input = [1, 2, 3, 4]
_window(input, 1) => [1], [2], [3], [4]
_window(input, 2) => [1, 2], [2, 3], [3, 4]
_window(input, 3) => [1, 2, 3], [2, 3, 4]
_window(input, 4) => [1, 2, 3, 4]
_window(input, 5) => None
Args:
iterable: elements to iterate over.
size: size of the window.
Yields:
Elements of `iterable` batched into a sliding window of length `size`.
"""
i
=
iter
(
iterable
)
window
=
[]
try
:
for
e
in
range
(
0
,
size
):
window
.
append
(
next
(
i
))
yield
window
except
StopIteration
:
# handle the case where iterable's length is less than the window size.
return
for
e
in
i
:
window
=
window
[
1
:]
+
[
e
]
yield
window
def
_contiguous
(
sorted_grams
):
"""Test whether a sequence of grams is contiguous.
Args:
sorted_grams: _Grams which are sorted in increasing order.
Returns:
True if `sorted_grams` are touching each other.
E.g.,
_contiguous([(1, 4), (4, 5), (5, 10)]) == True
_contiguous([(1, 2), (4, 5)]) == False
"""
for
a
,
b
in
_window
(
sorted_grams
,
2
):
if
a
.
end
!=
b
.
begin
:
return
False
return
True
def
_masking_ngrams
(
grams
,
max_ngram_size
,
max_masked_tokens
,
rng
):
"""Create a list of masking {1, ..., n}-grams from a list of one-grams.
This is an extention of 'whole word masking' to mask multiple, contiguous
words such as (e.g., "the red boat").
Each input gram represents the token indices of a single word,
words: ["the", "red", "boat"]
tokens: ["the", "red", "boa", "##t"]
grams: [(0,1), (1,2), (2,4)]
For a `max_ngram_size` of three, possible outputs masks include:
1-grams: (0,1), (1,2), (2,4)
2-grams: (0,2), (1,4)
3-grams; (0,4)
Output masks will not overlap and contain less than `max_masked_tokens` total
tokens. E.g., for the example above with `max_masked_tokens` as three,
valid outputs are,
[(0,1), (1,2)] # "the", "red" covering two tokens
[(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
The length of the selected n-gram follows a zipf weighting to
favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
Args:
grams: List of one-grams.
max_ngram_size: Maximum number of contiguous one-grams combined to create
an n-gram.
max_masked_tokens: Maximum total number of tokens to be masked.
rng: `random.Random` generator.
Returns:
A list of n-grams to be used as masks.
"""
if
not
grams
:
return
None
grams
=
sorted
(
grams
)
num_tokens
=
grams
[
-
1
].
end
# Ensure our grams are valid (i.e., they don't overlap).
for
a
,
b
in
_window
(
grams
,
2
):
if
a
.
end
>
b
.
begin
:
raise
ValueError
(
"overlapping grams: {}"
.
format
(
grams
))
# Build map from n-gram length to list of n-grams.
ngrams
=
{
i
:
[]
for
i
in
range
(
1
,
max_ngram_size
+
1
)}
for
gram_size
in
range
(
1
,
max_ngram_size
+
1
):
for
g
in
_window
(
grams
,
gram_size
):
if
_contiguous
(
g
):
# Add an n-gram which spans these one-grams.
ngrams
[
gram_size
].
append
(
_Gram
(
g
[
0
].
begin
,
g
[
-
1
].
end
))
# Shuffle each list of n-grams.
for
v
in
ngrams
.
values
():
rng
.
shuffle
(
v
)
# Create the weighting for n-gram length selection.
# Stored cummulatively for `random.choices` below.
cummulative_weights
=
list
(
itertools
.
accumulate
([
1.
/
n
for
n
in
range
(
1
,
max_ngram_size
+
1
)]))
output_ngrams
=
[]
# Keep a bitmask of which tokens have been masked.
masked_tokens
=
[
False
]
*
num_tokens
# Loop until we have enough masked tokens or there are no more candidate
# n-grams of any length.
# Each code path should ensure one or more elements from `ngrams` are removed
# to guarentee this loop terminates.
while
(
sum
(
masked_tokens
)
<
max_masked_tokens
and
sum
(
len
(
s
)
for
s
in
ngrams
.
values
())):
# Pick an n-gram size based on our weights.
sz
=
random
.
choices
(
range
(
1
,
max_ngram_size
+
1
),
cum_weights
=
cummulative_weights
)[
0
]
# Ensure this size doesn't result in too many masked tokens.
# E.g., a two-gram contains _at least_ two tokens.
if
sum
(
masked_tokens
)
+
sz
>
max_masked_tokens
:
# All n-grams of this length are too long and can be removed from
# consideration.
ngrams
[
sz
].
clear
()
continue
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
):
"""Creates the predictions for the masked LM objective."""
# All of the n-grams of this size have been used.
if
not
ngrams
[
sz
]:
continue
# Choose a random n-gram of the given size.
gram
=
ngrams
[
sz
].
pop
()
num_gram_tokens
=
gram
.
end
-
gram
.
begin
# Check if this would add too many tokens.
if
num_gram_tokens
+
sum
(
masked_tokens
)
>
max_masked_tokens
:
continue
cand_indexes
=
[]
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
# Check if any of the tokens in this gram have already been masked.
if
sum
(
masked_tokens
[
gram
.
begin
:
gram
.
end
]):
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
(
do_whole_word_mask
and
len
(
cand_indexes
)
>=
1
and
token
.
startswith
(
"##"
)):
cand_indexes
[
-
1
].
append
(
i
)
# Found a usable n-gram! Mark its tokens as masked and add it to return.
masked_tokens
[
gram
.
begin
:
gram
.
end
]
=
[
True
]
*
(
gram
.
end
-
gram
.
begin
)
output_ngrams
.
append
(
gram
)
return
output_ngrams
def
_wordpieces_to_grams
(
tokens
):
"""Reconstitue grams (words) from `tokens`.
E.g.,
tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
Args:
tokens: list of wordpieces
Returns:
List of _Grams representing spans of whole words
(without "[CLS]" and "[SEP]").
"""
grams
=
[]
gram_start_pos
=
None
for
i
,
token
in
enumerate
(
tokens
):
if
gram_start_pos
is
not
None
and
token
.
startswith
(
"##"
):
continue
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
i
))
if
token
not
in
[
"[CLS]"
,
"[SEP]"
]:
gram_start_pos
=
i
else
:
cand_indexes
.
append
([
i
])
gram_start_pos
=
None
if
gram_start_pos
is
not
None
:
grams
.
append
(
_Gram
(
gram_start_pos
,
len
(
tokens
)))
return
grams
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
,
do_whole_word_mask
,
max_ngram_size
=
None
):
"""Creates the predictions for the masked LM objective."""
if
do_whole_word_mask
:
grams
=
_wordpieces_to_grams
(
tokens
)
else
:
# Here we consider each token to be a word to allow for sub-word masking.
if
max_ngram_size
:
raise
ValueError
(
"cannot use ngram masking without whole word masking"
)
grams
=
[
_Gram
(
i
,
i
+
1
)
for
i
in
range
(
0
,
len
(
tokens
))
if
tokens
[
i
]
not
in
[
"[CLS]"
,
"[SEP]"
]]
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
# Generate masks. If `max_ngram_size` in [0, None] it means we're doing
# whole word masking or token level masking. Both of these can be treated
# as the `max_ngram_size=1` case.
masked_grams
=
_masking_ngrams
(
grams
,
max_ngram_size
or
1
,
num_to_predict
,
rng
)
masked_lms
=
[]
covered_indexes
=
set
()
for
index_set
in
cand_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
output_tokens
=
list
(
tokens
)
for
gram
in
masked_grams
:
# 80% of the time, replace all n-gram tokens with [MASK]
if
rng
.
random
()
<
0.8
:
replacement_action
=
lambda
idx
:
"[MASK]"
else
:
# 10% of the time, keep all the original n-gram tokens.
if
rng
.
random
()
<
0.5
:
replacement_action
=
lambda
idx
:
tokens
[
idx
]
# 10% of the time, replace each n-gram token with a random word.
else
:
# 10% of the time, keep original
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
replacement_action
=
lambda
idx
:
rng
.
choice
(
vocab_words
)
output_tokens
[
index
]
=
masked_token
for
idx
in
range
(
gram
.
begin
,
gram
.
end
):
output_tokens
[
idx
]
=
replacement_action
(
idx
)
masked_lms
.
append
(
MaskedLmInstance
(
index
=
idx
,
label
=
tokens
[
idx
]))
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
...
...
@@ -467,7 +649,7 @@ def main(_):
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
,
FLAGS
.
do_whole_word_mask
)
rng
,
FLAGS
.
do_whole_word_mask
,
FLAGS
.
max_ngram_size
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
logging
.
info
(
"*** Writing to output files ***"
)
...
...
@@ -476,7 +658,8 @@ def main(_):
write_instance_to_example_files
(
instances
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
max_predictions_per_seq
,
output_files
,
FLAGS
.
gzip_compress
)
FLAGS
.
gzip_compress
,
FLAGS
.
use_v2_feature_names
)
if
__name__
==
"__main__"
:
...
...
official/nlp/data/create_pretraining_data_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.nlp.data.create_pretraining_data."""
import
random
import
tensorflow
as
tf
from
official.nlp.data
import
create_pretraining_data
as
cpd
_VOCAB_WORDS
=
[
"vocab_1"
,
"vocab_2"
]
class
CreatePretrainingDataTest
(
tf
.
test
.
TestCase
):
def
assertTokens
(
self
,
input_tokens
,
output_tokens
,
masked_positions
,
masked_labels
):
# Ensure the masked positions are unique.
self
.
assertCountEqual
(
masked_positions
,
set
(
masked_positions
))
# Ensure we can reconstruct the input from the output.
reconstructed_tokens
=
output_tokens
for
pos
,
label
in
zip
(
masked_positions
,
masked_labels
):
reconstructed_tokens
[
pos
]
=
label
self
.
assertEqual
(
input_tokens
,
reconstructed_tokens
)
# Ensure each label is valid.
for
pos
,
label
in
zip
(
masked_positions
,
masked_labels
):
output_token
=
output_tokens
[
pos
]
if
(
output_token
==
"[MASK]"
or
output_token
in
_VOCAB_WORDS
or
output_token
==
input_tokens
[
pos
]):
continue
self
.
fail
(
"invalid mask value: {}"
.
format
(
output_token
))
def
test_wordpieces_to_grams
(
self
):
tests
=
[
([
"That"
,
"cone"
],
[(
0
,
1
),
(
1
,
2
)]),
([
"That"
,
"cone"
,
"##s"
],
[(
0
,
1
),
(
1
,
3
)]),
([
"Swit"
,
"##zer"
,
"##land"
],
[(
0
,
3
)]),
([
"[CLS]"
,
"Up"
,
"##dog"
],
[(
1
,
3
)]),
([
"[CLS]"
,
"Up"
,
"##dog"
,
"[SEP]"
,
"Down"
],
[(
1
,
3
),
(
4
,
5
)]),
]
for
inp
,
expected
in
tests
:
output
=
cpd
.
_wordpieces_to_grams
(
inp
)
self
.
assertEqual
(
expected
,
output
)
def
test_window
(
self
):
input_list
=
[
1
,
2
,
3
,
4
]
window_outputs
=
[
(
1
,
[[
1
],
[
2
],
[
3
],
[
4
]]),
(
2
,
[[
1
,
2
],
[
2
,
3
],
[
3
,
4
]]),
(
3
,
[[
1
,
2
,
3
],
[
2
,
3
,
4
]]),
(
4
,
[[
1
,
2
,
3
,
4
]]),
(
5
,
[]),
]
for
window
,
expected
in
window_outputs
:
output
=
cpd
.
_window
(
input_list
,
window
)
self
.
assertEqual
(
expected
,
list
(
output
))
def
test_create_masked_lm_predictions
(
self
):
tokens
=
[
"[CLS]"
,
"a"
,
"##a"
,
"b"
,
"##b"
,
"c"
,
"##c"
,
"[SEP]"
]
rng
=
random
.
Random
(
123
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
3
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
False
,
max_ngram_size
=
None
))
self
.
assertEqual
(
len
(
masked_positions
),
3
)
self
.
assertEqual
(
len
(
masked_labels
),
3
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
def
test_create_masked_lm_predictions_whole_word
(
self
):
tokens
=
[
"[CLS]"
,
"a"
,
"##a"
,
"b"
,
"##b"
,
"c"
,
"##c"
,
"[SEP]"
]
rng
=
random
.
Random
(
345
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
3
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
True
,
max_ngram_size
=
None
))
# since we can't get exactly three tokens without breaking a word we
# only take two.
self
.
assertEqual
(
len
(
masked_positions
),
2
)
self
.
assertEqual
(
len
(
masked_labels
),
2
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
# ensure that we took an entire word.
self
.
assertIn
(
masked_labels
,
[[
"a"
,
"##a"
],
[
"b"
,
"##b"
],
[
"c"
,
"##c"
]])
def
test_create_masked_lm_predictions_ngram
(
self
):
tokens
=
[
"[CLS]"
]
+
[
"tok{}"
.
format
(
i
)
for
i
in
range
(
0
,
512
)]
+
[
"[SEP]"
]
rng
=
random
.
Random
(
345
)
for
_
in
range
(
0
,
5
):
output_tokens
,
masked_positions
,
masked_labels
=
(
cpd
.
create_masked_lm_predictions
(
tokens
=
tokens
,
masked_lm_prob
=
1.0
,
max_predictions_per_seq
=
76
,
vocab_words
=
_VOCAB_WORDS
,
rng
=
rng
,
do_whole_word_mask
=
True
,
max_ngram_size
=
3
))
self
.
assertEqual
(
len
(
masked_positions
),
76
)
self
.
assertEqual
(
len
(
masked_labels
),
76
)
self
.
assertTokens
(
tokens
,
output_tokens
,
masked_positions
,
masked_labels
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/nlp/data/create_xlnet_pretraining_data.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create LM TF examples for XLNet."""
import
json
import
math
import
os
import
random
from
typing
import
Iterable
,
Mapping
,
List
,
Optional
,
Tuple
import
unicodedata
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
dataclasses
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
special_symbols
=
{
"<unk>"
:
0
,
"<s>"
:
1
,
"</s>"
:
2
,
"<cls>"
:
3
,
"<sep>"
:
4
,
"<pad>"
:
5
,
"<mask>"
:
6
,
"<eod>"
:
7
,
"<eop>"
:
8
,
}
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
"seq_length"
,
512
,
help
=
"Sequence length."
)
flags
.
DEFINE_integer
(
"reuse_length"
,
256
,
help
=
"Number of token that can be reused as memory. "
"Could be half of `seq_len`."
)
flags
.
DEFINE_string
(
"input_file"
,
None
,
"Input raw text file (or comma-separated list of files)."
)
flags
.
DEFINE_string
(
"save_dir"
,
None
,
"Directory for saving processed data."
)
flags
.
DEFINE_string
(
"sp_model_file"
,
""
,
"The path to the model used by sentence piece tokenizer."
)
flags
.
DEFINE_bool
(
"use_eod_token"
,
True
,
"Whether or not to include EOD tokens."
)
flags
.
DEFINE_bool
(
"bi_data"
,
True
,
"Whether or not to use bi-directional data."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
True
,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
flags
.
DEFINE_integer
(
"per_host_batch_size"
,
32
,
"Batch size per host."
)
flags
.
DEFINE_integer
(
"num_cores_per_host"
,
16
,
"The number of (TPU) cores per host."
)
flags
.
DEFINE_string
(
"prefix"
,
""
,
"Filename prefix."
)
flags
.
DEFINE_string
(
"suffix"
,
""
,
"Filename suffix."
)
flags
.
DEFINE_integer
(
"task_id"
,
None
,
"The id of the current task."
)
flags
.
DEFINE_integer
(
"num_tasks"
,
None
,
"The total number of tasks."
)
flags
.
DEFINE_integer
(
"num_passes"
,
1
,
"The number of times to run the script."
)
@
dataclasses
.
dataclass
class
TrainingInstance
:
"""Representation of a single XLNet Pretraining instance."""
data
:
Iterable
[
int
]
segment_ids
:
Iterable
[
int
]
boundary_indices
:
Iterable
[
int
]
label
:
int
def
to_feature
(
self
)
->
Mapping
[
str
,
tf
.
train
.
Feature
]:
feat
=
lambda
x
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
x
))
return
dict
(
input_word_ids
=
feat
(
self
.
data
),
input_type_ids
=
feat
(
self
.
segment_ids
),
boundary_indices
=
feat
(
self
.
boundary_indices
),
label
=
feat
([
self
.
label
]))
def
to_example
(
self
)
->
tf
.
train
.
Example
:
return
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
self
.
to_feature
()))
def
__str__
(
self
):
def
seq_to_str
(
seq
):
return
" "
.
join
([
str
(
x
)
for
x
in
seq
])
s
=
""
s
+=
"tokens: %s
\n
"
%
seq_to_str
(
self
.
data
)
s
+=
"segment_ids: %s
\n
"
%
seq_to_str
(
self
.
segment_ids
)
s
+=
"boundary_indices: %s
\n
"
%
seq_to_str
(
self
.
boundary_indices
)
s
+=
"label: %s
\n
"
%
self
.
label
s
+=
"
\n
"
return
s
def
__repr__
(
self
):
return
self
.
__str__
()
def
_preprocess_line
(
line
:
str
,
do_lower_case
:
bool
=
False
)
->
str
:
"""Preprocesses an individual raw text line.
This function will:
- Remove extraneous spaces.
- Replace `` with ", and '' with ".
- Replaces accents.
- Applies lower casing.
Args:
line: The input line to preprocess.
do_lower_case: Whether or not to lower case the text.
Returns:
The preprocessed line.
"""
line
=
" "
.
join
(
line
.
split
())
line
=
line
.
replace
(
"``"
,
"
\"
"
).
replace
(
"''"
,
"
\"
"
)
# Replace accents.
line
=
unicodedata
.
normalize
(
"NFKD"
,
line
)
line
=
""
.
join
([
c
for
c
in
line
if
not
unicodedata
.
combining
(
c
)])
if
do_lower_case
:
line
=
line
.
lower
()
return
line
def
preprocess_and_tokenize_input_files
(
input_files
:
Iterable
[
str
],
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
use_eod
:
bool
=
True
,
do_lower_case
:
bool
=
False
,
log_example_freq
:
int
=
100000
)
->
List
[
Tuple
[
np
.
array
,
np
.
array
]]:
"""Preprocesses and encodes raw text from input files.
This function preprocesses raw text and encodes them into tokens using a
`SentencePieceModel` tokenization method. This also provides the sentence
indicator for each token.
Args:
input_files: The list of input file names.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
not included.
do_lower_case: Whether or not to apply lower casing during raw text
preprocessing.
log_example_freq: The optional field for how many lines to process before
emitting an info log.
Returns:
The preprocessed list. Each entry in the list is a tuple consisting of
the token IDs and the sentence IDs.
"""
all_data
=
[]
eod_symbol
=
special_symbols
[
"<eod>"
]
total_number_of_lines
=
0
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for
input_file
in
input_files
:
line_count
=
0
logging
.
info
(
"Preprocessing %s"
,
input_file
)
all_tokens
=
[]
all_sentence_ids
=
[]
sentence_id
=
True
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"rb"
)
as
reader
:
while
True
:
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
if
not
line
:
break
line_count
+=
1
if
line_count
%
log_example_freq
==
0
:
logging
.
info
(
"Loading line %d"
,
line_count
)
line
=
line
.
strip
()
if
not
line
:
if
use_eod
:
token_ids
=
[
eod_symbol
]
sentence_id
=
not
sentence_id
else
:
continue
else
:
preprocessed_line
=
_preprocess_line
(
line
=
line
,
do_lower_case
=
do_lower_case
)
token_ids
=
tokenization
.
encode_ids
(
sp_model
=
tokenizer
.
sp_model
,
text
=
preprocessed_line
)
all_tokens
.
extend
(
token_ids
)
all_sentence_ids
.
extend
([
sentence_id
]
*
len
(
token_ids
))
sentence_id
=
not
sentence_id
logging
.
info
(
"Finished processing %s. Number of lines: %d"
,
input_file
,
line_count
)
if
line_count
==
0
:
continue
total_number_of_lines
+=
line_count
all_tokens
=
np
.
array
(
all_tokens
,
dtype
=
np
.
int64
)
all_sentence_ids
=
np
.
array
(
all_sentence_ids
,
dtype
=
np
.
bool
)
all_data
.
append
((
all_tokens
,
all_sentence_ids
))
logging
.
info
(
"Completed text preprocessing. Total number of lines: %d"
,
total_number_of_lines
)
return
all_data
def
_reshape_to_batch_dimensions
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
per_host_batch_size
:
int
)
->
Tuple
[
np
.
array
,
np
.
array
]:
"""Truncates and reshapes input data with a batch major dimension.
Args:
tokens: The input token ids. This should have the same shape as
`sentence_ids`.
sentence_ids: The input sentence ids. This should have the same shape as
`token_ids`.
per_host_batch_size: The target per-host batch size.
Returns:
The tuple of reshaped tokens and sentence_ids.
"""
num_steps
=
len
(
tokens
)
//
per_host_batch_size
truncated_data_length
=
num_steps
*
per_host_batch_size
logging
.
info
(
"per_host_batch_size: %d"
,
per_host_batch_size
)
logging
.
info
(
"num_steps: %d"
,
num_steps
)
def
truncate_and_reshape
(
a
):
return
a
[:
truncated_data_length
].
reshape
((
per_host_batch_size
,
num_steps
))
return
(
truncate_and_reshape
(
tokens
),
truncate_and_reshape
(
sentence_ids
))
def
_create_a_and_b_segments
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
begin_index
:
int
,
total_length
:
int
,
no_cut_probability
:
float
=
0.5
):
"""Splits segments A and B from a single instance of tokens and sentence ids.
Args:
tokens: The 1D input token ids. This represents an individual entry within a
batch.
sentence_ids: The 1D input sentence ids. This represents an indivdual entry
within a batch. This should be the same length as `tokens`.
begin_index: The reference beginning index to split data.
total_length: The target combined length of segments A and B.
no_cut_probability: The probability of not cutting a segment despite
a cut possibly existing.
Returns:
A tuple consisting of A data, B data, and label.
"""
data_length
=
tokens
.
shape
[
0
]
if
begin_index
+
total_length
>=
data_length
:
logging
.
info
(
"[_create_segments]: begin_index %d + total_length %d >= "
"data_length %d"
,
begin_index
,
total_length
,
data_length
)
return
None
end_index
=
begin_index
+
1
cut_indices
=
[]
# Identify all indices where sentence IDs change from one to the next.
while
end_index
<
data_length
:
if
sentence_ids
[
end_index
]
!=
sentence_ids
[
end_index
-
1
]:
if
end_index
-
begin_index
>=
total_length
:
break
cut_indices
.
append
(
end_index
)
end_index
+=
1
a_begin
=
begin_index
if
not
cut_indices
or
random
.
random
()
<
no_cut_probability
:
# Segments A and B are contained within the same sentence.
label
=
0
if
not
cut_indices
:
a_end
=
end_index
else
:
a_end
=
random
.
choice
(
cut_indices
)
b_length
=
max
(
1
,
total_length
-
(
a_end
-
a_begin
))
b_begin
=
random
.
randint
(
0
,
data_length
-
1
-
b_length
)
b_end
=
b_begin
+
b_length
while
b_begin
>
0
and
sentence_ids
[
b_begin
-
1
]
==
sentence_ids
[
b_begin
]:
b_begin
-=
1
while
(
b_end
<
data_length
-
1
and
sentence_ids
[
b_end
-
1
]
==
sentence_ids
[
b_end
]):
b_end
+=
1
else
:
# Segments A and B are different sentences.
label
=
1
a_end
=
random
.
choice
(
cut_indices
)
b_begin
=
a_end
b_end
=
end_index
while
a_end
-
a_begin
+
b_end
-
b_begin
>
total_length
:
if
a_end
-
a_begin
>
b_end
-
b_begin
:
# Delete only the right side for the LM objective.
a_end
-=
1
else
:
b_end
-=
1
if
a_end
>=
data_length
or
b_end
>=
data_length
:
logging
.
info
(
"[_create_segments]: a_end %d or b_end %d >= data_length %d"
,
a_end
,
b_end
,
data_length
)
return
None
a_data
=
tokens
[
a_begin
:
a_end
]
b_data
=
tokens
[
b_begin
:
b_end
]
return
a_data
,
b_data
,
label
def
_is_functional_piece
(
piece
:
str
)
->
bool
:
return
piece
!=
"<unk>"
and
piece
.
startswith
(
"<"
)
and
piece
.
endswith
(
">"
)
def
_is_start_piece
(
piece
:
str
)
->
bool
:
special_pieces
=
set
(
list
(
'!"#$%&
\"
()*+,-./:;?@[
\\
]^_`{|}~'
))
if
(
piece
.
startswith
(
"▁"
)
or
piece
in
special_pieces
):
return
True
else
:
return
False
def
_get_boundary_indices
(
data
:
np
.
array
,
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
)
->
np
.
array
:
"""Gets the boundary indices of whole words."""
seq_length
=
len
(
data
)
boundary_indices
=
[]
for
index
,
piece
in
enumerate
(
tokenizer
.
convert_ids_to_tokens
(
data
.
tolist
())):
if
_is_start_piece
(
piece
)
and
not
_is_functional_piece
(
piece
):
boundary_indices
.
append
(
index
)
boundary_indices
.
append
(
seq_length
)
return
boundary_indices
def
_convert_tokens_to_instances
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
per_host_batch_size
:
int
,
seq_length
:
int
,
reuse_length
:
int
,
bi_data
:
bool
,
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
num_cores_per_host
:
int
=
0
,
logging_frequency
:
int
=
500
)
->
List
[
TrainingInstance
]:
"""Converts tokens and sentence IDs into individual training instances.
The format of data in the XLNet pretraining task is very similar to the
BERT pretraining task. Two segments A and B are randomly sampled, and the
contatenation of A and B into a single sequence is used to perform
language modeling.
To create an XLNet Pretraining instance from a single long sequence, S:
- Create a segment of length `reuse_length`. This first segment represents
past tokens. During modeling, this segment is used to cache obtained
content representations for the segment recurrence mechanism.
- Similar to BERT, create a segment of length `seq_length` - `reuse_length`
composed of A and B segments.
For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
Args:
tokens: All tokens concatenated into a single list.
sentence_ids: All sentence IDs concatenated into a single list.
per_host_batch_size: The target batch size per host.
seq_length: The max sequence length.
reuse_length: The number of tokens to use from the previous segment.
bi_data: Whether or not to use bidirectional data.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
num_cores_per_host: The number of cores per host. This is required if
`bi_data` = `True`.
logging_frequency: The frequency at which to log status updates.
Returns:
A list of `TrainingInstance` objects.
"""
instances
=
[]
per_core_batch_size
=
(
per_host_batch_size
//
num_cores_per_host
if
bi_data
else
None
)
if
bi_data
:
logging
.
info
(
"Bi-directional data enabled."
)
assert
per_host_batch_size
%
(
2
*
num_cores_per_host
)
==
0
forward_tokens
,
forward_sentence_ids
=
_reshape_to_batch_dimensions
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
//
2
)
forward_data_shape
=
(
num_cores_per_host
,
1
,
per_core_batch_size
//
2
,
-
1
)
forward_tokens
=
forward_tokens
.
reshape
(
forward_data_shape
)
forward_sentence_ids
=
forward_sentence_ids
.
reshape
(
forward_data_shape
)
backwards_tokens
=
forward_tokens
[:,
:,
:,
::
-
1
]
backwards_sentence_ids
=
forward_sentence_ids
[:,
:,
:,
::
-
1
]
tokens
=
np
.
concatenate
([
forward_tokens
,
backwards_tokens
],
1
).
reshape
(
per_host_batch_size
,
-
1
)
sentence_ids
=
np
.
concatenate
(
[
forward_sentence_ids
,
backwards_sentence_ids
]).
reshape
(
per_host_batch_size
,
-
1
)
else
:
logging
.
info
(
"Bi-directional data disabled."
)
tokens
,
sentence_ids
=
_reshape_to_batch_dimensions
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
)
logging
.
info
(
"Tokens shape: %s"
,
tokens
.
shape
)
data_length
=
tokens
.
shape
[
1
]
sep
=
np
.
array
([
special_symbols
[
"<sep>"
]],
dtype
=
np
.
int64
)
cls
=
np
.
array
([
special_symbols
[
"<cls>"
]],
dtype
=
np
.
int64
)
# 2 sep, 1 cls
num_special_tokens
=
3
data_index
=
0
batch_number
=
0
step_size
=
reuse_length
if
reuse_length
else
seq_length
num_batches
=
math
.
ceil
(
data_length
/
step_size
)
while
data_index
+
seq_length
<=
data_length
:
if
batch_number
%
logging_frequency
==
0
:
logging
.
info
(
"Processing batch %d of %d"
,
batch_number
,
num_batches
)
for
batch_index
in
range
(
per_host_batch_size
):
previous_segment_tokens
=
tokens
[
batch_index
,
data_index
:
data_index
+
reuse_length
]
results
=
_create_a_and_b_segments
(
tokens
=
tokens
[
batch_index
],
sentence_ids
=
sentence_ids
[
batch_index
],
begin_index
=
data_index
+
reuse_length
,
total_length
=
seq_length
-
reuse_length
-
num_special_tokens
)
if
results
is
None
:
logging
.
info
(
"Stopping at data index: %d"
,
data_index
)
break
a_data
,
b_data
,
label
=
results
data
=
np
.
concatenate
(
[
previous_segment_tokens
,
a_data
,
sep
,
b_data
,
sep
,
cls
])
a_length
=
a_data
.
shape
[
0
]
b_length
=
b_data
.
shape
[
0
]
segment_ids
=
([
0
]
*
(
reuse_length
+
a_length
)
+
[
0
]
+
[
1
]
*
b_length
+
[
1
]
+
[
2
])
boundary_indices
=
_get_boundary_indices
(
tokenizer
=
tokenizer
,
data
=
data
)
assert
len
(
data
)
==
seq_length
assert
len
(
segment_ids
)
==
seq_length
assert
len
(
boundary_indices
)
>
0
# pylint: disable=g-explicit-length-test
instances
.
append
(
TrainingInstance
(
data
=
data
,
segment_ids
=
segment_ids
,
boundary_indices
=
boundary_indices
,
label
=
label
))
batch_number
+=
1
data_index
+=
step_size
return
instances
def
write_instances_to_tfrecord
(
instances
:
Iterable
[
TrainingInstance
],
save_path
:
str
):
"""Writes instances to TFRecord."""
record_writer
=
tf
.
io
.
TFRecordWriter
(
save_path
)
logging
.
info
(
"Start writing to %s."
,
save_path
)
for
i
,
instance
in
enumerate
(
instances
):
if
i
<
5
:
logging
.
info
(
"Instance %d: %s"
,
i
,
str
(
instance
))
record_writer
.
write
(
instance
.
to_example
().
SerializeToString
())
record_writer
.
close
()
logging
.
info
(
"Done writing %s."
,
save_path
)
def
shuffle_and_combine_preprocessed_data
(
all_data
:
List
[
Tuple
[
np
.
array
,
np
.
array
]])
->
Tuple
[
np
.
array
,
np
.
array
]:
"""Shuffles and combines preprocessed token/sentence IDs from documents."""
document_permutation
=
np
.
random
.
permutation
(
len
(
all_data
))
previous_sentence_id
=
None
all_tokens
,
all_sentence_ids
=
[],
[]
for
document_index
in
document_permutation
:
tokens
,
sentence_ids
=
all_data
[
document_index
]
# pylint: disable=g-explicit-length-test
if
len
(
tokens
)
==
0
:
continue
if
(
previous_sentence_id
is
not
None
and
sentence_ids
[
0
]
==
previous_sentence_id
):
sentence_ids
=
np
.
logical_not
(
sentence_ids
)
all_tokens
.
append
(
tokens
)
all_sentence_ids
.
append
(
sentence_ids
)
previous_sentence_id
=
sentence_ids
[
-
1
]
return
np
.
concatenate
(
all_tokens
),
np
.
concatenate
(
all_sentence_ids
)
def
get_tfrecord_name
(
per_host_batch_size
:
int
,
num_cores_per_host
:
int
,
seq_length
:
int
,
bi_data
:
bool
,
reuse_length
:
int
,
do_lower_case
:
bool
,
use_eod_token
:
bool
,
prefix
:
str
=
""
,
suffix
:
str
=
""
,
pass_id
:
int
=
0
,
num_passes
:
int
=
1
,
task_id
:
int
=
None
,
num_tasks
:
int
=
None
)
->
str
:
"""Formats the resulting TFRecord name based on provided inputs."""
components
=
[]
if
prefix
:
components
.
append
(
prefix
)
components
.
append
(
"seqlen-{}"
.
format
(
seq_length
))
if
reuse_length
==
0
:
components
.
append
(
"memless"
)
else
:
components
.
append
(
"reuse-{}"
.
format
(
reuse_length
))
components
.
append
(
"bs-{}"
.
format
(
per_host_batch_size
))
components
.
append
(
"cores-{}"
.
format
(
num_cores_per_host
))
if
do_lower_case
:
components
.
append
(
"uncased"
)
else
:
components
.
append
(
"cased"
)
if
use_eod_token
:
components
.
append
(
"eod"
)
if
bi_data
:
components
.
append
(
"bi"
)
else
:
components
.
append
(
"uni"
)
if
suffix
:
components
.
append
(
suffix
)
s
=
"_"
.
join
(
components
)
+
".tfrecord"
if
num_passes
==
1
and
task_id
is
None
:
return
s
if
task_id
is
None
:
num_tasks
=
1
task_id
=
0
current_shard
=
task_id
*
num_passes
+
pass_id
total_shards
=
num_tasks
*
num_passes
return
s
+
"-{}-of-{}"
.
format
(
current_shard
,
total_shards
)
def
create_tfrecords
(
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
input_file_or_files
:
str
,
use_eod_token
:
bool
,
do_lower_case
:
bool
,
per_host_batch_size
:
int
,
seq_length
:
int
,
reuse_length
:
int
,
bi_data
:
bool
,
num_cores_per_host
:
int
,
save_dir
:
str
,
prefix
:
str
=
""
,
suffix
:
str
=
""
,
num_tasks
:
Optional
[
int
]
=
None
,
task_id
:
Optional
[
int
]
=
None
,
num_passes
:
int
=
1
):
"""Runs the end-to-end preprocessing pipeline."""
logging
.
info
(
"Input configuration:"
)
logging
.
info
(
"input file(s): %s"
,
input_file_or_files
)
logging
.
info
(
"use_eod_token: %s"
,
use_eod_token
)
logging
.
info
(
"do_lower_case: %s"
,
do_lower_case
)
logging
.
info
(
"per_host_batch_size: %d"
,
per_host_batch_size
)
logging
.
info
(
"seq_length: %d"
,
seq_length
)
logging
.
info
(
"reuse_length: %d"
,
reuse_length
)
logging
.
info
(
"bi_data: %s"
,
bi_data
)
logging
.
info
(
"num_cores_per_host: %d"
,
num_cores_per_host
)
logging
.
info
(
"save_dir: %s"
,
save_dir
)
if
task_id
is
not
None
and
num_tasks
is
not
None
:
logging
.
info
(
"task_id: %d"
,
task_id
)
logging
.
info
(
"num_tasks: %d"
,
num_tasks
)
input_files
=
[]
for
input_pattern
in
input_file_or_files
.
split
(
","
):
input_files
.
extend
(
tf
.
io
.
gfile
.
glob
(
input_pattern
))
logging
.
info
(
"*** Reading from input files ***"
)
for
input_file
in
input_files
:
logging
.
info
(
" %s"
,
input_file
)
logging
.
info
(
"Shuffling the files with a fixed random seed."
)
np
.
random
.
shuffle
(
input_files
)
if
num_tasks
is
not
None
:
assert
task_id
is
not
None
logging
.
info
(
"Total number of input files: %d"
,
len
(
input_files
))
logging
.
info
(
"Splitting into %d shards of %d files each."
,
num_tasks
,
len
(
input_files
)
//
num_tasks
)
input_files
=
input_files
[
task_id
::
num_tasks
]
all_data
=
preprocess_and_tokenize_input_files
(
input_files
=
input_files
,
tokenizer
=
tokenizer
,
use_eod
=
use_eod_token
,
do_lower_case
=
do_lower_case
)
for
pass_id
in
range
(
num_passes
):
logging
.
info
(
"Beginning pass %d of %d"
,
pass_id
,
num_passes
)
tokens
,
sentence_ids
=
shuffle_and_combine_preprocessed_data
(
all_data
)
assert
len
(
tokens
)
==
len
(
sentence_ids
)
filename
=
get_tfrecord_name
(
per_host_batch_size
=
per_host_batch_size
,
num_cores_per_host
=
num_cores_per_host
,
seq_length
=
seq_length
,
bi_data
=
bi_data
,
use_eod_token
=
use_eod_token
,
reuse_length
=
reuse_length
,
do_lower_case
=
do_lower_case
,
prefix
=
prefix
,
suffix
=
suffix
,
pass_id
=
pass_id
,
num_passes
=
num_passes
,
num_tasks
=
num_tasks
,
task_id
=
task_id
)
save_path
=
os
.
path
.
join
(
save_dir
,
filename
)
if
os
.
path
.
exists
(
save_path
):
# If the path already exists, then we were probably preempted but
# previously wrote this file.
logging
.
info
(
"%s already exists, skipping this batch."
,
save_path
)
else
:
instances
=
_convert_tokens_to_instances
(
tokenizer
=
tokenizer
,
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
,
seq_length
=
seq_length
,
reuse_length
=
reuse_length
,
bi_data
=
bi_data
,
num_cores_per_host
=
num_cores_per_host
)
write_instances_to_tfrecord
(
instances
=
instances
,
save_path
=
save_path
)
if
task_id
is
None
or
task_id
==
0
:
corpus_info
=
{
"vocab_size"
:
32000
,
"per_host_batch_size"
:
per_host_batch_size
,
"num_cores_per_host"
:
num_cores_per_host
,
"seq_length"
:
seq_length
,
"reuse_length"
:
reuse_length
,
"do_lower_case"
:
do_lower_case
,
"bi_data"
:
bi_data
,
"use_eod_token"
:
use_eod_token
,
}
corpus_fname
=
os
.
path
.
basename
(
filename
)
+
".json"
corpus_destination
=
os
.
path
.
join
(
save_dir
,
corpus_fname
)
logging
.
info
(
"Saving corpus info to %s"
,
corpus_destination
)
with
tf
.
io
.
gfile
.
GFile
(
corpus_destination
,
"w"
)
as
fp
:
json
.
dump
(
corpus_info
,
fp
)
def
main
(
_
):
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
create_tfrecords
(
tokenizer
=
tokenizer
,
input_file_or_files
=
FLAGS
.
input_file
,
use_eod_token
=
FLAGS
.
use_eod_token
,
do_lower_case
=
FLAGS
.
do_lower_case
,
per_host_batch_size
=
FLAGS
.
per_host_batch_size
,
seq_length
=
FLAGS
.
seq_length
,
reuse_length
=
FLAGS
.
reuse_length
,
bi_data
=
FLAGS
.
bi_data
,
num_cores_per_host
=
FLAGS
.
num_cores_per_host
,
save_dir
=
FLAGS
.
save_dir
,
prefix
=
FLAGS
.
prefix
,
suffix
=
FLAGS
.
suffix
,
num_tasks
=
FLAGS
.
num_tasks
,
task_id
=
FLAGS
.
task_id
,
num_passes
=
FLAGS
.
num_passes
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
logging
.
set_verbosity
(
logging
.
INFO
)
app
.
run
(
main
)
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment