Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6cd426d9
Commit
6cd426d9
authored
Oct 17, 2019
by
Jing Li
Committed by
A. Unique TensorFlower
Oct 17, 2019
Browse files
Support online masking for XLNet
PiperOrigin-RevId: 275408074
parent
b0581d0a
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
359 additions
and
126 deletions
+359
-126
official/nlp/xlnet/data_utils.py
official/nlp/xlnet/data_utils.py
+328
-119
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+31
-7
No files found.
official/nlp/xlnet/data_utils.py
View file @
6cd426d9
This diff is collapsed.
Click to expand it.
official/nlp/xlnet/run_pretrain.py
View file @
6cd426d9
...
...
@@ -35,16 +35,33 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.utils.misc
import
tpu_lib
flags
.
DEFINE_integer
(
"mask_alpha"
,
default
=
6
,
help
=
"How many tokens to form a group."
)
flags
.
DEFINE_integer
(
"mask_beta"
,
default
=
1
,
help
=
"How many tokens to mask within each group."
)
flags
.
DEFINE_integer
(
"num_predict"
,
default
=
None
,
help
=
"Number of tokens to predict in partial prediction."
)
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
# FLAGS for pretrain input preprocessing
flags
.
DEFINE_integer
(
"perm_size"
,
0
,
help
=
"Window size of permutation."
)
flags
.
DEFINE_float
(
"leak_ratio"
,
default
=
0.1
,
help
=
"Percent of masked tokens that are leaked."
)
flags
.
DEFINE_enum
(
"sample_strategy"
,
default
=
"token_span"
,
enum_values
=
[
"single_token"
,
"whole_word"
,
"token_span"
,
"word_span"
],
help
=
"Stragey used to sample prediction targets."
)
flags
.
DEFINE_integer
(
"max_num_tokens"
,
default
=
5
,
help
=
"Maximum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_tokens"
,
default
=
1
,
help
=
"Minimum number of tokens to sample in a span."
"Effective when token_span strategy is used."
)
flags
.
DEFINE_integer
(
"max_num_words"
,
default
=
5
,
help
=
"Maximum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
flags
.
DEFINE_integer
(
"min_num_words"
,
default
=
1
,
help
=
"Minimum number of whole words to sample in a span."
"Effective when word_span strategy is used."
)
FLAGS
=
flags
.
FLAGS
...
...
@@ -74,11 +91,18 @@ def main(unused_argv):
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
logging
.
info
(
"***** Number of hosts used : %d"
,
num_hosts
)
online_masking_config
=
data_utils
.
OnlineMaskingConfig
(
sample_strategy
=
FLAGS
.
sample_strategy
,
max_num_tokens
=
FLAGS
.
max_num_tokens
,
min_num_tokens
=
FLAGS
.
min_num_tokens
,
max_num_words
=
FLAGS
.
max_num_words
,
min_num_words
=
FLAGS
.
min_num_words
)
train_input_fn
=
functools
.
partial
(
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
FLAGS
.
mask_alpha
,
FLAGS
.
mask_beta
,
FLAGS
.
num_predict
,
FLAGS
.
bi_data
,
FLAGS
.
uncased
,
num_hosts
)
FLAGS
.
leak_ratio
,
FLAGS
.
num_predict
,
FLAGS
.
uncased
,
online_masking_config
,
num_hosts
)
total_training_steps
=
FLAGS
.
train_steps
steps_per_epoch
=
int
(
FLAGS
.
train_data_size
/
FLAGS
.
train_batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment