Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e8b6955e
Commit
e8b6955e
authored
Nov 18, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 18, 2020
Browse files
Refactor the XLNet pretrain data generation script(s).
PiperOrigin-RevId: 343108621
parent
58edfb5c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1083 additions
and
40 deletions
+1083
-40
official/nlp/data/create_xlnet_pretraining_data.py
official/nlp/data/create_xlnet_pretraining_data.py
+717
-0
official/nlp/data/create_xlnet_pretraining_data_test.py
official/nlp/data/create_xlnet_pretraining_data_test.py
+356
-0
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+7
-36
official/nlp/data/pretrain_dataloader_test.py
official/nlp/data/pretrain_dataloader_test.py
+3
-4
No files found.
official/nlp/data/create_xlnet_pretraining_data.py
0 → 100644
View file @
e8b6955e
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create LM TF examples for XLNet."""
import
json
import
math
import
os
import
random
from
typing
import
Iterable
,
Mapping
,
List
,
Optional
,
Tuple
import
unicodedata
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
dataclasses
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
special_symbols
=
{
"<unk>"
:
0
,
"<s>"
:
1
,
"</s>"
:
2
,
"<cls>"
:
3
,
"<sep>"
:
4
,
"<pad>"
:
5
,
"<mask>"
:
6
,
"<eod>"
:
7
,
"<eop>"
:
8
,
}
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_integer
(
"seq_length"
,
512
,
help
=
"Sequence length."
)
flags
.
DEFINE_integer
(
"reuse_length"
,
256
,
help
=
"Number of token that can be reused as memory. "
"Could be half of `seq_len`."
)
flags
.
DEFINE_string
(
"input_file"
,
None
,
"Input raw text file (or comma-separated list of files)."
)
flags
.
DEFINE_string
(
"save_dir"
,
None
,
"Directory for saving processed data."
)
flags
.
DEFINE_string
(
"sp_model_file"
,
""
,
"The path to the model used by sentence piece tokenizer."
)
flags
.
DEFINE_bool
(
"use_eod_token"
,
True
,
"Whether or not to include EOD tokens."
)
flags
.
DEFINE_bool
(
"bi_data"
,
True
,
"Whether or not to use bi-directional data."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
True
,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models."
)
flags
.
DEFINE_integer
(
"per_host_batch_size"
,
32
,
"Batch size per host."
)
flags
.
DEFINE_integer
(
"num_cores_per_host"
,
16
,
"The number of (TPU) cores per host."
)
flags
.
DEFINE_string
(
"prefix"
,
""
,
"Filename prefix."
)
flags
.
DEFINE_string
(
"suffix"
,
""
,
"Filename suffix."
)
flags
.
DEFINE_integer
(
"task_id"
,
None
,
"The id of the current task."
)
flags
.
DEFINE_integer
(
"num_tasks"
,
None
,
"The total number of tasks."
)
flags
.
DEFINE_integer
(
"num_passes"
,
1
,
"The number of times to run the script."
)
@
dataclasses
.
dataclass
class
TrainingInstance
:
"""Representation of a single XLNet Pretraining instance."""
data
:
Iterable
[
int
]
segment_ids
:
Iterable
[
int
]
boundary_indices
:
Iterable
[
int
]
label
:
int
def
to_feature
(
self
)
->
Mapping
[
str
,
tf
.
train
.
Feature
]:
feat
=
lambda
x
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
x
))
return
dict
(
input_word_ids
=
feat
(
self
.
data
),
input_type_ids
=
feat
(
self
.
segment_ids
),
boundary_indices
=
feat
(
self
.
boundary_indices
),
label
=
feat
([
self
.
label
]))
def
to_example
(
self
)
->
tf
.
train
.
Example
:
return
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
self
.
to_feature
()))
def
__str__
(
self
):
def
seq_to_str
(
seq
):
return
" "
.
join
([
str
(
x
)
for
x
in
seq
])
s
=
""
s
+=
"tokens: %s
\n
"
%
seq_to_str
(
self
.
data
)
s
+=
"segment_ids: %s
\n
"
%
seq_to_str
(
self
.
segment_ids
)
s
+=
"boundary_indices: %s
\n
"
%
seq_to_str
(
self
.
boundary_indices
)
s
+=
"label: %s
\n
"
%
self
.
label
s
+=
"
\n
"
return
s
def
__repr__
(
self
):
return
self
.
__str__
()
def
_preprocess_line
(
line
:
str
,
do_lower_case
:
bool
=
False
)
->
str
:
"""Preprocesses an individual raw text line.
This function will:
- Remove extraneous spaces.
- Replace `` with ", and '' with ".
- Replaces accents.
- Applies lower casing.
Args:
line: The input line to preprocess.
do_lower_case: Whether or not to lower case the text.
Returns:
The preprocessed line.
"""
line
=
" "
.
join
(
line
.
split
())
line
=
line
.
replace
(
"``"
,
"
\"
"
).
replace
(
"''"
,
"
\"
"
)
# Replace accents.
line
=
unicodedata
.
normalize
(
"NFKD"
,
line
)
line
=
""
.
join
([
c
for
c
in
line
if
not
unicodedata
.
combining
(
c
)])
if
do_lower_case
:
line
=
line
.
lower
()
return
line
def
preprocess_and_tokenize_input_files
(
input_files
:
Iterable
[
str
],
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
use_eod
:
bool
=
True
,
do_lower_case
:
bool
=
False
,
log_example_freq
:
int
=
100000
)
->
List
[
Tuple
[
np
.
array
,
np
.
array
]]:
"""Preprocesses and encodes raw text from input files.
This function preprocesses raw text and encodes them into tokens using a
`SentencePieceModel` tokenization method. This also provides the sentence
indicator for each token.
Args:
input_files: The list of input file names.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
not included.
do_lower_case: Whether or not to apply lower casing during raw text
preprocessing.
log_example_freq: The optional field for how many lines to process before
emitting an info log.
Returns:
The preprocessed list. Each entry in the list is a tuple consisting of
the token IDs and the sentence IDs.
"""
all_data
=
[]
eod_symbol
=
special_symbols
[
"<eod>"
]
total_number_of_lines
=
0
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for
input_file
in
input_files
:
line_count
=
0
logging
.
info
(
"Preprocessing %s"
,
input_file
)
all_tokens
=
[]
all_sentence_ids
=
[]
sentence_id
=
True
with
tf
.
io
.
gfile
.
GFile
(
input_file
,
"rb"
)
as
reader
:
while
True
:
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
if
not
line
:
break
line_count
+=
1
if
line_count
%
log_example_freq
==
0
:
logging
.
info
(
"Loading line %d"
,
line_count
)
line
=
line
.
strip
()
if
not
line
:
if
use_eod
:
token_ids
=
[
eod_symbol
]
sentence_id
=
not
sentence_id
else
:
continue
else
:
preprocessed_line
=
_preprocess_line
(
line
=
line
,
do_lower_case
=
do_lower_case
)
token_ids
=
tokenization
.
encode_ids
(
sp_model
=
tokenizer
.
sp_model
,
text
=
preprocessed_line
)
all_tokens
.
extend
(
token_ids
)
all_sentence_ids
.
extend
([
sentence_id
]
*
len
(
token_ids
))
sentence_id
=
not
sentence_id
logging
.
info
(
"Finished processing %s. Number of lines: %d"
,
input_file
,
line_count
)
if
line_count
==
0
:
continue
total_number_of_lines
+=
line_count
all_tokens
=
np
.
array
(
all_tokens
,
dtype
=
np
.
int64
)
all_sentence_ids
=
np
.
array
(
all_sentence_ids
,
dtype
=
np
.
bool
)
all_data
.
append
((
all_tokens
,
all_sentence_ids
))
logging
.
info
(
"Completed text preprocessing. Total number of lines: %d"
,
total_number_of_lines
)
return
all_data
def
_reshape_to_batch_dimensions
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
per_host_batch_size
:
int
)
->
Tuple
[
np
.
array
,
np
.
array
]:
"""Truncates and reshapes input data with a batch major dimension.
Args:
tokens: The input token ids. This should have the same shape as
`sentence_ids`.
sentence_ids: The input sentence ids. This should have the same shape as
`token_ids`.
per_host_batch_size: The target per-host batch size.
Returns:
The tuple of reshaped tokens and sentence_ids.
"""
num_steps
=
len
(
tokens
)
//
per_host_batch_size
truncated_data_length
=
num_steps
*
per_host_batch_size
logging
.
info
(
"per_host_batch_size: %d"
,
per_host_batch_size
)
logging
.
info
(
"num_steps: %d"
,
num_steps
)
def
truncate_and_reshape
(
a
):
return
a
[:
truncated_data_length
].
reshape
((
per_host_batch_size
,
num_steps
))
return
(
truncate_and_reshape
(
tokens
),
truncate_and_reshape
(
sentence_ids
))
def
_create_a_and_b_segments
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
begin_index
:
int
,
total_length
:
int
,
no_cut_probability
:
float
=
0.5
):
"""Splits segments A and B from a single instance of tokens and sentence ids.
Args:
tokens: The 1D input token ids. This represents an individual entry within a
batch.
sentence_ids: The 1D input sentence ids. This represents an indivdual entry
within a batch. This should be the same length as `tokens`.
begin_index: The reference beginning index to split data.
total_length: The target combined length of segments A and B.
no_cut_probability: The probability of not cutting a segment despite
a cut possibly existing.
Returns:
A tuple consisting of A data, B data, and label.
"""
data_length
=
tokens
.
shape
[
0
]
if
begin_index
+
total_length
>=
data_length
:
logging
.
info
(
"[_create_segments]: begin_index %d + total_length %d >= "
"data_length %d"
,
begin_index
,
total_length
,
data_length
)
return
None
end_index
=
begin_index
+
1
cut_indices
=
[]
# Identify all indices where sentence IDs change from one to the next.
while
end_index
<
data_length
:
if
sentence_ids
[
end_index
]
!=
sentence_ids
[
end_index
-
1
]:
if
end_index
-
begin_index
>=
total_length
:
break
cut_indices
.
append
(
end_index
)
end_index
+=
1
a_begin
=
begin_index
if
not
cut_indices
or
random
.
random
()
<
no_cut_probability
:
# Segments A and B are contained within the same sentence.
label
=
0
if
not
cut_indices
:
a_end
=
end_index
else
:
a_end
=
random
.
choice
(
cut_indices
)
b_length
=
max
(
1
,
total_length
-
(
a_end
-
a_begin
))
b_begin
=
random
.
randint
(
0
,
data_length
-
1
-
b_length
)
b_end
=
b_begin
+
b_length
while
b_begin
>
0
and
sentence_ids
[
b_begin
-
1
]
==
sentence_ids
[
b_begin
]:
b_begin
-=
1
while
(
b_end
<
data_length
-
1
and
sentence_ids
[
b_end
-
1
]
==
sentence_ids
[
b_end
]):
b_end
+=
1
else
:
# Segments A and B are different sentences.
label
=
1
a_end
=
random
.
choice
(
cut_indices
)
b_begin
=
a_end
b_end
=
end_index
while
a_end
-
a_begin
+
b_end
-
b_begin
>
total_length
:
if
a_end
-
a_begin
>
b_end
-
b_begin
:
# Delete only the right side for the LM objective.
a_end
-=
1
else
:
b_end
-=
1
if
a_end
>=
data_length
or
b_end
>=
data_length
:
logging
.
info
(
"[_create_segments]: a_end %d or b_end %d >= data_length %d"
,
a_end
,
b_end
,
data_length
)
return
None
a_data
=
tokens
[
a_begin
:
a_end
]
b_data
=
tokens
[
b_begin
:
b_end
]
return
a_data
,
b_data
,
label
def
_is_functional_piece
(
piece
:
str
)
->
bool
:
return
piece
!=
"<unk>"
and
piece
.
startswith
(
"<"
)
and
piece
.
endswith
(
">"
)
def
_is_start_piece
(
piece
:
str
)
->
bool
:
special_pieces
=
set
(
list
(
'!"#$%&
\"
()*+,-./:;?@[
\\
]^_`{|}~'
))
if
(
piece
.
startswith
(
"▁"
)
or
piece
in
special_pieces
):
return
True
else
:
return
False
def
_get_boundary_indices
(
data
:
np
.
array
,
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
)
->
np
.
array
:
"""Gets the boundary indices of whole words."""
seq_length
=
len
(
data
)
boundary_indices
=
[]
for
index
,
piece
in
enumerate
(
tokenizer
.
convert_ids_to_tokens
(
data
.
tolist
())):
if
_is_start_piece
(
piece
)
and
not
_is_functional_piece
(
piece
):
boundary_indices
.
append
(
index
)
boundary_indices
.
append
(
seq_length
)
return
boundary_indices
def
_convert_tokens_to_instances
(
tokens
:
np
.
array
,
sentence_ids
:
np
.
array
,
per_host_batch_size
:
int
,
seq_length
:
int
,
reuse_length
:
int
,
bi_data
:
bool
,
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
num_cores_per_host
:
int
=
0
,
logging_frequency
:
int
=
500
)
->
List
[
TrainingInstance
]:
"""Converts tokens and sentence IDs into individual training instances.
The format of data in the XLNet pretraining task is very similar to the
BERT pretraining task. Two segments A and B are randomly sampled, and the
contatenation of A and B into a single sequence is used to perform
language modeling.
To create an XLNet Pretraining instance from a single long sequence, S:
- Create a segment of length `reuse_length`. This first segment represents
past tokens. During modeling, this segment is used to cache obtained
content representations for the segment recurrence mechanism.
- Similar to BERT, create a segment of length `seq_length` - `reuse_length`
composed of A and B segments.
For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
Args:
tokens: All tokens concatenated into a single list.
sentence_ids: All sentence IDs concatenated into a single list.
per_host_batch_size: The target batch size per host.
seq_length: The max sequence length.
reuse_length: The number of tokens to use from the previous segment.
bi_data: Whether or not to use bidirectional data.
tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
num_cores_per_host: The number of cores per host. This is required if
`bi_data` = `True`.
logging_frequency: The frequency at which to log status updates.
Returns:
A list of `TrainingInstance` objects.
"""
instances
=
[]
per_core_batch_size
=
(
per_host_batch_size
//
num_cores_per_host
if
bi_data
else
None
)
if
bi_data
:
logging
.
info
(
"Bi-directional data enabled."
)
assert
per_host_batch_size
%
(
2
*
num_cores_per_host
)
==
0
forward_tokens
,
forward_sentence_ids
=
_reshape_to_batch_dimensions
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
//
2
)
forward_data_shape
=
(
num_cores_per_host
,
1
,
per_core_batch_size
//
2
,
-
1
)
forward_tokens
=
forward_tokens
.
reshape
(
forward_data_shape
)
forward_sentence_ids
=
forward_sentence_ids
.
reshape
(
forward_data_shape
)
backwards_tokens
=
forward_tokens
[:,
:,
:,
::
-
1
]
backwards_sentence_ids
=
forward_sentence_ids
[:,
:,
:,
::
-
1
]
tokens
=
np
.
concatenate
([
forward_tokens
,
backwards_tokens
],
1
).
reshape
(
per_host_batch_size
,
-
1
)
sentence_ids
=
np
.
concatenate
(
[
forward_sentence_ids
,
backwards_sentence_ids
]).
reshape
(
per_host_batch_size
,
-
1
)
else
:
logging
.
info
(
"Bi-directional data disabled."
)
tokens
,
sentence_ids
=
_reshape_to_batch_dimensions
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
)
logging
.
info
(
"Tokens shape: %s"
,
tokens
.
shape
)
data_length
=
tokens
.
shape
[
1
]
sep
=
np
.
array
([
special_symbols
[
"<sep>"
]],
dtype
=
np
.
int64
)
cls
=
np
.
array
([
special_symbols
[
"<cls>"
]],
dtype
=
np
.
int64
)
# 2 sep, 1 cls
num_special_tokens
=
3
data_index
=
0
batch_number
=
0
step_size
=
reuse_length
if
reuse_length
else
seq_length
num_batches
=
math
.
ceil
(
data_length
/
step_size
)
while
data_index
+
seq_length
<=
data_length
:
if
batch_number
%
logging_frequency
==
0
:
logging
.
info
(
"Processing batch %d of %d"
,
batch_number
,
num_batches
)
for
batch_index
in
range
(
per_host_batch_size
):
previous_segment_tokens
=
tokens
[
batch_index
,
data_index
:
data_index
+
reuse_length
]
results
=
_create_a_and_b_segments
(
tokens
=
tokens
[
batch_index
],
sentence_ids
=
sentence_ids
[
batch_index
],
begin_index
=
data_index
+
reuse_length
,
total_length
=
seq_length
-
reuse_length
-
num_special_tokens
)
if
results
is
None
:
logging
.
info
(
"Stopping at data index: %d"
,
data_index
)
break
a_data
,
b_data
,
label
=
results
data
=
np
.
concatenate
(
[
previous_segment_tokens
,
a_data
,
sep
,
b_data
,
sep
,
cls
])
a_length
=
a_data
.
shape
[
0
]
b_length
=
b_data
.
shape
[
0
]
segment_ids
=
([
0
]
*
(
reuse_length
+
a_length
)
+
[
0
]
+
[
1
]
*
b_length
+
[
1
]
+
[
2
])
boundary_indices
=
_get_boundary_indices
(
tokenizer
=
tokenizer
,
data
=
data
)
instances
.
append
(
TrainingInstance
(
data
=
data
,
segment_ids
=
segment_ids
,
boundary_indices
=
boundary_indices
,
label
=
label
))
batch_number
+=
1
data_index
+=
step_size
return
instances
def
write_instances_to_tfrecord
(
instances
:
Iterable
[
TrainingInstance
],
save_path
:
str
):
"""Writes instances to TFRecord."""
record_writer
=
tf
.
io
.
TFRecordWriter
(
save_path
)
logging
.
info
(
"Start writing to %s."
,
save_path
)
for
i
,
instance
in
enumerate
(
instances
):
if
i
<
5
:
logging
.
info
(
"Instance %d: %s"
,
i
,
str
(
instance
))
record_writer
.
write
(
instance
.
to_example
().
SerializeToString
())
record_writer
.
close
()
logging
.
info
(
"Done writing %s."
,
save_path
)
def
shuffle_and_combine_preprocessed_data
(
all_data
:
List
[
Tuple
[
np
.
array
,
np
.
array
]])
->
Tuple
[
np
.
array
,
np
.
array
]:
"""Shuffles and combines preprocessed token/sentence IDs from documents."""
document_permutation
=
np
.
random
.
permutation
(
len
(
all_data
))
previous_sentence_id
=
None
all_tokens
,
all_sentence_ids
=
[],
[]
for
document_index
in
document_permutation
:
tokens
,
sentence_ids
=
all_data
[
document_index
]
# pylint: disable=g-explicit-length-test
if
len
(
tokens
)
==
0
:
continue
if
(
previous_sentence_id
is
not
None
and
sentence_ids
[
0
]
==
previous_sentence_id
):
sentence_ids
=
np
.
logical_not
(
sentence_ids
)
all_tokens
.
append
(
tokens
)
all_sentence_ids
.
append
(
sentence_ids
)
previous_sentence_id
=
sentence_ids
[
-
1
]
return
np
.
concatenate
(
all_tokens
),
np
.
concatenate
(
all_sentence_ids
)
def
get_tfrecord_name
(
per_host_batch_size
:
int
,
num_cores_per_host
:
int
,
seq_length
:
int
,
bi_data
:
bool
,
reuse_length
:
int
,
do_lower_case
:
bool
,
use_eod_token
:
bool
,
prefix
:
str
=
""
,
suffix
:
str
=
""
,
pass_id
:
int
=
0
,
num_passes
:
int
=
1
,
task_id
:
int
=
None
,
num_tasks
:
int
=
None
)
->
str
:
"""Formats the resulting TFRecord name based on provided inputs."""
components
=
[]
if
prefix
:
components
.
append
(
prefix
)
components
.
append
(
"seqlen-{}"
.
format
(
seq_length
))
if
reuse_length
==
0
:
components
.
append
(
"memless"
)
else
:
components
.
append
(
"reuse-{}"
.
format
(
reuse_length
))
components
.
append
(
"bs-{}"
.
format
(
per_host_batch_size
))
components
.
append
(
"cores-{}"
.
format
(
num_cores_per_host
))
if
do_lower_case
:
components
.
append
(
"uncased"
)
else
:
components
.
append
(
"cased"
)
if
use_eod_token
:
components
.
append
(
"eod"
)
if
bi_data
:
components
.
append
(
"bi"
)
else
:
components
.
append
(
"uni"
)
if
suffix
:
components
.
append
(
suffix
)
s
=
"_"
.
join
(
components
)
+
".tfrecord"
if
num_passes
==
1
and
task_id
is
None
:
return
s
if
task_id
is
None
:
num_tasks
=
1
task_id
=
0
current_shard
=
task_id
*
num_passes
+
pass_id
total_shards
=
num_tasks
*
num_passes
return
s
+
"-{}-of-{}"
.
format
(
current_shard
,
total_shards
)
def
create_tfrecords
(
tokenizer
:
tokenization
.
FullSentencePieceTokenizer
,
input_file_or_files
:
str
,
use_eod_token
:
bool
,
do_lower_case
:
bool
,
per_host_batch_size
:
int
,
seq_length
:
int
,
reuse_length
:
int
,
bi_data
:
bool
,
num_cores_per_host
:
int
,
save_dir
:
str
,
prefix
:
str
=
""
,
suffix
:
str
=
""
,
num_tasks
:
Optional
[
int
]
=
None
,
task_id
:
Optional
[
int
]
=
None
,
num_passes
:
int
=
1
):
"""Runs the end-to-end preprocessing pipeline."""
logging
.
info
(
"Input configuration:"
)
logging
.
info
(
"input file(s): %s"
,
input_file_or_files
)
logging
.
info
(
"use_eod_token: %s"
,
use_eod_token
)
logging
.
info
(
"do_lower_case: %s"
,
do_lower_case
)
logging
.
info
(
"per_host_batch_size: %d"
,
per_host_batch_size
)
logging
.
info
(
"seq_length: %d"
,
seq_length
)
logging
.
info
(
"reuse_length: %d"
,
reuse_length
)
logging
.
info
(
"bi_data: %s"
,
bi_data
)
logging
.
info
(
"num_cores_per_host: %d"
,
num_cores_per_host
)
logging
.
info
(
"save_dir: %s"
,
save_dir
)
if
task_id
is
not
None
and
num_tasks
is
not
None
:
logging
.
info
(
"task_id: %d"
,
task_id
)
logging
.
info
(
"num_tasks: %d"
,
num_tasks
)
input_files
=
[]
for
input_pattern
in
input_file_or_files
.
split
(
","
):
input_files
.
extend
(
tf
.
io
.
gfile
.
glob
(
input_pattern
))
logging
.
info
(
"*** Reading from input files ***"
)
for
input_file
in
input_files
:
logging
.
info
(
" %s"
,
input_file
)
logging
.
info
(
"Shuffling the files with a fixed random seed."
)
np
.
random
.
shuffle
(
input_files
)
if
num_tasks
is
not
None
:
assert
task_id
is
not
None
logging
.
info
(
"Total number of input files: %d"
,
len
(
input_files
))
logging
.
info
(
"Splitting into %d shards of %d files each."
,
num_tasks
,
len
(
input_files
)
//
num_tasks
)
input_files
=
input_files
[
task_id
::
num_tasks
]
all_data
=
preprocess_and_tokenize_input_files
(
input_files
=
input_files
,
tokenizer
=
tokenizer
,
use_eod
=
use_eod_token
,
do_lower_case
=
do_lower_case
)
for
pass_id
in
range
(
num_passes
):
logging
.
info
(
"Beginning pass %d of %d"
,
pass_id
,
num_passes
)
tokens
,
sentence_ids
=
shuffle_and_combine_preprocessed_data
(
all_data
)
assert
len
(
tokens
)
==
len
(
sentence_ids
)
filename
=
get_tfrecord_name
(
per_host_batch_size
=
per_host_batch_size
,
num_cores_per_host
=
num_cores_per_host
,
seq_length
=
seq_length
,
bi_data
=
bi_data
,
use_eod_token
=
use_eod_token
,
reuse_length
=
reuse_length
,
do_lower_case
=
do_lower_case
,
prefix
=
prefix
,
suffix
=
suffix
,
pass_id
=
pass_id
,
num_passes
=
num_passes
,
num_tasks
=
num_tasks
,
task_id
=
task_id
)
save_path
=
os
.
path
.
join
(
save_dir
,
filename
)
if
os
.
path
.
exists
(
save_path
):
# If the path already exists, then we were probably preempted but
# previously wrote this file.
logging
.
info
(
"%s already exists, skipping this batch."
,
save_path
)
else
:
instances
=
_convert_tokens_to_instances
(
tokenizer
=
tokenizer
,
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
,
seq_length
=
seq_length
,
reuse_length
=
reuse_length
,
bi_data
=
bi_data
,
num_cores_per_host
=
num_cores_per_host
)
write_instances_to_tfrecord
(
instances
=
instances
,
save_path
=
save_path
)
if
task_id
is
None
or
task_id
==
0
:
corpus_info
=
{
"vocab_size"
:
32000
,
"per_host_batch_size"
:
per_host_batch_size
,
"num_cores_per_host"
:
num_cores_per_host
,
"seq_length"
:
seq_length
,
"reuse_length"
:
reuse_length
,
"do_lower_case"
:
do_lower_case
,
"bi_data"
:
bi_data
,
"use_eod_token"
:
use_eod_token
,
}
corpus_fname
=
os
.
path
.
basename
(
filename
)
+
".json"
corpus_destination
=
os
.
path
.
join
(
save_dir
,
corpus_fname
)
logging
.
info
(
"Saving corpus info to %s"
,
corpus_destination
)
with
tf
.
io
.
gfile
.
GFile
(
corpus_destination
,
"w"
)
as
fp
:
json
.
dump
(
corpus_info
,
fp
)
def
main
(
_
):
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
FLAGS
.
sp_model_file
)
create_tfrecords
(
tokenizer
=
tokenizer
,
input_file_or_files
=
FLAGS
.
input_file
,
use_eod_token
=
FLAGS
.
use_eod_token
,
do_lower_case
=
FLAGS
.
do_lower_case
,
per_host_batch_size
=
FLAGS
.
per_host_batch_size
,
seq_length
=
FLAGS
.
seq_length
,
reuse_length
=
FLAGS
.
reuse_length
,
bi_data
=
FLAGS
.
bi_data
,
num_cores_per_host
=
FLAGS
.
num_cores_per_host
,
save_dir
=
FLAGS
.
save_dir
,
prefix
=
FLAGS
.
prefix
,
suffix
=
FLAGS
.
suffix
,
num_tasks
=
FLAGS
.
num_tasks
,
task_id
=
FLAGS
.
task_id
,
num_passes
=
FLAGS
.
num_passes
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
logging
.
set_verbosity
(
logging
.
INFO
)
app
.
run
(
main
)
official/nlp/data/create_xlnet_pretraining_data_test.py
0 → 100644
View file @
e8b6955e
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for official.nlp.data.create_xlnet_pretraining_data."""
import
os
import
tempfile
from
typing
import
List
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.data
import
create_xlnet_pretraining_data
as
cpd
_VOCAB_WORDS
=
[
"vocab_1"
,
"vocab_2"
]
# pylint: disable=invalid-name
def
_create_files
(
temp_dir
:
str
,
file_contents
:
List
[
List
[
str
]])
->
List
[
str
]:
"""Writes arbitrary documents into files."""
root_dir
=
tempfile
.
mkdtemp
(
dir
=
temp_dir
)
files
=
[]
for
i
,
file_content
in
enumerate
(
file_contents
):
destination
=
os
.
path
.
join
(
root_dir
,
"%d.txt"
%
i
)
with
open
(
destination
,
"wb"
)
as
f
:
for
line
in
file_content
:
f
.
write
(
line
.
encode
(
"utf-8"
))
files
.
append
(
destination
)
return
files
def
_get_mock_tokenizer
():
"""Creates a mock tokenizer."""
class
MockSpieceModel
:
"""Mock Spiece model for testing."""
def
__init__
(
self
):
self
.
_special_piece_to_id
=
{
"<unk>"
:
0
,
}
for
piece
in
set
(
list
(
'!"#$%&
\"
()*+,-./:;?@[
\\
]^_`{|}~'
)):
self
.
_special_piece_to_id
[
piece
]
=
1
def
EncodeAsPieces
(
self
,
inputs
:
str
)
->
List
[
str
]:
return
inputs
def
SampleEncodeAsPieces
(
self
,
inputs
:
str
,
nbest_size
:
int
,
theta
:
float
)
->
List
[
str
]:
del
nbest_size
,
theta
return
inputs
def
PieceToId
(
self
,
piece
:
str
)
->
int
:
return
ord
(
piece
[
0
])
def
IdToPiece
(
self
,
id_
:
int
)
->
str
:
return
chr
(
id_
)
*
3
class
Tokenizer
:
"""Mock Tokenizer for testing."""
def
__init__
(
self
):
self
.
sp_model
=
MockSpieceModel
()
def
convert_ids_to_tokens
(
self
,
ids
:
List
[
int
])
->
List
[
str
]:
return
[
self
.
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
return
Tokenizer
()
class
PreprocessDataTest
(
tf
.
test
.
TestCase
):
def
test_remove_extraneous_space
(
self
):
line
=
" abc "
output
=
cpd
.
_preprocess_line
(
line
)
self
.
assertEqual
(
output
,
"abc"
)
def
test_symbol_replacements
(
self
):
self
.
assertEqual
(
cpd
.
_preprocess_line
(
"``abc``"
),
"
\"
abc
\"
"
)
self
.
assertEqual
(
cpd
.
_preprocess_line
(
"''abc''"
),
"
\"
abc
\"
"
)
def
test_accent_replacements
(
self
):
self
.
assertEqual
(
cpd
.
_preprocess_line
(
"åbc"
),
"abc"
)
def
test_lower_case
(
self
):
self
.
assertEqual
(
cpd
.
_preprocess_line
(
"ABC"
,
do_lower_case
=
True
),
"abc"
)
def
test_end_to_end
(
self
):
self
.
assertEqual
(
cpd
.
_preprocess_line
(
"HelLo ``wórLd``"
,
do_lower_case
=
True
),
"hello
\"
world
\"
"
)
class
PreprocessAndTokenizeFilesTest
(
tf
.
test
.
TestCase
):
def
test_basic_end_to_end
(
self
):
documents
=
[
[
"This is sentence 1.
\n
"
,
"This is sentence 2.
\n
"
,
"Sentence 3 is what this is.
\n
"
,
],
[
"This is the second document.
\n
"
,
"This is the second line of the second document.
\n
"
],
]
input_files
=
_create_files
(
temp_dir
=
self
.
get_temp_dir
(),
file_contents
=
documents
)
all_data
=
cpd
.
preprocess_and_tokenize_input_files
(
input_files
=
input_files
,
tokenizer
=
_get_mock_tokenizer
(),
log_example_freq
=
1
)
self
.
assertEqual
(
len
(
all_data
),
len
(
documents
))
for
token_ids
,
sentence_ids
in
all_data
:
self
.
assertEqual
(
len
(
token_ids
),
len
(
sentence_ids
))
def
test_basic_correctness
(
self
):
documents
=
[[
"a
\n
"
,
"b
\n
"
,
"c
\n
"
]]
input_files
=
_create_files
(
temp_dir
=
self
.
get_temp_dir
(),
file_contents
=
documents
)
all_data
=
cpd
.
preprocess_and_tokenize_input_files
(
input_files
=
input_files
,
tokenizer
=
_get_mock_tokenizer
(),
log_example_freq
=
1
)
token_ids
,
sentence_ids
=
all_data
[
0
]
self
.
assertAllClose
(
token_ids
,
[
97
,
98
,
99
])
self
.
assertAllClose
(
sentence_ids
,
[
True
,
False
,
True
])
def
test_correctness_with_spaces_and_accents
(
self
):
documents
=
[[
" å
\n
"
,
"b
\n
"
,
" c
\n
"
,
]]
input_files
=
_create_files
(
temp_dir
=
self
.
get_temp_dir
(),
file_contents
=
documents
)
all_data
=
cpd
.
preprocess_and_tokenize_input_files
(
input_files
=
input_files
,
tokenizer
=
_get_mock_tokenizer
(),
log_example_freq
=
1
)
token_ids
,
sentence_ids
=
all_data
[
0
]
self
.
assertAllClose
(
token_ids
,
[
97
,
98
,
99
])
self
.
assertAllClose
(
sentence_ids
,
[
True
,
False
,
True
])
class
BatchReshapeTests
(
tf
.
test
.
TestCase
):
def
test_basic_functionality
(
self
):
per_host_batch_size
=
3
mock_shape
=
(
20
,)
# Should truncate and reshape.
expected_result_shape
=
(
3
,
6
)
tokens
=
np
.
zeros
(
mock_shape
)
sentence_ids
=
np
.
zeros
(
mock_shape
)
reshaped_data
=
cpd
.
_reshape_to_batch_dimensions
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
per_host_batch_size
)
for
values
in
reshaped_data
:
self
.
assertEqual
(
len
(
values
.
flatten
())
%
per_host_batch_size
,
0
)
self
.
assertAllClose
(
values
.
shape
,
expected_result_shape
)
class
CreateSegmentsTest
(
tf
.
test
.
TestCase
):
def
test_basic_functionality
(
self
):
data_length
=
10
tokens
=
np
.
arange
(
data_length
)
sentence_ids
=
np
.
concatenate
([
np
.
zeros
(
data_length
//
2
),
np
.
ones
(
data_length
//
2
)])
begin_index
=
0
total_length
=
8
a_data
,
b_data
,
label
=
cpd
.
_create_a_and_b_segments
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
begin_index
=
begin_index
,
total_length
=
total_length
,
no_cut_probability
=
0.
)
self
.
assertAllClose
(
a_data
,
[
0
,
1
,
2
,
3
])
self
.
assertAllClose
(
b_data
,
[
5
,
6
,
7
,
8
])
self
.
assertEqual
(
label
,
1
)
def
test_no_cut
(
self
):
data_length
=
10
tokens
=
np
.
arange
(
data_length
)
sentence_ids
=
np
.
zeros
(
data_length
)
begin_index
=
0
total_length
=
8
a_data
,
b_data
,
label
=
cpd
.
_create_a_and_b_segments
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
begin_index
=
begin_index
,
total_length
=
total_length
,
no_cut_probability
=
0.
)
self
.
assertGreater
(
len
(
a_data
),
0
)
self
.
assertGreater
(
len
(
b_data
),
0
)
self
.
assertEqual
(
label
,
0
)
def
test_no_cut_with_probability
(
self
):
data_length
=
10
tokens
=
np
.
arange
(
data_length
)
sentence_ids
=
np
.
concatenate
([
np
.
zeros
(
data_length
//
2
),
np
.
ones
(
data_length
//
2
)])
begin_index
=
0
total_length
=
8
a_data
,
b_data
,
label
=
cpd
.
_create_a_and_b_segments
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
begin_index
=
begin_index
,
total_length
=
total_length
,
no_cut_probability
=
1.
)
self
.
assertGreater
(
len
(
a_data
),
0
)
self
.
assertGreater
(
len
(
b_data
),
0
)
self
.
assertEqual
(
label
,
0
)
class
CreateInstancesTest
(
tf
.
test
.
TestCase
):
"""Tests conversions of Token/Sentence IDs to training instances."""
def
test_basic
(
self
):
data_length
=
12
tokens
=
np
.
arange
(
data_length
)
sentence_ids
=
np
.
zeros
(
data_length
)
seq_length
=
8
instances
=
cpd
.
_convert_tokens_to_instances
(
tokens
=
tokens
,
sentence_ids
=
sentence_ids
,
per_host_batch_size
=
2
,
seq_length
=
seq_length
,
reuse_length
=
4
,
tokenizer
=
_get_mock_tokenizer
(),
bi_data
=
False
,
num_cores_per_host
=
1
,
logging_frequency
=
1
)
for
instance
in
instances
:
self
.
assertEqual
(
len
(
instance
.
data
),
seq_length
)
self
.
assertEqual
(
len
(
instance
.
segment_ids
),
seq_length
)
self
.
assertIsInstance
(
instance
.
label
,
int
)
self
.
assertIsInstance
(
instance
.
boundary_indices
,
list
)
class
TFRecordPathTests
(
tf
.
test
.
TestCase
):
def
test_basic
(
self
):
base_kwargs
=
dict
(
per_host_batch_size
=
1
,
num_cores_per_host
=
1
,
seq_length
=
2
,
reuse_length
=
1
)
config1
=
dict
(
prefix
=
"test"
,
suffix
=
""
,
bi_data
=
True
,
use_eod_token
=
False
,
do_lower_case
=
True
)
config1
.
update
(
base_kwargs
)
expectation1
=
"test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
self
.
assertEqual
(
cpd
.
get_tfrecord_name
(
**
config1
),
expectation1
)
config2
=
dict
(
prefix
=
""
,
suffix
=
"test"
,
bi_data
=
False
,
use_eod_token
=
False
,
do_lower_case
=
False
)
config2
.
update
(
base_kwargs
)
expectation2
=
"seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
self
.
assertEqual
(
cpd
.
get_tfrecord_name
(
**
config2
),
expectation2
)
config3
=
dict
(
prefix
=
""
,
suffix
=
""
,
use_eod_token
=
True
,
bi_data
=
False
,
do_lower_case
=
True
)
config3
.
update
(
base_kwargs
)
expectation3
=
"seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
self
.
assertEqual
(
cpd
.
get_tfrecord_name
(
**
config3
),
expectation3
)
class
TestCreateTFRecords
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
named_parameters
(
(
"bi_data_only"
,
True
,
False
,
False
),
(
"eod_token_only"
,
False
,
True
,
True
),
(
"lower_case_only"
,
False
,
False
,
True
),
(
"all_enabled"
,
True
,
True
,
True
),
)
def
test_end_to_end
(
self
,
bi_data
:
bool
,
use_eod_token
:
bool
,
do_lower_case
:
bool
):
tokenizer
=
_get_mock_tokenizer
()
num_documents
=
5
sentences_per_document
=
10
document_length
=
50
documents
=
[
[
"a "
*
document_length
for
_
in
range
(
sentences_per_document
)]
for
_
in
range
(
num_documents
)]
save_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
files
=
_create_files
(
temp_dir
=
self
.
get_temp_dir
(),
file_contents
=
documents
)
cpd
.
create_tfrecords
(
tokenizer
=
tokenizer
,
input_file_or_files
=
","
.
join
(
files
),
use_eod_token
=
use_eod_token
,
do_lower_case
=
do_lower_case
,
per_host_batch_size
=
8
,
seq_length
=
8
,
reuse_length
=
4
,
bi_data
=
bi_data
,
num_cores_per_host
=
2
,
save_dir
=
save_dir
)
self
.
assertTrue
(
any
(
filter
(
lambda
x
:
x
.
endswith
(
".json"
),
os
.
listdir
(
save_dir
))))
self
.
assertTrue
(
any
(
filter
(
lambda
x
:
x
.
endswith
(
".tfrecord"
),
os
.
listdir
(
save_dir
))))
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
0
)
logging
.
set_verbosity
(
logging
.
INFO
)
tf
.
test
.
main
()
official/nlp/data/pretrain_dataloader.py
View file @
e8b6955e
...
@@ -143,8 +143,7 @@ class XLNetPretrainDataConfig(cfg.DataConfig):
...
@@ -143,8 +143,7 @@ class XLNetPretrainDataConfig(cfg.DataConfig):
reuse_length: The number of tokens in a previous segment to reuse. This
reuse_length: The number of tokens in a previous segment to reuse. This
should be the same value used during pretrain data creation.
should be the same value used during pretrain data creation.
sample_strategy: The strategy used to sample factorization permutations.
sample_strategy: The strategy used to sample factorization permutations.
Possible values: 'fixed', 'single_token', 'whole_word', 'token_span',
Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
'word_span'.
min_num_tokens: The minimum number of tokens to sample in a span.
min_num_tokens: The minimum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
This is used when `sample_strategy` is 'token_span'.
max_num_tokens: The maximum number of tokens to sample in a span.
max_num_tokens: The maximum number of tokens to sample in a span.
...
@@ -208,12 +207,8 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
...
@@ -208,12 +207,8 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_type_ids'
:
'input_type_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'target'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'boundary_indices'
:
'boundary_indices'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
...
@@ -234,20 +229,12 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
...
@@ -234,20 +229,12 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
inputs
=
record
[
'input_word_ids'
]
inputs
=
record
[
'input_word_ids'
]
x
[
'input_type_ids'
]
=
record
[
'input_type_ids'
]
x
[
'input_type_ids'
]
=
record
[
'input_type_ids'
]
if
self
.
_sample_strategy
==
'fixed'
:
input_mask
=
record
[
'input_mask'
]
else
:
input_mask
=
None
if
self
.
_sample_strategy
in
[
'whole_word'
,
'word_span'
]:
if
self
.
_sample_strategy
in
[
'whole_word'
,
'word_span'
]:
boundary
=
tf
.
sparse
.
to_dense
(
record
[
'boundary_indices'
])
boundary
=
tf
.
sparse
.
to_dense
(
record
[
'boundary_indices'
])
else
:
else
:
boundary
=
None
boundary
=
None
input_mask
=
self
.
_online_sample_mask
(
input_mask
=
self
.
_online_sample_mask
(
inputs
=
inputs
,
boundary
=
boundary
)
inputs
=
inputs
,
input_mask
=
input_mask
,
boundary
=
boundary
)
if
self
.
_reuse_length
>
0
:
if
self
.
_reuse_length
>
0
:
if
self
.
_permutation_size
>
self
.
_reuse_length
:
if
self
.
_permutation_size
>
self
.
_reuse_length
:
...
@@ -503,14 +490,10 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
...
@@ -503,14 +490,10 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
def
_online_sample_mask
(
self
,
def
_online_sample_mask
(
self
,
inputs
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
,
input_mask
:
tf
.
Tensor
,
boundary
:
tf
.
Tensor
)
->
tf
.
Tensor
:
boundary
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Samples target positions for predictions.
"""Samples target positions for predictions.
Descriptions of each strategy:
Descriptions of each strategy:
- 'fixed': Returns the input mask that was computed during pretrain data
creation. The value for `max_predictions_per_seq` must match the value
used during dataset creation.
- 'single_token': Samples individual tokens as prediction targets.
- 'single_token': Samples individual tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
...
@@ -518,9 +501,6 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
...
@@ -518,9 +501,6 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
Args:
Args:
inputs: The input tokens.
inputs: The input tokens.
input_mask: The `bool` Tensor of the same shape as `inputs`. This is the
input mask calculated when creating pretraining the pretraining dataset.
If `sample_strategy` is not 'fixed', this is not used.
boundary: The `int` Tensor of indices indicating whole word boundaries.
boundary: The `int` Tensor of indices indicating whole word boundaries.
This is used in 'whole_word' and 'word_span'
This is used in 'whole_word' and 'word_span'
...
@@ -528,26 +508,17 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
...
@@ -528,26 +508,17 @@ class XLNetPretrainDataLoader(data_loader.DataLoader):
The sampled `bool` input mask.
The sampled `bool` input mask.
Raises:
Raises:
`ValueError`: if `max_predictions_per_seq` is not set
`ValueError`: if `max_predictions_per_seq` is not set or if boundary is
and the sample strategy is not 'fixed', or if boundary is not provided
not provided for 'whole_word' and 'word_span' sample strategies.
for 'whole_word' and 'word_span' sample strategies.
"""
"""
if
(
self
.
_sample_strategy
!=
'fixed'
and
if
self
.
_max_predictions_per_seq
is
None
:
self
.
_max_predictions_per_seq
is
None
):
raise
ValueError
(
'`max_predictions_per_seq` must be set.'
)
raise
ValueError
(
'`max_predictions_per_seq` must be set if using '
'sample strategy {}.'
.
format
(
self
.
_sample_strategy
))
if
boundary
is
None
and
'word'
in
self
.
_sample_strategy
:
if
boundary
is
None
and
'word'
in
self
.
_sample_strategy
:
raise
ValueError
(
'`boundary` must be provided for {} strategy'
.
format
(
raise
ValueError
(
'`boundary` must be provided for {} strategy'
.
format
(
self
.
_sample_strategy
))
self
.
_sample_strategy
))
if
self
.
_sample_strategy
==
'fixed'
:
if
self
.
_sample_strategy
==
'single_token'
:
# Uses the computed input masks from preprocessing.
# Note: This should have `max_predictions_per_seq` number of tokens set
# to 1.
return
tf
.
cast
(
input_mask
,
tf
.
bool
)
elif
self
.
_sample_strategy
==
'single_token'
:
return
self
.
_single_token_mask
(
inputs
)
return
self
.
_single_token_mask
(
inputs
)
elif
self
.
_sample_strategy
==
'token_span'
:
elif
self
.
_sample_strategy
==
'token_span'
:
return
self
.
_token_span_mask
(
inputs
)
return
self
.
_token_span_mask
(
inputs
)
...
...
official/nlp/data/pretrain_dataloader_test.py
View file @
e8b6955e
...
@@ -174,7 +174,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -174,7 +174,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
class
XLNetPretrainDataTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
XLNetPretrainDataTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
@
parameterized
.
parameters
(
itertools
.
product
(
(
"fixed"
,
"single_token"
,
"whole_word"
,
"token_span"
),
(
"single_token"
,
"whole_word"
,
"token_span"
),
(
0
,
64
),
(
0
,
64
),
(
20
,
None
),
(
20
,
None
),
))
))
...
@@ -200,9 +200,8 @@ class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -200,9 +200,8 @@ class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
permutation_size
=
seq_length
//
2
,
permutation_size
=
seq_length
//
2
,
leak_ratio
=
0.1
)
leak_ratio
=
0.1
)
if
(
max_predictions_per_seq
is
None
and
sample_strategy
!=
"fixed"
):
if
max_predictions_per_seq
is
None
:
with
self
.
assertRaisesWithRegexpMatch
(
with
self
.
assertRaises
(
ValueError
):
ValueError
,
"`max_predictions_per_seq` must be set"
):
dataset
=
pretrain_dataloader
.
XLNetPretrainDataLoader
(
dataset
=
pretrain_dataloader
.
XLNetPretrainDataLoader
(
data_config
).
load
()
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment