Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6cd426d9
Commit
6cd426d9
authored
Oct 17, 2019
by
Jing Li
Committed by
A. Unique TensorFlower
Oct 17, 2019
Browse files
Support online masking for XLNet
PiperOrigin-RevId: 275408074
parent
b0581d0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
359 additions
and
126 deletions
+359
-126
official/nlp/xlnet/data_utils.py
official/nlp/xlnet/data_utils.py
+328
-119
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+31
-7
No files found.
official/nlp/xlnet/data_utils.py
View file @
6cd426d9
...
...
@@ -19,12 +19,15 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
collections
import
json
import
os
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
special_symbols
=
{
"<unk>"
:
0
,
"<s>"
:
1
,
...
...
@@ -49,6 +52,11 @@ SEG_ID_CLS = 2
SEG_ID_PAD
=
3
OnlineMaskingConfig
=
collections
.
namedtuple
(
"OnlineMaskingConfig"
,
[
"sample_strategy"
,
"max_num_tokens"
,
"min_num_tokens"
,
"max_num_words"
,
"min_num_words"
])
def
file_based_input_fn_builder
(
input_file
,
name_to_features
,
batch_size
,
is_training
):
"""Creates an `input_fn` closure."""
...
...
@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
def
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
):
"""Turn beg and end indices into actual mask."""
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
SEP_ID
),
tf
.
not_equal
(
inputs
,
CLS_ID
))
all_indices
=
tf
.
where
(
non_func_mask
,
tf
.
range
(
tgt_len
,
dtype
=
tf
.
int64
),
tf
.
constant
(
-
1
,
shape
=
[
tgt_len
],
dtype
=
tf
.
int64
))
candidate_matrix
=
tf
.
cast
(
tf
.
logical_and
(
all_indices
[
None
,
:]
>=
beg_indices
[:,
None
],
all_indices
[
None
,
:]
<
end_indices
[:,
None
]),
tf
.
float32
)
cumsum_matrix
=
tf
.
reshape
(
tf
.
cumsum
(
tf
.
reshape
(
candidate_matrix
,
[
-
1
])),
[
-
1
,
tgt_len
])
masked_matrix
=
tf
.
cast
(
cumsum_matrix
<=
num_predict
,
tf
.
float32
)
target_mask
=
tf
.
reduce_sum
(
candidate_matrix
*
masked_matrix
,
axis
=
0
)
is_masked
=
tf
.
cast
(
target_mask
,
tf
.
bool
)
return
is_masked
,
target_mask
def
_word_span_mask
(
inputs
,
tgt_len
,
num_predict
,
min_num_words
,
max_num_words
,
boundary
):
"""Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio
mask_alpha
=
tgt_len
/
num_predict
/
1.2
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int64
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_words
,
max_num_words
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
# Sample `num_predict` words here: note that this is over sampling
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
num_predict
,
dtype
=
tf
.
int64
,
)[
0
]
+
min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
num_predict
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
beg_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
beg_indices
+
span_lens
# Remove out of range indices
max_boundary_index
=
tf
.
cast
(
tf
.
shape
(
boundary
)[
0
]
-
1
,
tf
.
int64
)
valid_idx_mask
=
end_indices
<
max_boundary_index
beg_indices
=
tf
.
boolean_mask
(
beg_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
beg_indices
=
tf
.
gather
(
boundary
,
beg_indices
)
end_indices
=
tf
.
gather
(
boundary
,
end_indices
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
beg_indices
)[
0
],
tf
.
int64
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int64
))
beg_indices
=
tf
.
gather
(
beg_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_token_span_mask
(
inputs
,
tgt_len
,
num_predict
,
min_num_tokens
,
max_num_tokens
):
"""Sample token spans as prediction targets."""
mask_alpha
=
tgt_len
/
num_predict
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int64
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_tokens
,
max_num_tokens
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
num_predict
,
dtype
=
tf
.
int64
,
)[
0
]
+
min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
num_predict
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
# Compute the offset from left start to the right end
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
# Get the actual begin and end indices
beg_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
beg_indices
+
span_lens
# Remove out of range indices
valid_idx_mask
=
end_indices
<
tgt_len
beg_indices
=
tf
.
boolean_mask
(
beg_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
beg_indices
)[
0
],
tf
.
int64
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int64
))
beg_indices
=
tf
.
gather
(
beg_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_whole_word_mask
(
inputs
,
tgt_len
,
num_predict
,
boundary
):
"""Sample whole words as prediction targets."""
pair_indices
=
tf
.
concat
([
boundary
[:
-
1
,
None
],
boundary
[
1
:,
None
]],
axis
=
1
)
cand_pair_indices
=
tf
.
random
.
shuffle
(
pair_indices
)[:
num_predict
]
beg_indices
=
cand_pair_indices
[:,
0
]
end_indices
=
cand_pair_indices
[:,
1
]
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_single_token_mask
(
inputs
,
tgt_len
,
num_predict
):
"""Sample individual tokens as prediction targets."""
all_indices
=
tf
.
range
(
tgt_len
,
dtype
=
tf
.
int64
)
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
SEP_ID
),
tf
.
not_equal
(
inputs
,
CLS_ID
))
non_func_indices
=
tf
.
boolean_mask
(
all_indices
,
non_func_mask
)
masked_pos
=
tf
.
random
.
shuffle
(
non_func_indices
)
masked_pos
=
tf
.
contrib
.
framework
.
sort
(
masked_pos
[:
num_predict
])
target_mask
=
tf
.
sparse_to_dense
(
sparse_indices
=
masked_pos
,
output_shape
=
[
tgt_len
],
sparse_values
=
1.0
,
default_value
=
0.0
)
is_masked
=
tf
.
cast
(
target_mask
,
tf
.
bool
)
return
is_masked
,
target_mask
def
_online_sample_masks
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
,
boundary
=
None
):
"""Sample target positions to predict."""
logging
.
info
(
"Online sample with strategy: `%s`."
,
online_masking_config
.
sample_strategy
)
if
online_masking_config
.
sample_strategy
==
"single_token"
:
return
_single_token_mask
(
inputs
,
tgt_len
,
num_predict
)
elif
online_masking_config
.
sample_strategy
==
"whole_word"
:
assert
boundary
is
not
None
,
"whole word sampling requires `boundary`"
return
_whole_word_mask
(
inputs
,
tgt_len
,
num_predict
,
boundary
)
elif
online_masking_config
.
sample_strategy
==
"token_span"
:
return
_token_span_mask
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
.
min_num_tokens
,
online_masking_config
.
max_num_tokens
)
elif
online_masking_config
.
sample_strategy
==
"word_span"
:
assert
boundary
is
not
None
,
"word span sampling requires `boundary`"
return
_word_span_mask
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
.
min_num_words
,
online_masking_config
.
max_num_words
,
boundary
)
else
:
raise
NotImplementedError
def
create_pretrain_dataset
(
file_names
,
bsz_per_core
,
seq_len
,
reuse_len
,
perm_size
,
leak_ratio
,
online_masking_config
,
num_predict
=
None
,
input_pipeline_context
=
None
):
"""Creates pretrain dataset."""
...
...
@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names,
record_spec
=
{
"input"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"target"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"seg_id"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"label"
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
"is_masked"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
}
if
online_masking_config
.
sample_strategy
in
[
"whole_word"
,
"word_span"
]:
logging
.
info
(
"Add `boundary` spec for %s"
,
online_masking_config
.
sample_strategy
)
record_spec
[
"boundary"
]
=
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
# retrieve serialized example
example
=
tf
.
io
.
parse_single_example
(
serialized
=
record
,
features
=
record_spec
)
inputs
=
example
.
pop
(
"input"
)
target
=
example
.
pop
(
"target"
)
is_masked
=
tf
.
cast
(
example
.
pop
(
"is_masked"
),
tf
.
bool
)
non_reuse_len
=
seq_len
-
reuse_len
# perm_size should not be larger than reuse_len or non_reuse_len otherwise
# there will be data leaks.
assert
perm_size
<=
reuse_len
and
perm_size
<=
non_reuse_len
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0
,
target_0
,
target_mask_0
,
input_k_0
,
input_q_0
=
_local_perm
(
inputs
[:
reuse_len
],
target
[:
reuse_len
],
is_masked
[:
reuse_len
],
perm_size
,
reuse_len
)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1
,
target_1
,
target_mask_1
,
input_k_1
,
input_q_1
=
_local_perm
(
inputs
[
reuse_len
:],
target
[
reuse_len
:],
is_masked
[
reuse_len
:],
perm_size
,
non_reuse_len
)
perm_mask_0
=
tf
.
concat
(
[
perm_mask_0
,
tf
.
ones
([
reuse_len
,
non_reuse_len
])],
axis
=
1
)
perm_mask_1
=
tf
.
concat
([
tf
.
zeros
([
non_reuse_len
,
reuse_len
]),
perm_mask_1
],
axis
=
1
)
perm_mask
=
tf
.
concat
([
perm_mask_0
,
perm_mask_1
],
axis
=
0
)
target
=
tf
.
concat
([
target_0
,
target_1
],
axis
=
0
)
target_mask
=
tf
.
concat
([
target_mask_0
,
target_mask_1
],
axis
=
0
)
input_k
=
tf
.
concat
([
input_k_0
,
input_k_1
],
axis
=
0
)
input_q
=
tf
.
concat
([
input_q_0
,
input_q_1
],
axis
=
0
)
if
online_masking_config
.
sample_strategy
in
[
"whole_word"
,
"word_span"
]:
boundary
=
tf
.
sparse
.
to_dense
(
example
.
pop
(
"boundary"
))
else
:
boundary
=
None
is_masked
,
_
=
_online_sample_masks
(
inputs
,
seq_len
,
num_predict
,
online_masking_config
,
boundary
=
boundary
)
if
reuse_len
>
0
:
##### Use memory
# permutate the reuse and non-reuse parts separately
non_reuse_len
=
seq_len
-
reuse_len
assert
reuse_len
%
perm_size
==
0
and
non_reuse_len
%
perm_size
==
0
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0
,
target_mask_0
,
input_k_0
,
input_q_0
=
_local_perm
(
inputs
[:
reuse_len
],
is_masked
[:
reuse_len
],
perm_size
,
reuse_len
,
leak_ratio
)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1
,
target_mask_1
,
input_k_1
,
input_q_1
=
_local_perm
(
inputs
[
reuse_len
:],
is_masked
[
reuse_len
:],
perm_size
,
non_reuse_len
,
leak_ratio
)
perm_mask_0
=
tf
.
concat
(
[
perm_mask_0
,
tf
.
ones
([
reuse_len
,
non_reuse_len
])],
axis
=
1
)
perm_mask_1
=
tf
.
concat
(
[
tf
.
zeros
([
non_reuse_len
,
reuse_len
]),
perm_mask_1
],
axis
=
1
)
perm_mask
=
tf
.
concat
([
perm_mask_0
,
perm_mask_1
],
axis
=
0
)
target_mask
=
tf
.
concat
([
target_mask_0
,
target_mask_1
],
axis
=
0
)
input_k
=
tf
.
concat
([
input_k_0
,
input_k_1
],
axis
=
0
)
input_q
=
tf
.
concat
([
input_q_0
,
input_q_1
],
axis
=
0
)
else
:
##### Do not use memory
assert
seq_len
%
perm_size
==
0
# permutate the entire sequence together
perm_mask
,
target_mask
,
input_k
,
input_q
=
_local_perm
(
inputs
,
is_masked
,
perm_size
,
seq_len
,
leak_ratio
)
# reshape back to fixed shape
example
[
"perm_mask"
]
=
tf
.
reshape
(
perm_mask
,
[
seq_len
,
seq_len
])
example
[
"input_k"
]
=
tf
.
reshape
(
input_k
,
[
seq_len
])
example
[
"input_q"
]
=
tf
.
reshape
(
input_q
,
[
seq_len
])
# Directly use raw inputs as the target
target
=
inputs
if
num_predict
is
not
None
:
indices
=
tf
.
range
(
seq_len
,
dtype
=
tf
.
int64
)
...
...
@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names,
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
num_predict
])
##### target mask
target_mask
=
tf
.
concat
([
tf
.
ones
([
actual_num_predict
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
pad_len
],
dtype
=
tf
.
float32
)
],
axis
=
0
)
target_mask
=
tf
.
concat
(
[
tf
.
ones
([
actual_num_predict
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
pad_len
],
dtype
=
tf
.
float32
)],
axis
=
0
)
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
num_predict
])
else
:
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
seq_len
])
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
seq_len
])
# reshape back to fixed shape
example
[
"perm_mask"
]
=
tf
.
reshape
(
perm_mask
,
[
seq_len
,
seq_len
])
example
[
"input_k"
]
=
tf
.
reshape
(
input_k
,
[
seq_len
])
example
[
"input_q"
]
=
tf
.
reshape
(
input_q
,
[
seq_len
])
for
key
in
list
(
example
.
keys
()):
val
=
example
[
key
]
if
tf
.
keras
.
backend
.
is_sparse
(
val
):
...
...
@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names,
parser
=
parser
,
file_paths
=
file_names
,
bsz_per_core
=
bsz_per_core
,
sequential
=
reuse_len
>
0
,
input_pipeline_context
=
input_pipeline_context
)
return
dataset
def
format_filename
(
prefix
,
bsz_per_host
,
seq_len
,
bi_data
,
suffix
,
mask_alpha
=
5
,
mask_beta
=
1
,
reuse_len
=
None
,
uncased
=
False
,
fixed_num_predict
=
None
):
def
format_filename
(
prefix
,
suffix
,
bsz_per_host
,
seq_len
,
reuse_len
=
None
,
uncased
=
False
):
"""Generates input file name pattern."""
if
reuse_len
is
None
:
reuse_len_str
=
""
if
reuse_len
is
not
None
and
reuse_len
>
0
:
reuse_str
=
"reuse-{}."
.
format
(
reuse_len
)
bsz_str
=
"hostbsz-{}."
.
format
(
bsz_per_host
)
else
:
reuse_len_str
=
"reuse-{}."
.
format
(
reuse_len
)
reuse_str
=
""
bsz_str
=
""
if
not
uncased
:
uncased_str
=
""
else
:
uncased_str
=
"uncased."
if
bi_data
:
bi_data_str
=
"bi"
case_str
=
""
else
:
bi_data_str
=
"uni"
if
fixed_num_predict
is
not
None
:
fnp_str
=
"fnp-{}."
.
format
(
fixed_num_predict
)
else
:
fnp_str
=
""
case_str
=
"uncased."
file_name
=
"{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}"
.
format
(
prefix
,
bsz_per_host
,
seq_len
,
reuse_len_str
,
uncased_str
,
bi_data_str
,
mask_alpha
,
mask_beta
,
fnp_str
,
suffix
)
file_name
=
"{}.seq-{}.{}{}{}{}"
.
format
(
prefix
,
seq_len
,
reuse_str
,
bsz_str
,
case_str
,
suffix
)
return
file_name
...
...
@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size,
file_path
,
reuse_len
,
perm_size
,
mask_alpha
,
mask_beta
,
leak_ratio
,
num_predict
,
bi_data
,
uncased
,
online_masking_config
,
num_hosts
=
1
):
"""Returns input dataset from input file string."""
...
...
@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size,
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
split
=
"train"
bsz_per_host
=
int
(
batch_size
/
num_hosts
)
record_glob_base
=
format_filename
(
prefix
=
"record_info-{}-*"
.
format
(
split
),
bsz_per_host
=
int
(
batch_size
/
num_hosts
),
prefix
=
"meta.{}.pass-*"
.
format
(
split
),
suffix
=
"json*"
,
bsz_per_host
=
bsz_per_host
,
seq_len
=
seq_len
,
bi_data
=
bi_data
,
suffix
=
"json"
,
mask_alpha
=
mask_alpha
,
mask_beta
=
mask_beta
,
reuse_len
=
reuse_len
,
uncased
=
uncased
,
fixed_num_predict
=
num_predict
)
uncased
=
uncased
)
def
_get_num_batch
(
info
):
if
"num_batch"
in
info
:
return
info
[
"num_batch"
]
elif
"num_example"
in
info
:
return
info
[
"num_example"
]
/
bsz_per_host
else
:
raise
ValueError
(
"Do not have sample info."
)
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
...
...
@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size,
for
record_info_path
in
record_paths
:
with
tf
.
io
.
gfile
.
GFile
(
record_info_path
,
"r"
)
as
fp
:
info
=
json
.
load
(
fp
)
cur_record_info
[
"num_batch"
]
+=
in
fo
[
"
num_batch
"
]
cur_record_info
[
"num_batch"
]
+=
in
t
(
_get_
num_batch
(
info
))
cur_record_info
[
"filenames"
]
+=
info
[
"filenames"
]
# overwrite directory for `cur_record_info`
...
...
@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size,
seq_len
=
seq_len
,
reuse_len
=
reuse_len
,
perm_size
=
perm_size
,
leak_ratio
=
leak_ratio
,
online_masking_config
=
online_masking_config
,
num_predict
=
num_predict
,
input_pipeline_context
=
ctx
)
return
train_dataset
...
...
@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size,
def
parse_files_to_dataset
(
parser
,
file_paths
,
bsz_per_core
,
sequential
,
input_pipeline_context
=
None
):
"""Creates the dataset given file paths."""
...
...
@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser,
if
len
(
file_paths
)
>
1
:
dataset
=
dataset
.
shuffle
(
len
(
file_paths
))
dataset
=
tf
.
data
.
TFRecordDataset
(
dataset
)
if
sequential
:
# Note: cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
dataset
=
tf
.
data
.
TFRecordDataset
(
dataset
)
else
:
# `cycle_length` is the number of parallel files that get read.
cycle_length
=
min
(
8
,
len
(
file_paths
))
logging
.
info
(
"Interleave %d files"
,
cycle_length
)
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
parallel_interleave
(
tf
.
data
.
TFRecordDataset
,
sloppy
=
True
,
cycle_length
=
cycle_length
))
buffer_size
=
2048
logging
.
info
(
"Perform sample-level shuffle with size %d"
,
buffer_size
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
buffer_size
)
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
...
...
@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser,
return
dataset
def
_local_perm
(
inputs
,
targets
,
is_masked
,
perm_size
,
seq_len
):
def
_local_perm
(
inputs
,
is_masked
,
perm_size
,
seq_len
,
leak_ratio
):
"""Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
leak_ratio: float, percent of masked tokens that are leaked.
Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
...
...
@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This
...
...
@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
index
=
tf
.
random
.
shuffle
(
index
)
index
=
tf
.
reshape
(
tf
.
transpose
(
index
),
[
-
1
])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens
=
tf
.
logical_not
(
tf
.
logical_or
(
tf
.
equal
(
inputs
,
SEP_ID
),
tf
.
equal
(
inputs
,
CLS_ID
)))
non_mask_tokens
=
tf
.
logical_and
(
tf
.
logical_not
(
is_masked
),
non_func_tokens
)
masked_or_func_tokens
=
tf
.
logical_not
(
non_mask_tokens
)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index
=
-
tf
.
ones
([
seq_len
],
dtype
=
tf
.
int64
)
rev_index
=
tf
.
where
(
non_mask_tokens
,
smallest_index
,
index
)
# Create `target_mask`: non-funcional and masked tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens
=
tf
.
logical_and
(
masked_or_func_tokens
,
non_func_tokens
)
target_mask
=
tf
.
cast
(
target_tokens
,
tf
.
float32
)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index
=
tf
.
where
(
target_tokens
,
rev_index
,
rev_index
+
1
)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask
=
tf
.
logical_and
(
self_rev_index
[:,
None
]
<=
rev_index
[
None
,
:],
masked_or_func_tokens
)
perm_mask
=
tf
.
cast
(
perm_mask
,
tf
.
float32
)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets
=
tf
.
concat
([
inputs
[
0
:
1
],
targets
[:
-
1
]],
axis
=
0
)
non_func_tokens
=
tf
.
logical_not
(
tf
.
logical_or
(
tf
.
equal
(
inputs
,
SEP_ID
),
tf
.
equal
(
inputs
,
CLS_ID
)))
masked_tokens
=
tf
.
logical_and
(
is_masked
,
non_func_tokens
)
non_masked_or_func_tokens
=
tf
.
logical_not
(
masked_tokens
)
smallest_index
=
-
2
*
tf
.
ones
([
seq_len
],
dtype
=
tf
.
int64
)
# Similar to BERT, randomly leak some masked tokens
if
leak_ratio
>
0
:
leak_tokens
=
tf
.
logical_and
(
masked_tokens
,
tf
.
random
.
uniform
([
seq_len
],
maxval
=
1.0
)
<
leak_ratio
)
can_attend_self
=
tf
.
logical_or
(
non_masked_or_func_tokens
,
leak_tokens
)
else
:
can_attend_self
=
non_masked_or_func_tokens
to_index
=
tf
.
where
(
can_attend_self
,
smallest_index
,
index
)
from_index
=
tf
.
where
(
can_attend_self
,
to_index
+
1
,
to_index
)
# For masked tokens, can attend if i > j
# For context tokens, always can attend each other
can_attend
=
from_index
[:,
None
]
>
to_index
[
None
,
:]
# In modeling, 1 indicates cannot attend. Hence, reverse the value here.
perm_mask
=
1.0
-
tf
.
cast
(
can_attend
,
tf
.
float32
)
# Only masked tokens are included in the loss
target_mask
=
tf
.
cast
(
masked_tokens
,
tf
.
float32
)
# construct inputs_k
inputs_k
=
inputs
# construct inputs_q
inputs_q
=
target_mask
inputs_q
=
masked_tokens
return
perm_mask
,
new_targets
,
target_mask
,
inputs_k
,
inputs_q
return
perm_mask
,
target_mask
,
inputs_k
,
inputs_q
official/nlp/xlnet/run_pretrain.py
View file @
6cd426d9
...
...
@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.utils.misc
import
tpu_lib
flags
.
DEFINE_integer
(
"mask_alpha"
,
default
=
6
,
help
=
"How many tokens to form a group."
)
flags
.
DEFINE_integer
(
"mask_beta"
,
default
=
1
,
help
=
"How many tokens to mask within each group."
)
flags
.
DEFINE_integer
(
"num_predict"
,
default
=
None
,
help
=
"Number of tokens to predict in partial prediction."
)
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
# FLAGS for pretrain input preprocessing
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
flags
.
DEFINE_float
(
"leak_ratio"
,
default
=
0.1
,
help
=
"Percent of masked tokens that are leaked."
)
flags
.
DEFINE_enum
(
"sample_strategy"
,
default
=
"token_span"
,
enum_values
=
[
"single_token"
,
"whole_word"
,
"token_span"
,
"word_span"
],
help
=
"Stragey used to sample prediction targets."
)
flags
.
DEFINE_integer
(
"max_num_tokens"
,
default
=
5
,
help
=
"Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_tokens"
,
default
=
1
,
help
=
"Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"max_num_words"
,
default
=
5
,
help
=
"Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_words"
,
default
=
1
,
help
=
"Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
FLAGS
=
flags
.
FLAGS
...
...
@@ -74,11 +91,18 @@ def main(unused_argv):
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
logging
.
info
(
"***** Number of hosts used : %d"
,
num_hosts
)
online_masking_config
=
data_utils
.
OnlineMaskingConfig
(
sample_strategy
=
FLAGS
.
sample_strategy
,
max_num_tokens
=
FLAGS
.
max_num_tokens
,
min_num_tokens
=
FLAGS
.
min_num_tokens
,
max_num_words
=
FLAGS
.
max_num_words
,
min_num_words
=
FLAGS
.
min_num_words
)
train_input_fn
=
functools
.
partial
(
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
FLAGS
.
mask_alpha
,
FLAGS
.
mask_beta
,
FLAGS
.
num_predict
,
FLAGS
.
bi_data
,
FLAGS
.
uncased
,
num_hosts
)
FLAGS
.
leak_ratio
,
FLAGS
.
num_predict
,
FLAGS
.
uncased
,
online_masking_config
,
num_hosts
)
total_training_steps
=
FLAGS
.
train_steps
steps_per_epoch
=
int
(
FLAGS
.
train_data_size
/
FLAGS
.
train_batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment