Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6cd426d9
Commit
6cd426d9
authored
Oct 17, 2019
by
Jing Li
Committed by
A. Unique TensorFlower
Oct 17, 2019
Browse files
Support online masking for XLNet
PiperOrigin-RevId: 275408074
parent
b0581d0a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
359 additions
and
126 deletions
+359
-126
official/nlp/xlnet/data_utils.py
official/nlp/xlnet/data_utils.py
+328
-119
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+31
-7
No files found.
official/nlp/xlnet/data_utils.py
View file @
6cd426d9
...
@@ -19,12 +19,15 @@ from __future__ import division
...
@@ -19,12 +19,15 @@ from __future__ import division
# from __future__ import google_type_annotations
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
__future__
import
print_function
import
collections
import
json
import
json
import
os
import
os
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
special_symbols
=
{
special_symbols
=
{
"<unk>"
:
0
,
"<unk>"
:
0
,
"<s>"
:
1
,
"<s>"
:
1
,
...
@@ -49,6 +52,11 @@ SEG_ID_CLS = 2
...
@@ -49,6 +52,11 @@ SEG_ID_CLS = 2
SEG_ID_PAD
=
3
SEG_ID_PAD
=
3
OnlineMaskingConfig
=
collections
.
namedtuple
(
"OnlineMaskingConfig"
,
[
"sample_strategy"
,
"max_num_tokens"
,
"min_num_tokens"
,
"max_num_words"
,
"min_num_words"
])
def
file_based_input_fn_builder
(
input_file
,
name_to_features
,
batch_size
,
def
file_based_input_fn_builder
(
input_file
,
name_to_features
,
batch_size
,
is_training
):
is_training
):
"""Creates an `input_fn` closure."""
"""Creates an `input_fn` closure."""
...
@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
...
@@ -249,11 +257,191 @@ def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
return
_dataset_fn
if
use_dataset_fn
else
_dataset_fn
()
def
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
):
"""Turn beg and end indices into actual mask."""
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
SEP_ID
),
tf
.
not_equal
(
inputs
,
CLS_ID
))
all_indices
=
tf
.
where
(
non_func_mask
,
tf
.
range
(
tgt_len
,
dtype
=
tf
.
int64
),
tf
.
constant
(
-
1
,
shape
=
[
tgt_len
],
dtype
=
tf
.
int64
))
candidate_matrix
=
tf
.
cast
(
tf
.
logical_and
(
all_indices
[
None
,
:]
>=
beg_indices
[:,
None
],
all_indices
[
None
,
:]
<
end_indices
[:,
None
]),
tf
.
float32
)
cumsum_matrix
=
tf
.
reshape
(
tf
.
cumsum
(
tf
.
reshape
(
candidate_matrix
,
[
-
1
])),
[
-
1
,
tgt_len
])
masked_matrix
=
tf
.
cast
(
cumsum_matrix
<=
num_predict
,
tf
.
float32
)
target_mask
=
tf
.
reduce_sum
(
candidate_matrix
*
masked_matrix
,
axis
=
0
)
is_masked
=
tf
.
cast
(
target_mask
,
tf
.
bool
)
return
is_masked
,
target_mask
def
_word_span_mask
(
inputs
,
tgt_len
,
num_predict
,
min_num_words
,
max_num_words
,
boundary
):
"""Sample whole word spans as prediction targets."""
# Note: 1.2 is the token-to-word ratio
mask_alpha
=
tgt_len
/
num_predict
/
1.2
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int64
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_words
,
max_num_words
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
# Sample `num_predict` words here: note that this is over sampling
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
num_predict
,
dtype
=
tf
.
int64
,
)[
0
]
+
min_num_words
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
num_predict
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
beg_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
beg_indices
+
span_lens
# Remove out of range indices
max_boundary_index
=
tf
.
cast
(
tf
.
shape
(
boundary
)[
0
]
-
1
,
tf
.
int64
)
valid_idx_mask
=
end_indices
<
max_boundary_index
beg_indices
=
tf
.
boolean_mask
(
beg_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
beg_indices
=
tf
.
gather
(
boundary
,
beg_indices
)
end_indices
=
tf
.
gather
(
boundary
,
end_indices
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
beg_indices
)[
0
],
tf
.
int64
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int64
))
beg_indices
=
tf
.
gather
(
beg_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_token_span_mask
(
inputs
,
tgt_len
,
num_predict
,
min_num_tokens
,
max_num_tokens
):
"""Sample token spans as prediction targets."""
mask_alpha
=
tgt_len
/
num_predict
round_to_int
=
lambda
x
:
tf
.
cast
(
tf
.
round
(
x
),
tf
.
int64
)
# Sample span lengths from a zipf distribution
span_len_seq
=
np
.
arange
(
min_num_tokens
,
max_num_tokens
+
1
)
probs
=
np
.
array
([
1.0
/
(
i
+
1
)
for
i
in
span_len_seq
])
probs
/=
np
.
sum
(
probs
)
logits
=
tf
.
constant
(
np
.
log
(
probs
),
dtype
=
tf
.
float32
)
span_lens
=
tf
.
random
.
categorical
(
logits
=
logits
[
None
],
num_samples
=
num_predict
,
dtype
=
tf
.
int64
,
)[
0
]
+
min_num_tokens
# Sample the ratio [0.0, 1.0) of left context lengths
span_lens_float
=
tf
.
cast
(
span_lens
,
tf
.
float32
)
left_ratio
=
tf
.
random
.
uniform
(
shape
=
[
num_predict
],
minval
=
0.0
,
maxval
=
1.0
)
left_ctx_len
=
left_ratio
*
span_lens_float
*
(
mask_alpha
-
1
)
left_ctx_len
=
round_to_int
(
left_ctx_len
)
# Compute the offset from left start to the right end
right_offset
=
round_to_int
(
span_lens_float
*
mask_alpha
)
-
left_ctx_len
# Get the actual begin and end indices
beg_indices
=
(
tf
.
cumsum
(
left_ctx_len
)
+
tf
.
cumsum
(
right_offset
,
exclusive
=
True
))
end_indices
=
beg_indices
+
span_lens
# Remove out of range indices
valid_idx_mask
=
end_indices
<
tgt_len
beg_indices
=
tf
.
boolean_mask
(
beg_indices
,
valid_idx_mask
)
end_indices
=
tf
.
boolean_mask
(
end_indices
,
valid_idx_mask
)
# Shuffle valid indices
num_valid
=
tf
.
cast
(
tf
.
shape
(
beg_indices
)[
0
],
tf
.
int64
)
order
=
tf
.
random
.
shuffle
(
tf
.
range
(
num_valid
,
dtype
=
tf
.
int64
))
beg_indices
=
tf
.
gather
(
beg_indices
,
order
)
end_indices
=
tf
.
gather
(
end_indices
,
order
)
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_whole_word_mask
(
inputs
,
tgt_len
,
num_predict
,
boundary
):
"""Sample whole words as prediction targets."""
pair_indices
=
tf
.
concat
([
boundary
[:
-
1
,
None
],
boundary
[
1
:,
None
]],
axis
=
1
)
cand_pair_indices
=
tf
.
random
.
shuffle
(
pair_indices
)[:
num_predict
]
beg_indices
=
cand_pair_indices
[:,
0
]
end_indices
=
cand_pair_indices
[:,
1
]
return
_idx_pair_to_mask
(
beg_indices
,
end_indices
,
inputs
,
tgt_len
,
num_predict
)
def
_single_token_mask
(
inputs
,
tgt_len
,
num_predict
):
"""Sample individual tokens as prediction targets."""
all_indices
=
tf
.
range
(
tgt_len
,
dtype
=
tf
.
int64
)
non_func_mask
=
tf
.
logical_and
(
tf
.
not_equal
(
inputs
,
SEP_ID
),
tf
.
not_equal
(
inputs
,
CLS_ID
))
non_func_indices
=
tf
.
boolean_mask
(
all_indices
,
non_func_mask
)
masked_pos
=
tf
.
random
.
shuffle
(
non_func_indices
)
masked_pos
=
tf
.
contrib
.
framework
.
sort
(
masked_pos
[:
num_predict
])
target_mask
=
tf
.
sparse_to_dense
(
sparse_indices
=
masked_pos
,
output_shape
=
[
tgt_len
],
sparse_values
=
1.0
,
default_value
=
0.0
)
is_masked
=
tf
.
cast
(
target_mask
,
tf
.
bool
)
return
is_masked
,
target_mask
def
_online_sample_masks
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
,
boundary
=
None
):
"""Sample target positions to predict."""
logging
.
info
(
"Online sample with strategy: `%s`."
,
online_masking_config
.
sample_strategy
)
if
online_masking_config
.
sample_strategy
==
"single_token"
:
return
_single_token_mask
(
inputs
,
tgt_len
,
num_predict
)
elif
online_masking_config
.
sample_strategy
==
"whole_word"
:
assert
boundary
is
not
None
,
"whole word sampling requires `boundary`"
return
_whole_word_mask
(
inputs
,
tgt_len
,
num_predict
,
boundary
)
elif
online_masking_config
.
sample_strategy
==
"token_span"
:
return
_token_span_mask
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
.
min_num_tokens
,
online_masking_config
.
max_num_tokens
)
elif
online_masking_config
.
sample_strategy
==
"word_span"
:
assert
boundary
is
not
None
,
"word span sampling requires `boundary`"
return
_word_span_mask
(
inputs
,
tgt_len
,
num_predict
,
online_masking_config
.
min_num_words
,
online_masking_config
.
max_num_words
,
boundary
)
else
:
raise
NotImplementedError
def
create_pretrain_dataset
(
file_names
,
def
create_pretrain_dataset
(
file_names
,
bsz_per_core
,
bsz_per_core
,
seq_len
,
seq_len
,
reuse_len
,
reuse_len
,
perm_size
,
perm_size
,
leak_ratio
,
online_masking_config
,
num_predict
=
None
,
num_predict
=
None
,
input_pipeline_context
=
None
):
input_pipeline_context
=
None
):
"""Creates pretrain dataset."""
"""Creates pretrain dataset."""
...
@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names,
...
@@ -263,46 +451,67 @@ def create_pretrain_dataset(file_names,
record_spec
=
{
record_spec
=
{
"input"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"input"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"target"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"seg_id"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"seg_id"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
"label"
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
"label"
:
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
),
"is_masked"
:
tf
.
io
.
FixedLenFeature
([
seq_len
],
tf
.
int64
),
}
}
if
online_masking_config
.
sample_strategy
in
[
"whole_word"
,
"word_span"
]:
logging
.
info
(
"Add `boundary` spec for %s"
,
online_masking_config
.
sample_strategy
)
record_spec
[
"boundary"
]
=
tf
.
io
.
VarLenFeature
(
tf
.
int64
)
# retrieve serialized example
# retrieve serialized example
example
=
tf
.
io
.
parse_single_example
(
example
=
tf
.
io
.
parse_single_example
(
serialized
=
record
,
features
=
record_spec
)
serialized
=
record
,
features
=
record_spec
)
inputs
=
example
.
pop
(
"input"
)
inputs
=
example
.
pop
(
"input"
)
target
=
example
.
pop
(
"target"
)
if
online_masking_config
.
sample_strategy
in
[
"whole_word"
,
"word_span"
]:
is_masked
=
tf
.
cast
(
example
.
pop
(
"is_masked"
),
tf
.
bool
)
boundary
=
tf
.
sparse
.
to_dense
(
example
.
pop
(
"boundary"
))
else
:
non_reuse_len
=
seq_len
-
reuse_len
boundary
=
None
# perm_size should not be larger than reuse_len or non_reuse_len otherwise
is_masked
,
_
=
_online_sample_masks
(
# there will be data leaks.
inputs
,
seq_len
,
num_predict
,
online_masking_config
,
boundary
=
boundary
)
assert
perm_size
<=
reuse_len
and
perm_size
<=
non_reuse_len
if
reuse_len
>
0
:
# Creates permutation mask and target mask for the first reuse_len tokens.
##### Use memory
# The tokens in this part are reused from the last sequence.
# permutate the reuse and non-reuse parts separately
perm_mask_0
,
target_0
,
target_mask_0
,
input_k_0
,
input_q_0
=
_local_perm
(
non_reuse_len
=
seq_len
-
reuse_len
inputs
[:
reuse_len
],
target
[:
reuse_len
],
is_masked
[:
reuse_len
],
assert
reuse_len
%
perm_size
==
0
and
non_reuse_len
%
perm_size
==
0
perm_size
,
reuse_len
)
# Creates permutation mask and target mask for the first reuse_len tokens.
# Creates permutation mask and target mask for the rest of tokens in
# The tokens in this part are reused from the last sequence.
# current example, which are concatentation of two new segments.
perm_mask_0
,
target_mask_0
,
input_k_0
,
input_q_0
=
_local_perm
(
perm_mask_1
,
target_1
,
target_mask_1
,
input_k_1
,
input_q_1
=
_local_perm
(
inputs
[:
reuse_len
],
is_masked
[:
reuse_len
],
perm_size
,
reuse_len
,
inputs
[
reuse_len
:],
target
[
reuse_len
:],
is_masked
[
reuse_len
:],
leak_ratio
)
perm_size
,
non_reuse_len
)
# Creates permutation mask and target mask for the rest of tokens in
perm_mask_0
=
tf
.
concat
(
# current example, which are concatentation of two new segments.
[
perm_mask_0
,
tf
.
ones
([
reuse_len
,
non_reuse_len
])],
axis
=
1
)
perm_mask_1
,
target_mask_1
,
input_k_1
,
input_q_1
=
_local_perm
(
perm_mask_1
=
tf
.
concat
([
tf
.
zeros
([
non_reuse_len
,
reuse_len
]),
perm_mask_1
],
inputs
[
reuse_len
:],
is_masked
[
reuse_len
:],
perm_size
,
non_reuse_len
,
axis
=
1
)
leak_ratio
)
perm_mask
=
tf
.
concat
([
perm_mask_0
,
perm_mask_1
],
axis
=
0
)
target
=
tf
.
concat
([
target_0
,
target_1
],
axis
=
0
)
perm_mask_0
=
tf
.
concat
(
target_mask
=
tf
.
concat
([
target_mask_0
,
target_mask_1
],
axis
=
0
)
[
perm_mask_0
,
tf
.
ones
([
reuse_len
,
non_reuse_len
])],
axis
=
1
)
input_k
=
tf
.
concat
([
input_k_0
,
input_k_1
],
axis
=
0
)
perm_mask_1
=
tf
.
concat
(
input_q
=
tf
.
concat
([
input_q_0
,
input_q_1
],
axis
=
0
)
[
tf
.
zeros
([
non_reuse_len
,
reuse_len
]),
perm_mask_1
],
axis
=
1
)
perm_mask
=
tf
.
concat
([
perm_mask_0
,
perm_mask_1
],
axis
=
0
)
target_mask
=
tf
.
concat
([
target_mask_0
,
target_mask_1
],
axis
=
0
)
input_k
=
tf
.
concat
([
input_k_0
,
input_k_1
],
axis
=
0
)
input_q
=
tf
.
concat
([
input_q_0
,
input_q_1
],
axis
=
0
)
else
:
##### Do not use memory
assert
seq_len
%
perm_size
==
0
# permutate the entire sequence together
perm_mask
,
target_mask
,
input_k
,
input_q
=
_local_perm
(
inputs
,
is_masked
,
perm_size
,
seq_len
,
leak_ratio
)
# reshape back to fixed shape
example
[
"perm_mask"
]
=
tf
.
reshape
(
perm_mask
,
[
seq_len
,
seq_len
])
example
[
"input_k"
]
=
tf
.
reshape
(
input_k
,
[
seq_len
])
example
[
"input_q"
]
=
tf
.
reshape
(
input_q
,
[
seq_len
])
# Directly use raw inputs as the target
target
=
inputs
if
num_predict
is
not
None
:
if
num_predict
is
not
None
:
indices
=
tf
.
range
(
seq_len
,
dtype
=
tf
.
int64
)
indices
=
tf
.
range
(
seq_len
,
dtype
=
tf
.
int64
)
...
@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names,
...
@@ -327,21 +536,15 @@ def create_pretrain_dataset(file_names,
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
num_predict
])
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
num_predict
])
##### target mask
##### target mask
target_mask
=
tf
.
concat
([
target_mask
=
tf
.
concat
(
tf
.
ones
([
actual_num_predict
],
dtype
=
tf
.
float32
),
[
tf
.
ones
([
actual_num_predict
],
dtype
=
tf
.
float32
),
tf
.
zeros
([
pad_len
],
dtype
=
tf
.
float32
)
tf
.
zeros
([
pad_len
],
dtype
=
tf
.
float32
)],
],
axis
=
0
)
axis
=
0
)
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
num_predict
])
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
num_predict
])
else
:
else
:
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
seq_len
])
example
[
"target"
]
=
tf
.
reshape
(
target
,
[
seq_len
])
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
seq_len
])
example
[
"target_mask"
]
=
tf
.
reshape
(
target_mask
,
[
seq_len
])
# reshape back to fixed shape
example
[
"perm_mask"
]
=
tf
.
reshape
(
perm_mask
,
[
seq_len
,
seq_len
])
example
[
"input_k"
]
=
tf
.
reshape
(
input_k
,
[
seq_len
])
example
[
"input_q"
]
=
tf
.
reshape
(
input_q
,
[
seq_len
])
for
key
in
list
(
example
.
keys
()):
for
key
in
list
(
example
.
keys
()):
val
=
example
[
key
]
val
=
example
[
key
]
if
tf
.
keras
.
backend
.
is_sparse
(
val
):
if
tf
.
keras
.
backend
.
is_sparse
(
val
):
...
@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names,
...
@@ -360,42 +563,29 @@ def create_pretrain_dataset(file_names,
parser
=
parser
,
parser
=
parser
,
file_paths
=
file_names
,
file_paths
=
file_names
,
bsz_per_core
=
bsz_per_core
,
bsz_per_core
=
bsz_per_core
,
sequential
=
reuse_len
>
0
,
input_pipeline_context
=
input_pipeline_context
)
input_pipeline_context
=
input_pipeline_context
)
return
dataset
return
dataset
def
format_filename
(
prefix
,
def
format_filename
(
prefix
,
suffix
,
bsz_per_host
,
seq_len
,
reuse_len
=
None
,
bsz_per_host
,
uncased
=
False
):
seq_len
,
bi_data
,
suffix
,
mask_alpha
=
5
,
mask_beta
=
1
,
reuse_len
=
None
,
uncased
=
False
,
fixed_num_predict
=
None
):
"""Generates input file name pattern."""
"""Generates input file name pattern."""
if
reuse_len
is
None
:
if
reuse_len
is
not
None
and
reuse_len
>
0
:
reuse_len_str
=
""
reuse_str
=
"reuse-{}."
.
format
(
reuse_len
)
bsz_str
=
"hostbsz-{}."
.
format
(
bsz_per_host
)
else
:
else
:
reuse_len_str
=
"reuse-{}."
.
format
(
reuse_len
)
reuse_str
=
""
bsz_str
=
""
if
not
uncased
:
if
not
uncased
:
uncased_str
=
""
case_str
=
""
else
:
uncased_str
=
"uncased."
if
bi_data
:
bi_data_str
=
"bi"
else
:
else
:
bi_data_str
=
"uni"
case_str
=
"uncased."
if
fixed_num_predict
is
not
None
:
fnp_str
=
"fnp-{}."
.
format
(
fixed_num_predict
)
else
:
fnp_str
=
""
file_name
=
"{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}"
.
format
(
file_name
=
"{}.seq-{}.{}{}{}{}"
.
format
(
prefix
,
bsz_per_host
,
seq_len
,
reuse_len_str
,
uncased_str
,
bi_data_str
,
prefix
,
seq_len
,
reuse_str
,
bsz_str
,
case_str
,
suffix
)
mask_alpha
,
mask_beta
,
fnp_str
,
suffix
)
return
file_name
return
file_name
...
@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size,
...
@@ -406,11 +596,10 @@ def get_pretrain_input_data(batch_size,
file_path
,
file_path
,
reuse_len
,
reuse_len
,
perm_size
,
perm_size
,
mask_alpha
,
leak_ratio
,
mask_beta
,
num_predict
,
num_predict
,
bi_data
,
uncased
,
uncased
,
online_masking_config
,
num_hosts
=
1
):
num_hosts
=
1
):
"""Returns input dataset from input file string."""
"""Returns input dataset from input file string."""
...
@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size,
...
@@ -419,17 +608,22 @@ def get_pretrain_input_data(batch_size,
# than passing dataset instance itself.
# than passing dataset instance itself.
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
use_dataset_fn
=
isinstance
(
strategy
,
tf
.
distribute
.
experimental
.
TPUStrategy
)
split
=
"train"
split
=
"train"
bsz_per_host
=
int
(
batch_size
/
num_hosts
)
record_glob_base
=
format_filename
(
record_glob_base
=
format_filename
(
prefix
=
"record_info-{}-*"
.
format
(
split
),
prefix
=
"meta.{}.pass-*"
.
format
(
split
),
bsz_per_host
=
int
(
batch_size
/
num_hosts
),
suffix
=
"json*"
,
bsz_per_host
=
bsz_per_host
,
seq_len
=
seq_len
,
seq_len
=
seq_len
,
bi_data
=
bi_data
,
suffix
=
"json"
,
mask_alpha
=
mask_alpha
,
mask_beta
=
mask_beta
,
reuse_len
=
reuse_len
,
reuse_len
=
reuse_len
,
uncased
=
uncased
,
uncased
=
uncased
)
fixed_num_predict
=
num_predict
)
def
_get_num_batch
(
info
):
if
"num_batch"
in
info
:
return
info
[
"num_batch"
]
elif
"num_example"
in
info
:
return
info
[
"num_example"
]
/
bsz_per_host
else
:
raise
ValueError
(
"Do not have sample info."
)
if
use_dataset_fn
:
if
use_dataset_fn
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
if
batch_size
%
strategy
.
num_replicas_in_sync
!=
0
:
...
@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size,
...
@@ -460,7 +654,7 @@ def get_pretrain_input_data(batch_size,
for
record_info_path
in
record_paths
:
for
record_info_path
in
record_paths
:
with
tf
.
io
.
gfile
.
GFile
(
record_info_path
,
"r"
)
as
fp
:
with
tf
.
io
.
gfile
.
GFile
(
record_info_path
,
"r"
)
as
fp
:
info
=
json
.
load
(
fp
)
info
=
json
.
load
(
fp
)
cur_record_info
[
"num_batch"
]
+=
in
fo
[
"
num_batch
"
]
cur_record_info
[
"num_batch"
]
+=
in
t
(
_get_
num_batch
(
info
))
cur_record_info
[
"filenames"
]
+=
info
[
"filenames"
]
cur_record_info
[
"filenames"
]
+=
info
[
"filenames"
]
# overwrite directory for `cur_record_info`
# overwrite directory for `cur_record_info`
...
@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size,
...
@@ -494,6 +688,8 @@ def get_pretrain_input_data(batch_size,
seq_len
=
seq_len
,
seq_len
=
seq_len
,
reuse_len
=
reuse_len
,
reuse_len
=
reuse_len
,
perm_size
=
perm_size
,
perm_size
=
perm_size
,
leak_ratio
=
leak_ratio
,
online_masking_config
=
online_masking_config
,
num_predict
=
num_predict
,
num_predict
=
num_predict
,
input_pipeline_context
=
ctx
)
input_pipeline_context
=
ctx
)
return
train_dataset
return
train_dataset
...
@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size,
...
@@ -504,6 +700,7 @@ def get_pretrain_input_data(batch_size,
def
parse_files_to_dataset
(
parser
,
def
parse_files_to_dataset
(
parser
,
file_paths
,
file_paths
,
bsz_per_core
,
bsz_per_core
,
sequential
,
input_pipeline_context
=
None
):
input_pipeline_context
=
None
):
"""Creates the dataset given file paths."""
"""Creates the dataset given file paths."""
...
@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser,
...
@@ -519,7 +716,26 @@ def parse_files_to_dataset(parser,
if
len
(
file_paths
)
>
1
:
if
len
(
file_paths
)
>
1
:
dataset
=
dataset
.
shuffle
(
len
(
file_paths
))
dataset
=
dataset
.
shuffle
(
len
(
file_paths
))
dataset
=
tf
.
data
.
TFRecordDataset
(
dataset
)
if
sequential
:
# Note: cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
dataset
=
tf
.
data
.
TFRecordDataset
(
dataset
)
else
:
# `cycle_length` is the number of parallel files that get read.
cycle_length
=
min
(
8
,
len
(
file_paths
))
logging
.
info
(
"Interleave %d files"
,
cycle_length
)
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
parallel_interleave
(
tf
.
data
.
TFRecordDataset
,
sloppy
=
True
,
cycle_length
=
cycle_length
))
buffer_size
=
2048
logging
.
info
(
"Perform sample-level shuffle with size %d"
,
buffer_size
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
buffer_size
)
# (zihang): since we are doing online preprocessing, the parsed result of
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
...
@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser,
...
@@ -531,19 +747,19 @@ def parse_files_to_dataset(parser,
return
dataset
return
dataset
def
_local_perm
(
inputs
,
targets
,
is_masked
,
perm_size
,
seq_len
):
def
_local_perm
(
inputs
,
is_masked
,
perm_size
,
seq_len
,
leak_ratio
):
"""Samples a permutation of the factorization order.
"""Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly.
Creates perm_mask and target_mask accordingly.
Args:
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for
is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction.
partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
seq_len: int, sequence length.
leak_ratio: float, percent of masked tokens that are leaked.
Returns:
Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
...
@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
...
@@ -555,9 +771,6 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
means the ith token (in original order) can attend to the jth token
means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all
(in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper.
other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1,
target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This
the ith token needs to be predicted and mask will be used as input. This
...
@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
...
@@ -575,44 +788,40 @@ def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
index
=
tf
.
random
.
shuffle
(
index
)
index
=
tf
.
random
.
shuffle
(
index
)
index
=
tf
.
reshape
(
tf
.
transpose
(
index
),
[
-
1
])
index
=
tf
.
reshape
(
tf
.
transpose
(
index
),
[
-
1
])
# `perm_mask` and `target_mask`
# non-functional tokens
# non-functional tokens
non_func_tokens
=
tf
.
logical_not
(
non_func_tokens
=
tf
.
logical_not
(
tf
.
logical_or
(
tf
.
logical_or
(
tf
.
equal
(
inputs
,
SEP_ID
),
tf
.
equal
(
inputs
,
CLS_ID
)))
tf
.
equal
(
inputs
,
SEP_ID
),
tf
.
equal
(
inputs
,
CLS_ID
)))
non_mask_tokens
=
tf
.
logical_and
(
tf
.
logical_not
(
is_masked
),
non_func_tokens
)
masked_tokens
=
tf
.
logical_and
(
is_masked
,
non_func_tokens
)
masked_or_func_tokens
=
tf
.
logical_not
(
non_mask_tokens
)
non_masked_or_func_tokens
=
tf
.
logical_not
(
masked_tokens
)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
smallest_index
=
-
2
*
tf
.
ones
([
seq_len
],
dtype
=
tf
.
int64
)
# smallest index (-1):
# (1) they can be seen by all other positions
# Similar to BERT, randomly leak some masked tokens
# (2) they cannot see masked positions, so there won"t be information leak
if
leak_ratio
>
0
:
smallest_index
=
-
tf
.
ones
([
seq_len
],
dtype
=
tf
.
int64
)
leak_tokens
=
tf
.
logical_and
(
rev_index
=
tf
.
where
(
non_mask_tokens
,
smallest_index
,
index
)
masked_tokens
,
tf
.
random
.
uniform
([
seq_len
],
maxval
=
1.0
)
<
leak_ratio
)
# Create `target_mask`: non-funcional and masked tokens
can_attend_self
=
tf
.
logical_or
(
non_masked_or_func_tokens
,
leak_tokens
)
# 1: use mask as input and have loss
else
:
# 0: use token (or [SEP], [CLS]) as input and do not have loss
can_attend_self
=
non_masked_or_func_tokens
target_tokens
=
tf
.
logical_and
(
masked_or_func_tokens
,
non_func_tokens
)
to_index
=
tf
.
where
(
can_attend_self
,
smallest_index
,
index
)
target_mask
=
tf
.
cast
(
target_tokens
,
tf
.
float32
)
from_index
=
tf
.
where
(
can_attend_self
,
to_index
+
1
,
to_index
)
# Create `perm_mask`
# For masked tokens, can attend if i > j
# `target_tokens` cannot see themselves
# For context tokens, always can attend each other
self_rev_index
=
tf
.
where
(
target_tokens
,
rev_index
,
rev_index
+
1
)
can_attend
=
from_index
[:,
None
]
>
to_index
[
None
,
:]
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# In modeling, 1 indicates cannot attend. Hence, reverse the value here.
# 0: can attend if i > j or j is non-masked
perm_mask
=
1.0
-
tf
.
cast
(
can_attend
,
tf
.
float32
)
perm_mask
=
tf
.
logical_and
(
self_rev_index
[:,
None
]
<=
rev_index
[
None
,
:],
masked_or_func_tokens
)
# Only masked tokens are included in the loss
perm_mask
=
tf
.
cast
(
perm_mask
,
tf
.
float32
)
target_mask
=
tf
.
cast
(
masked_tokens
,
tf
.
float32
)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets
=
tf
.
concat
([
inputs
[
0
:
1
],
targets
[:
-
1
]],
axis
=
0
)
# construct inputs_k
# construct inputs_k
inputs_k
=
inputs
inputs_k
=
inputs
# construct inputs_q
# construct inputs_q
inputs_q
=
target_mask
inputs_q
=
masked_tokens
return
perm_mask
,
new_targets
,
target_mask
,
inputs_k
,
inputs_q
return
perm_mask
,
target_mask
,
inputs_k
,
inputs_q
official/nlp/xlnet/run_pretrain.py
View file @
6cd426d9
...
@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
...
@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
training_utils
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
flags
.
DEFINE_integer
(
"mask_alpha"
,
default
=
6
,
help
=
"How many tokens to form a group."
)
flags
.
DEFINE_integer
(
"mask_beta"
,
default
=
1
,
help
=
"How many tokens to mask within each group."
)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
"num_predict"
,
"num_predict"
,
default
=
None
,
default
=
None
,
help
=
"Number of tokens to predict in partial prediction."
)
help
=
"Number of tokens to predict in partial prediction."
)
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
# FLAGS for pretrain input preprocessing
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
flags
.
DEFINE_float
(
"leak_ratio"
,
default
=
0.1
,
help
=
"Percent of masked tokens that are leaked."
)
flags
.
DEFINE_enum
(
"sample_strategy"
,
default
=
"token_span"
,
enum_values
=
[
"single_token"
,
"whole_word"
,
"token_span"
,
"word_span"
],
help
=
"Stragey used to sample prediction targets."
)
flags
.
DEFINE_integer
(
"max_num_tokens"
,
default
=
5
,
help
=
"Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_tokens"
,
default
=
1
,
help
=
"Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"max_num_words"
,
default
=
5
,
help
=
"Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_words"
,
default
=
1
,
help
=
"Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -74,11 +91,18 @@ def main(unused_argv):
...
@@ -74,11 +91,18 @@ def main(unused_argv):
logging
.
info
(
"***** Number of cores used : %d"
,
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
strategy
.
num_replicas_in_sync
)
logging
.
info
(
"***** Number of hosts used : %d"
,
num_hosts
)
logging
.
info
(
"***** Number of hosts used : %d"
,
num_hosts
)
online_masking_config
=
data_utils
.
OnlineMaskingConfig
(
sample_strategy
=
FLAGS
.
sample_strategy
,
max_num_tokens
=
FLAGS
.
max_num_tokens
,
min_num_tokens
=
FLAGS
.
min_num_tokens
,
max_num_words
=
FLAGS
.
max_num_words
,
min_num_words
=
FLAGS
.
min_num_words
)
train_input_fn
=
functools
.
partial
(
train_input_fn
=
functools
.
partial
(
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
FLAGS
.
mask_alpha
,
FLAGS
.
mask_beta
,
FLAGS
.
num_predict
,
FLAGS
.
bi_data
,
FLAGS
.
leak_ratio
,
FLAGS
.
num_predict
,
FLAGS
.
uncased
,
online_masking_config
,
FLAGS
.
uncased
,
num_hosts
)
num_hosts
)
total_training_steps
=
FLAGS
.
train_steps
total_training_steps
=
FLAGS
.
train_steps
steps_per_epoch
=
int
(
FLAGS
.
train_data_size
/
FLAGS
.
train_batch_size
)
steps_per_epoch
=
int
(
FLAGS
.
train_data_size
/
FLAGS
.
train_batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment