Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
67996f87
Commit
67996f87
authored
Jun 23, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 317898942
parent
ee35a030
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
268 additions
and
7 deletions
+268
-7
official/nlp/bert/input_pipeline.py
official/nlp/bert/input_pipeline.py
+36
-0
official/nlp/data/classifier_data_lib.py
official/nlp/data/classifier_data_lib.py
+21
-5
official/nlp/data/create_finetuning_data.py
official/nlp/data/create_finetuning_data.py
+43
-2
official/nlp/data/sentence_retrieval_lib.py
official/nlp/data/sentence_retrieval_lib.py
+168
-0
No files found.
official/nlp/bert/input_pipeline.py
View file @
67996f87
...
...
@@ -247,3 +247,39 @@ def create_squad_dataset(file_path,
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
True
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
create_retrieval_dataset
(
file_path
,
seq_length
,
batch_size
,
input_pipeline_context
=
None
):
"""Creates input dataset from (tf)records files for scoring."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
seq_length
],
tf
.
int64
),
'int_iden'
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
}
dataset
=
single_file_dataset
(
file_path
,
name_to_features
)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
if
input_pipeline_context
and
input_pipeline_context
.
num_input_pipelines
>
1
:
dataset
=
dataset
.
shard
(
input_pipeline_context
.
num_input_pipelines
,
input_pipeline_context
.
input_pipeline_id
)
def
_select_data_from_record
(
record
):
x
=
{
'input_word_ids'
:
record
[
'input_ids'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
y
=
record
[
'int_iden'
]
return
(
x
,
y
)
dataset
=
dataset
.
map
(
_select_data_from_record
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
batch
(
batch_size
,
drop_remainder
=
False
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
official/nlp/data/classifier_data_lib.py
View file @
67996f87
...
...
@@ -33,7 +33,13 @@ from official.nlp.bert import tokenization
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
,
weight
=
None
):
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
,
weight
=
None
,
int_iden
=
None
):
"""Constructs a InputExample.
Args:
...
...
@@ -46,12 +52,15 @@ class InputExample(object):
specified for train and dev examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during
training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
"""
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
label
=
label
self
.
weight
=
weight
self
.
int_iden
=
int_iden
class
InputFeatures
(
object
):
...
...
@@ -63,13 +72,15 @@ class InputFeatures(object):
segment_ids
,
label_id
,
is_real_example
=
True
,
weight
=
None
):
weight
=
None
,
int_iden
=
None
):
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
self
.
is_real_example
=
is_real_example
self
.
weight
=
weight
self
.
int_iden
=
int_iden
class
DataProcessor
(
object
):
...
...
@@ -908,8 +919,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging
.
info
(
"input_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logging
.
info
(
"input_mask: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
logging
.
info
(
"segment_ids: %s"
,
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
logging
.
info
(
"label: %s (id = %
d
)"
,
example
.
label
,
label_id
)
logging
.
info
(
"label: %s (id = %
s
)"
,
example
.
label
,
str
(
label_id
)
)
logging
.
info
(
"weight: %s"
,
example
.
weight
)
logging
.
info
(
"int_iden: %s"
,
str
(
example
.
int_iden
))
feature
=
InputFeatures
(
input_ids
=
input_ids
,
...
...
@@ -917,7 +929,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids
=
segment_ids
,
label_id
=
label_id
,
is_real_example
=
True
,
weight
=
example
.
weight
)
weight
=
example
.
weight
,
int_iden
=
example
.
int_iden
)
return
feature
...
...
@@ -953,12 +967,14 @@ def file_based_convert_examples_to_features(examples,
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
label_type
is
not
None
and
label_type
==
float
:
features
[
"label_ids"
]
=
create_float_feature
([
feature
.
label_id
])
el
s
e
:
el
if
feature
.
label_id
is
not
Non
e
:
features
[
"label_ids"
]
=
create_int_feature
([
feature
.
label_id
])
features
[
"is_real_example"
]
=
create_int_feature
(
[
int
(
feature
.
is_real_example
)])
if
feature
.
weight
is
not
None
:
features
[
"weight"
]
=
create_float_feature
([
feature
.
weight
])
if
feature
.
int_iden
is
not
None
:
features
[
"int_iden"
]
=
create_int_feature
([
feature
.
int_iden
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
...
...
official/nlp/data/create_finetuning_data.py
View file @
67996f87
...
...
@@ -27,6 +27,7 @@ from absl import flags
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
classifier_data_lib
from
official.nlp.data
import
sentence_retrieval_lib
# word-piece tokenizer based squad_lib
from
official.nlp.data
import
squad_lib
as
squad_lib_wp
# sentence-piece tokenizer based squad_lib
...
...
@@ -36,7 +37,7 @@ FLAGS = flags.FLAGS
flags
.
DEFINE_enum
(
"fine_tuning_task_type"
,
"classification"
,
[
"classification"
,
"regression"
,
"squad"
],
[
"classification"
,
"regression"
,
"squad"
,
"retrieval"
],
"The name of the BERT fine tuning task for which data "
"will be generated.."
)
...
...
@@ -55,6 +56,9 @@ flags.DEFINE_enum("classification_task_name", "MNLI",
"only and for XNLI is all languages combined. Same for "
"PAWS-X."
)
flags
.
DEFINE_enum
(
"retrieval_task_name"
,
"bucc"
,
[
"bucc"
,
"tatoeba"
],
"The name of sentence retrieval task for scoring"
)
# XNLI task specific flag.
flags
.
DEFINE_string
(
"xnli_language"
,
"en"
,
...
...
@@ -246,6 +250,39 @@ def generate_squad_dataset():
FLAGS
.
max_query_length
,
FLAGS
.
doc_stride
,
FLAGS
.
version_2_with_negative
)
def
generate_retrieval_dataset
():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert
(
FLAGS
.
input_data_dir
and
FLAGS
.
retrieval_task_name
)
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
processor_text_fn
=
tokenization
.
convert_to_unicode
else
:
assert
FLAGS
.
tokenizer_impl
==
"sentence_piece"
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
processor_text_fn
=
functools
.
partial
(
tokenization
.
preprocess_text
,
lower
=
FLAGS
.
do_lower_case
)
processors
=
{
"bucc"
:
sentence_retrieval_lib
.
BuccProcessor
,
"tatoeba"
:
sentence_retrieval_lib
.
TatoebaProcessor
,
}
task_name
=
FLAGS
.
retrieval_task_name
.
lower
()
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
task_name
)
processor
=
processors
[
task_name
](
process_text_fn
=
processor_text_fn
)
return
sentence_retrieval_lib
.
generate_sentence_retrevial_tf_record
(
processor
,
FLAGS
.
input_data_dir
,
tokenizer
,
FLAGS
.
eval_data_output_path
,
FLAGS
.
test_data_output_path
,
FLAGS
.
max_seq_length
)
def
main
(
_
):
if
FLAGS
.
tokenizer_impl
==
"word_piece"
:
if
not
FLAGS
.
vocab_file
:
...
...
@@ -257,10 +294,15 @@ def main(_):
raise
ValueError
(
"FLAG sp_model_file for sentence-piece tokenizer is not specified."
)
if
FLAGS
.
fine_tuning_task_type
!=
"retrieval"
:
flags
.
mark_flag_as_required
(
"train_data_output_path"
)
if
FLAGS
.
fine_tuning_task_type
==
"classification"
:
input_meta_data
=
generate_classifier_dataset
()
elif
FLAGS
.
fine_tuning_task_type
==
"regression"
:
input_meta_data
=
generate_regression_dataset
()
elif
FLAGS
.
fine_tuning_task_type
==
"retrieval"
:
input_meta_data
=
generate_retrieval_dataset
()
else
:
input_meta_data
=
generate_squad_dataset
()
...
...
@@ -270,6 +312,5 @@ def main(_):
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"train_data_output_path"
)
flags
.
mark_flag_as_required
(
"meta_data_file_path"
)
app
.
run
(
main
)
official/nlp/data/sentence_retrieval_lib.py
0 → 100644
View file @
67996f87
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for cross lingual sentence retrieval task."""
import
os
from
absl
import
logging
from
official.nlp.bert
import
tokenization
from
official.nlp.data
import
classifier_data_lib
class
BuccProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Procssor for Xtreme BUCC data set."""
supported_languages
=
[
"de"
,
"fr"
,
"ru"
,
"zh"
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
BuccProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
languages
=
BuccProcessor
.
supported_languages
def
get_dev_examples
(
self
,
data_dir
,
file_pattern
):
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
file_pattern
.
format
(
"dev"
))),
"sample"
)
def
get_test_examples
(
self
,
data_dir
,
file_pattern
):
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
file_pattern
.
format
(
"test"
))),
"test"
)
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"BUCC"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
int_iden
=
int
(
line
[
0
].
split
(
"-"
)[
1
])
text_a
=
self
.
process_text_fn
(
line
[
1
])
examples
.
append
(
classifier_data_lib
.
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
int_iden
=
int_iden
))
return
examples
class
TatoebaProcessor
(
classifier_data_lib
.
DataProcessor
):
"""Procssor for Xtreme Tatoeba data set."""
supported_languages
=
[
"af"
,
"ar"
,
"bg"
,
"bn"
,
"de"
,
"el"
,
"es"
,
"et"
,
"eu"
,
"fa"
,
"fi"
,
"fr"
,
"he"
,
"hi"
,
"hu"
,
"id"
,
"it"
,
"ja"
,
"jv"
,
"ka"
,
"kk"
,
"ko"
,
"ml"
,
"mr"
,
"nl"
,
"pt"
,
"ru"
,
"sw"
,
"ta"
,
"te"
,
"th"
,
"tl"
,
"tr"
,
"ur"
,
"vi"
,
"zh"
]
def
__init__
(
self
,
process_text_fn
=
tokenization
.
convert_to_unicode
):
super
(
TatoebaProcessor
,
self
).
__init__
(
process_text_fn
)
self
.
languages
=
TatoebaProcessor
.
supported_languages
def
get_test_examples
(
self
,
data_dir
,
file_path
):
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
file_path
)),
"test"
)
@
staticmethod
def
get_processor_name
():
"""See base class."""
return
"TATOEBA"
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
self
.
process_text_fn
(
line
[
0
])
examples
.
append
(
classifier_data_lib
.
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
int_iden
=
i
))
return
examples
def
generate_sentence_retrevial_tf_record
(
processor
,
data_dir
,
tokenizer
,
eval_data_output_path
=
None
,
test_data_output_path
=
None
,
max_seq_length
=
128
):
"""Generates the tf records for retrieval tasks.
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from.
tokenizer: The tokenizer to be applied on the data.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
Returns:
A dictionary containing input meta data.
"""
assert
eval_data_output_path
or
test_data_output_path
if
processor
.
get_processor_name
()
==
"BUCC"
:
path_pattern
=
"{}-en.{{}}.{}"
if
processor
.
get_processor_name
()
==
"TATOEBA"
:
path_pattern
=
"{}-en.{}"
meta_data
=
{
"processor_type"
:
processor
.
get_processor_name
(),
"max_seq_length"
:
max_seq_length
,
"number_eval_data"
:
{},
"number_test_data"
:
{},
}
logging
.
info
(
"Start to process %s task data"
,
processor
.
get_processor_name
())
for
lang_a
in
processor
.
languages
:
for
lang_b
in
[
lang_a
,
"en"
]:
if
eval_data_output_path
:
eval_input_data_examples
=
processor
.
get_dev_examples
(
data_dir
,
os
.
path
.
join
(
path_pattern
.
format
(
lang_a
,
lang_b
)))
num_eval_data
=
len
(
eval_input_data_examples
)
logging
.
info
(
"Processing %d dev examples of %s-en.%s"
,
num_eval_data
,
lang_a
,
lang_b
)
output_file
=
os
.
path
.
join
(
eval_data_output_path
,
"{}-en-{}.{}.tfrecords"
.
format
(
lang_a
,
lang_b
,
"dev"
))
classifier_data_lib
.
file_based_convert_examples_to_features
(
eval_input_data_examples
,
None
,
max_seq_length
,
tokenizer
,
output_file
,
None
)
meta_data
[
"number_eval_data"
][
f
"
{
lang_a
}
-en.
{
lang_b
}
"
]
=
num_eval_data
if
test_data_output_path
:
test_input_data_examples
=
processor
.
get_test_examples
(
data_dir
,
os
.
path
.
join
(
path_pattern
.
format
(
lang_a
,
lang_b
)))
num_test_data
=
len
(
test_input_data_examples
)
logging
.
info
(
"Processing %d test examples of %s-en.%s"
,
num_test_data
,
lang_a
,
lang_b
)
output_file
=
os
.
path
.
join
(
test_data_output_path
,
"{}-en-{}.{}.tfrecords"
.
format
(
lang_a
,
lang_b
,
"test"
))
classifier_data_lib
.
file_based_convert_examples_to_features
(
test_input_data_examples
,
None
,
max_seq_length
,
tokenizer
,
output_file
,
None
)
meta_data
[
"number_test_data"
][
f
"
{
lang_a
}
-en.
{
lang_b
}
"
]
=
num_test_data
return
meta_data
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment