Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6d6a78a2
Commit
6d6a78a2
authored
Nov 13, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 13, 2020
Browse files
Create XLNet pretrain data loader.
PiperOrigin-RevId: 342283301
parent
42f8e96e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
615 additions
and
12 deletions
+615
-12
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+504
-0
official/nlp/data/pretrain_dataloader_test.py
official/nlp/data/pretrain_dataloader_test.py
+111
-12
No files found.
official/nlp/data/pretrain_dataloader.py
View file @
6d6a78a2
...
...
@@ -16,7 +16,10 @@
"""Loads dataset for the BERT pretraining task."""
from
typing
import
Mapping
,
Optional
from
absl
import
logging
import
dataclasses
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
...
...
@@ -125,3 +128,504 @@ class BertPretrainDataLoader(data_loader.DataLoader):
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
@
dataclasses
.
dataclass
class
XLNetPretrainDataConfig
(
cfg
.
DataConfig
):
"""Data config for XLNet pretraining task.
Attributes:
input_path: See base class.
global_batch_size: See base calss.
is_training: See base class.
seq_length: The length of each sequence.
max_predictions_per_seq: The number of predictions per sequence.
reuse_length: The number of tokens in a previous segment to reuse. This
should be the same value used during pretrain data creation.
sample_strategy: The strategy used to sample factorization permutations.
Possible values: 'fixed', 'single_token', 'whole_word', 'token_span',
'word_span'.
min_num_tokens: The minimum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
max_num_tokens: The maximum number of tokens to sample in a span.
This is used when `sample_strategy` is 'token_span'.
min_num_words: The minimum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
max_num_words: The maximum number of words to sample in a span.
This is used when `sample_strategy` is 'word_span'.
permutation_size: The length of the longest permutation. This can be set
to `reuse_length`. This should NOT be greater than `reuse_length`,
otherwise this may introduce data leaks.
leak_ratio: The percentage of masked tokens that are leaked.
segment_sep_id: The ID of the SEP token used when preprocessing
the dataset.
segment_cls_id: The ID of the CLS token used when preprocessing
the dataset.
"""
input_path
:
str
=
''
global_batch_size
:
int
=
512
is_training
:
bool
=
True
seq_length
:
int
=
512
max_predictions_per_seq
:
int
=
76
reuse_length
:
int
=
256
sample_strategy
:
str
=
'word_span'
min_num_tokens
:
int
=
1
max_num_tokens
:
int
=
5
min_num_words
:
int
=
1
max_num_words
:
int
=
5
permutation_size
:
int
=
256
leak_ratio
:
float
=
0.1
segment_sep_id
:
int
=
4
segment_cls_id
:
int
=
3
@
data_loader_factory
.
register_data_loader_cls
(
XLNetPretrainDataConfig
)
class
XLNetPretrainDataLoader
(
data_loader
.
DataLoader
):
"""A class to load dataset for xlnet pretraining task."""
def
__init__
(
self
,
params
:
XLNetPretrainDataConfig
):
"""Inits `XLNetPretrainDataLoader` class.
Args:
params: A `XLNetPretrainDataConfig` object.
"""
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_max_predictions_per_seq
=
params
.
max_predictions_per_seq
self
.
_reuse_length
=
params
.
reuse_length
self
.
_num_replicas_in_sync
=
None
self
.
_permutation_size
=
params
.
permutation_size
self
.
_sep_id
=
params
.
segment_sep_id
self
.
_cls_id
=
params
.
segment_cls_id
self
.
_sample_strategy
=
params
.
sample_strategy
self
.
_leak_ratio
=
params
.
leak_ratio
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_word_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_type_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'target'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'boundary_indices'
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
list
(
example
.
keys
()):
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{}
inputs
=
record
[
'input_word_ids'
]
x
[
'input_type_ids'
]
=
record
[
'input_type_ids'
]
if
self
.
_sample_strategy
==
'fixed'
:
input_mask
=
record
[
'input_mask'
]
else
:
input_mask
=
None
if
self
.
_sample_strategy
in
[
'whole_word'
,
'word_span'
]:
boundary
=
tf
.
sparse
.
to_dense
(
record
[
'boundary_indices'
])
else
:
boundary
=
None
input_mask
=
self
.
_online_sample_mask
(
inputs
=
inputs
,
input_mask
=
input_mask
,
boundary
=
boundary
)
if
self
.
_reuse_length
>
0
:
if
self
.
_permutation_size
>
self
.
_reuse_length
:
logging
.
warning
(
'`permutation_size` is greater than `reuse_length` (%d > %d).'
'This may introduce data leakage.'
,
self
.
_permutation_size
,
self
.
_reuse_length
)
# Enable the memory mechanism.
# Permute the reuse and non-reuse segments separately.
non_reuse_len
=
self
.
_seq_length
-
self
.
_reuse_length
if
not
(
self
.
_reuse_length
%
self
.
_permutation_size
==
0
and
non_reuse_len
%
self
.
_permutation_size
==
0
):
raise
ValueError
(
'`reuse_length` and `seq_length` should both be '
'a multiple of `permutation_size`.'
)
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0
,
target_mask_0
,
tokens_0
,
masked_0
=
self
.
_get_factorization
(
inputs
=
inputs
[:
self
.
_reuse_length
],
input_mask
=
input_mask
[:
self
.
_reuse_length
])
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1
,
target_mask_1
,
tokens_1
,
masked_1
=
self
.
_get_factorization
(
inputs
[
self
.
_reuse_length
:],
input_mask
[
self
.
_reuse_length
:])
perm_mask_0
=
tf
.
concat
(
[
perm_mask_0
,
tf
.
zeros
([
self
.
_reuse_length
,
non_reuse_len
],
dtype
=
tf
.
int32
)],
axis
=
1
)
perm_mask_1
=
tf
.
concat
(
[
tf
.
ones
([
non_reuse_len
,
self
.
_reuse_length
],
dtype
=
tf
.
int32
),
perm_mask_1
],
axis
=
1
)
perm_mask
=
tf
.
concat
([
perm_mask_0
,
perm_mask_1
],
axis
=
0
)
target_mask
=
tf
.
concat
([
target_mask_0
,
target_mask_1
],
axis
=
0
)
tokens
=
tf
.
concat
([
tokens_0
,
tokens_1
],
axis
=
0
)
masked_tokens
=
tf
.
concat
([
masked_0
,
masked_1
],
axis
=
0
)
else
:
# Disable the memory mechanism.
if
self
.
_seq_length
%
self
.
_permutation_size
!=
0
:
raise
ValueError
(
'`seq_length` should be a multiple of '
'`permutation_size`.'
)
# Permute the entire sequence together
perm_mask
,
target_mask
,
tokens
,
masked_tokens
=
self
.
_get_factorization
(
inputs
=
inputs
,
input_mask
=
input_mask
)
x
[
'permutation_mask'
]
=
tf
.
reshape
(
perm_mask
,
[
self
.
_seq_length
,
self
.
_seq_length
])
x
[
'input_word_ids'
]
=
tokens
x
[
'masked_tokens'
]
=
masked_tokens
target
=
tokens
if
self
.
_max_predictions_per_seq
is
not
None
:
indices
=
tf
.
range
(
self
.
_seq_length
,
dtype
=
tf
.
int32
)
bool_target_mask
=
tf
.
cast
(
target_mask
,
tf
.
bool
)
indices
=
tf
.
boolean_mask
(
indices
,
bool_target_mask
)
# account for extra padding due to CLS/SEP.
actual_num_predict
=
tf
.
shape
(
indices
)[
0
]
pad_len
=
self
.
_max_predictions_per_seq
-
actual_num_predict
target_mapping
=
tf
.
one_hot
(
indices
,
self
.
_seq_length
,
dtype
=
tf
.
int32
)
paddings
=
tf
.
zeros
([
pad_len
,
self
.
_seq_length
],
dtype
=
target_mapping
.
dtype
)
target_mapping
=
tf
.
concat
([
target_mapping
,
paddings
],
axis
=
0
)
x
[
'target_mapping'
]
=
tf
.
reshape
(
target_mapping
,
[
self
.
_max_predictions_per_seq
,
self
.
_seq_length
])
target
=
tf
.
boolean_mask
(
target
,
bool_target_mask
)
paddings
=
tf
.
zeros
([
pad_len
],
dtype
=
target
.
dtype
)
target
=
tf
.
concat
([
target
,
paddings
],
axis
=
0
)
x
[
'target'
]
=
tf
.
reshape
(
target
,
[
self
.
_max_predictions_per_seq
])
target_mask
=
tf
.
concat
([
tf
.
ones
([
actual_num_predict
],
dtype
=
tf
.
int32
),
tf
.
zeros
([
pad_len
],
dtype
=
tf
.
int32
)
],
axis
=
0
)
x
[
'target_mask'
]
=
tf
.
reshape
(
target_mask
,
[
self
.
_max_predictions_per_seq
])
else
:
x
[
'target'
]
=
tf
.
reshape
(
target
,
[
self
.
_seq_length
])
x
[
'target_mask'
]
=
tf
.
reshape
(
target_mask
,
[
self
.
_seq_length
])
return
x
def
_index_pair_to_mask
(
self
,
begin_indices
:
tf
.
Tensor
,
end_indices
:
tf
.
Tensor
,
inputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Converts beginning and end indices into an actual mask."""
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
self
.
_sep_id
),
tf
.
not_equal
(
inputs
,
self
.
_cls_id
))
all_indices
=
tf
.
where
(
non_func_mask
,
tf
.
range
(
self
.
_seq_length
,
dtype
=
tf
.
int32
),
tf
.
constant
(
-
1
,
shape
=
[
self
.
_seq_length
],
dtype
=
tf
.
int32
))
candidate_matrix
=
tf
.
cast
(
tf
.
logical_and
(
all_indices
[
None
,
:]
>=
begin_indices
[:,
None
],
all_indices
[
None
,
:]
<
end_indices
[:,
None
]),
tf
.
float32
)
cumsum_matrix
=
tf
.
reshape
(
tf
.
cumsum
(
tf
.
reshape
(
candidate_matrix
,
[
-
1
])),
[
-
1
,
self
.
_seq_length
])
masked_matrix
=
tf
.
cast
(
cumsum_matrix
<=
self
.
_max_predictions_per_seq
,
tf
.
float32
)
target_mask
=
tf
.
reduce_sum
(
candidate_matrix
*
masked_matrix
,
axis
=
0
)
return
tf
.
cast
(
target_mask
,
tf
.
bool
)
def
_single_token_mask
(
self
,
inputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Samples individual tokens as prediction targets."""
all_indices
=
tf
.
range
(
self
.
_seq_length
,
dtype
=
tf
.
int32
)
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
self
.
_sep_id
),
tf
.
not_equal
(
inputs
,
self
.
_cls_id
))
non_func_indices
=
tf
.
boolean_mask
(
all_indices
,
non_func_mask
)
masked_pos
=
tf
.
random
.
shuffle
(
non_func_indices
)
masked_pos
=
tf
.
sort
(
masked_pos
[:
self
.
_max_predictions_per_seq
])
sparse_indices
=
tf
.
stack
(
[
tf
.
zeros_like
(
masked_pos
),
masked_pos
],
axis
=-
1
)
sparse_indices
=
tf
.
cast
(
sparse_indices
,
tf
.
int64
)
sparse_indices
=
tf
.
sparse
.
SparseTensor
(
sparse_indices
,
values
=
tf
.
ones_like
(
masked_pos
),
dense_shape
=
(
1
,
self
.
_seq_length
))
target_mask
=
tf
.
sparse
.
to_dense
(
sp_input
=
sparse_indices
,
default_value
=
0
)
return
tf
.
squeeze
(
tf
.
cast
(
target_mask
,
tf
.
bool
))
def
_whole_word_mask
(
self
,
inputs
:
tf
.
Tensor
,
boundary
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Samples whole words as prediction targets."""
pair_indices
=
tf
.
concat
([
boundary
[:
-
1
,
None
],
boundary
[
1
:,
None
]],
axis
=
1
)
cand_pair_indices
=
tf
.
random
.
shuffle
(
pair_indices
)[:
self
.
_max_predictions_per_seq
]
begin_indices
=
cand_pair_indices
[:,
0
]
end_indices
=
cand_pair_indices
[:,
1
]
return
self
.
_index_pair_to_mask
(
begin_indices
=
begin_indices
,
end_indices
=
end_indices
,
inputs
=
inputs
)
def
_token_span_mask
(
self
,
inputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Samples token spans as prediction targets."""
min_num_tokens
=
self
.
_params
.
min_num_tokens
max_num_tokens
=
self
.
_params
.
max_num_tokens
mask_alpha
=
self
.
_seq_length
/
self
.
_max_predictions_per_seq
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int32
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_tokens
,
max_num_tokens
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
self
.
_max_predictions_per_seq
,
dtype
=
tf
.
int32
,
)[
0
]
+
min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
self
.
_max_predictions_per_seq
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
# Compute the offset from left start to the right end
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
# Get the actual begin and end indices
begin_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
begin_indices
+
span_lens
# Remove out of range indices
valid_idx_mask
=
end_indices
<
self
.
_seq_length
begin_indices
=
tf
.
boolean_mask
(
begin_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
begin_indices
)[
0
],
tf
.
int32
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int32
))
begin_indices
=
tf
.
gather
(
begin_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
self
.
_index_pair_to_mask
(
begin_indices
=
begin_indices
,
end_indices
=
end_indices
,
inputs
=
inputs
)
def
_word_span_mask
(
self
,
inputs
:
tf
.
Tensor
,
boundary
:
tf
.
Tensor
):
"""Sample whole word spans as prediction targets."""
min_num_words
=
self
.
_params
.
min_num_words
max_num_words
=
self
.
_params
.
max_num_words
# Note: 1.2 is the token-to-word ratio
mask_alpha
=
self
.
_seq_length
/
self
.
_max_predictions_per_seq
/
1.2
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int32
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_words
,
max_num_words
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
# Sample `num_predict` words here: note that this is over sampling
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
self
.
_max_predictions_per_seq
,
dtype
=
tf
.
int32
,
)[
0
]
+
min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
self
.
_max_predictions_per_seq
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
begin_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
begin_indices
+
span_lens
# Remove out of range indices
max_boundary_index
=
tf
.
cast
(
tf
.
shape
(
boundary
)[
0
]
-
1
,
tf
.
int32
)
valid_idx_mask
=
end_indices
<
max_boundary_index
begin_indices
=
tf
.
boolean_mask
(
begin_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
begin_indices
=
tf
.
gather
(
boundary
,
begin_indices
)
end_indices
=
tf
.
gather
(
boundary
,
end_indices
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
begin_indices
)[
0
],
tf
.
int32
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int32
))
begin_indices
=
tf
.
gather
(
begin_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
self
.
_index_pair_to_mask
(
begin_indices
=
begin_indices
,
end_indices
=
end_indices
,
inputs
=
inputs
)
def
_online_sample_mask
(
self
,
inputs
:
tf
.
Tensor
,
input_mask
:
tf
.
Tensor
,
boundary
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Samples target positions for predictions.
Descriptions of each strategy:
- 'fixed': Returns the input mask that was computed during pretrain data
creation. The value for `max_predictions_per_seq` must match the value
used during dataset creation.
- 'single_token': Samples individual tokens as prediction targets.
- 'token_span': Samples spans of tokens as prediction targets.
- 'whole_word': Samples individual words as prediction targets.
- 'word_span': Samples spans of words as prediction targets.
Args:
inputs: The input tokens.
input_mask: The `bool` Tensor of the same shape as `inputs`. This is the
input mask calculated when creating pretraining the pretraining dataset.
If `sample_strategy` is not 'fixed', this is not used.
boundary: The `int` Tensor of indices indicating whole word boundaries.
This is used in 'whole_word' and 'word_span'
Returns:
The sampled `bool` input mask.
Raises:
`ValueError`: if `max_predictions_per_seq` is not set
and the sample strategy is not 'fixed', or if boundary is not provided
for 'whole_word' and 'word_span' sample strategies.
"""
if
(
self
.
_sample_strategy
!=
'fixed'
and
self
.
_max_predictions_per_seq
is
None
):
raise
ValueError
(
'`max_predictions_per_seq` must be set if using '
'sample strategy {}.'
.
format
(
self
.
_sample_strategy
))
if
boundary
is
None
and
'word'
in
self
.
_sample_strategy
:
raise
ValueError
(
'`boundary` must be provided for {} strategy'
.
format
(
self
.
_sample_strategy
))
if
self
.
_sample_strategy
==
'fixed'
:
# Uses the computed input masks from preprocessing.
# Note: This should have `max_predictions_per_seq` number of tokens set
# to 1.
return
tf
.
cast
(
input_mask
,
tf
.
bool
)
elif
self
.
_sample_strategy
==
'single_token'
:
return
self
.
_single_token_mask
(
inputs
)
elif
self
.
_sample_strategy
==
'token_span'
:
return
self
.
_token_span_mask
(
inputs
)
elif
self
.
_sample_strategy
==
'whole_word'
:
return
self
.
_whole_word_mask
(
inputs
,
boundary
)
elif
self
.
_sample_strategy
==
'word_span'
:
return
self
.
_word_span_mask
(
inputs
,
boundary
)
else
:
raise
NotImplementedError
(
'Invalid sample strategy.'
)
def
_get_factorization
(
self
,
inputs
:
tf
.
Tensor
,
input_mask
:
tf
.
Tensor
):
"""Samples a permutation of the factorization order.
Args:
inputs: the input tokens.
input_mask: the `bool` Tensor of the same shape as `inputs`.
If `True`, then this means select for partial prediction.
Returns:
perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
token (in original order) cannot attend to the jth attention token.
target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
If target_mask[i] == 1, then the i-th token needs to be predicted and
the mask will be used as input. This token will be included in the loss.
If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
input. This token will not be included in the loss.
tokens: int32 Tensor of shape [seq_length].
masked_tokens: int32 Tensor of shape [seq_length].
"""
factorization_length
=
tf
.
shape
(
inputs
)[
0
]
# Generate permutation indices
index
=
tf
.
range
(
factorization_length
,
dtype
=
tf
.
int32
)
index
=
tf
.
transpose
(
tf
.
reshape
(
index
,
[
-
1
,
self
.
_permutation_size
]))
index
=
tf
.
random
.
shuffle
(
index
)
index
=
tf
.
reshape
(
tf
.
transpose
(
index
),
[
-
1
])
input_mask
=
tf
.
cast
(
input_mask
,
tf
.
bool
)
# non-functional tokens
non_func_tokens
=
tf
.
logical_not
(
tf
.
logical_or
(
tf
.
equal
(
inputs
,
self
.
_sep_id
),
tf
.
equal
(
inputs
,
self
.
_cls_id
)))
masked_tokens
=
tf
.
logical_and
(
input_mask
,
non_func_tokens
)
non_masked_or_func_tokens
=
tf
.
logical_not
(
masked_tokens
)
smallest_index
=
-
2
*
tf
.
ones
([
factorization_length
],
dtype
=
tf
.
int32
)
# Similar to BERT, randomly leak some masked tokens
if
self
.
_leak_ratio
>
0
:
leak_tokens
=
tf
.
logical_and
(
masked_tokens
,
tf
.
random
.
uniform
([
factorization_length
],
maxval
=
1.0
)
<
self
.
_leak_ratio
)
can_attend_self
=
tf
.
logical_or
(
non_masked_or_func_tokens
,
leak_tokens
)
else
:
can_attend_self
=
non_masked_or_func_tokens
to_index
=
tf
.
where
(
can_attend_self
,
smallest_index
,
index
)
from_index
=
tf
.
where
(
can_attend_self
,
to_index
+
1
,
to_index
)
# For masked tokens, can attend if i > j
# For context tokens, always can attend each other
can_attend
=
from_index
[:,
None
]
>
to_index
[
None
,
:]
perm_mask
=
tf
.
cast
(
can_attend
,
tf
.
int32
)
# Only masked tokens are included in the loss
target_mask
=
tf
.
cast
(
masked_tokens
,
tf
.
int32
)
return
perm_mask
,
target_mask
,
inputs
,
masked_tokens
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
if
input_context
:
self
.
_num_replicas_in_sync
=
input_context
.
num_replicas_in_sync
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/data/pretrain_dataloader_test.py
View file @
6d6a78a2
...
...
@@ -24,19 +24,21 @@ import tensorflow as tf
from
official.nlp.data
import
pretrain_dataloader
def
_create_fake_dataset
(
output_path
,
seq_length
,
max_predictions_per_seq
,
use_position_id
,
use_next_sentence_label
,
use_v2_feature_names
=
False
):
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
_create_fake_bert_dataset
(
output_path
,
seq_length
,
max_predictions_per_seq
,
use_position_id
,
use_next_sentence_label
,
use_v2_feature_names
=
False
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
def
create_int_feature
(
values
):
f
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
f
def
create_float_feature
(
values
):
f
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
f
...
...
@@ -70,6 +72,34 @@ def _create_fake_dataset(output_path,
writer
.
close
()
def
_create_fake_xlnet_dataset
(
output_path
,
seq_length
,
max_predictions_per_seq
):
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
for
_
in
range
(
100
):
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
num_boundary_indices
=
np
.
random
.
randint
(
1
,
seq_length
)
if
max_predictions_per_seq
is
not
None
:
input_mask
=
np
.
zeros_like
(
input_ids
)
input_mask
[:
max_predictions_per_seq
]
=
1
np
.
random
.
shuffle
(
input_mask
)
else
:
input_mask
=
np
.
ones_like
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"input_word_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_type_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"boundary_indices"
]
=
create_int_feature
(
sorted
(
np
.
random
.
randint
(
seq_length
,
size
=
(
num_boundary_indices
))))
features
[
"target"
]
=
create_int_feature
(
input_ids
+
1
)
features
[
"label"
]
=
create_int_feature
([
1
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writer
.
write
(
tf_example
.
SerializeToString
())
writer
.
close
()
class
BertPretrainDataTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
...
...
@@ -80,7 +110,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
seq_length
=
128
max_predictions_per_seq
=
20
_create_fake_dataset
(
_create_fake_
bert_
dataset
(
train_data_path
,
seq_length
,
max_predictions_per_seq
,
...
...
@@ -114,7 +144,7 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
seq_length
=
128
max_predictions_per_seq
=
20
_create_fake_dataset
(
_create_fake_
bert_
dataset
(
train_data_path
,
seq_length
,
max_predictions_per_seq
,
...
...
@@ -141,5 +171,74 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertIn
(
"masked_lm_weights"
,
features
)
class
XLNetPretrainDataTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
itertools
.
product
(
(
"fixed"
,
"single_token"
,
"whole_word"
,
"token_span"
),
(
0
,
64
),
(
20
,
None
),
))
def
test_load_data
(
self
,
sample_strategy
,
reuse_length
,
max_predictions_per_seq
):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
seq_length
=
128
batch_size
=
5
_create_fake_xlnet_dataset
(
train_data_path
,
seq_length
,
max_predictions_per_seq
)
data_config
=
pretrain_dataloader
.
XLNetPretrainDataConfig
(
input_path
=
train_data_path
,
max_predictions_per_seq
=
max_predictions_per_seq
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
,
is_training
=
True
,
reuse_length
=
reuse_length
,
sample_strategy
=
sample_strategy
,
min_num_tokens
=
1
,
max_num_tokens
=
2
,
permutation_size
=
seq_length
//
2
,
leak_ratio
=
0.1
)
if
(
max_predictions_per_seq
is
None
and
sample_strategy
!=
"fixed"
):
with
self
.
assertRaisesWithRegexpMatch
(
ValueError
,
"`max_predictions_per_seq` must be set"
):
dataset
=
pretrain_dataloader
.
XLNetPretrainDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
else
:
dataset
=
pretrain_dataloader
.
XLNetPretrainDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertIn
(
"input_word_ids"
,
features
)
self
.
assertIn
(
"input_type_ids"
,
features
)
self
.
assertIn
(
"permutation_mask"
,
features
)
self
.
assertIn
(
"masked_tokens"
,
features
)
self
.
assertIn
(
"target"
,
features
)
self
.
assertIn
(
"target_mask"
,
features
)
self
.
assertAllClose
(
features
[
"input_word_ids"
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertAllClose
(
features
[
"input_type_ids"
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertAllClose
(
features
[
"permutation_mask"
].
shape
,
(
batch_size
,
seq_length
,
seq_length
))
self
.
assertAllClose
(
features
[
"masked_tokens"
].
shape
,
(
batch_size
,
seq_length
,))
if
max_predictions_per_seq
is
not
None
:
self
.
assertIn
(
"target_mapping"
,
features
)
self
.
assertAllClose
(
features
[
"target_mapping"
].
shape
,
(
batch_size
,
max_predictions_per_seq
,
seq_length
))
self
.
assertAllClose
(
features
[
"target_mask"
].
shape
,
(
batch_size
,
max_predictions_per_seq
))
self
.
assertAllClose
(
features
[
"target"
].
shape
,
(
batch_size
,
max_predictions_per_seq
))
else
:
self
.
assertAllClose
(
features
[
"target_mask"
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertAllClose
(
features
[
"target"
].
shape
,
(
batch_size
,
seq_length
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment