Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e10b29f4
"vscode:/vscode.git/clone" did not exist on "8433780efb9c78ab3136bb8de8ed104284664438"
Commit
e10b29f4
authored
Sep 08, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 08, 2020
Browse files
Internal change
PiperOrigin-RevId: 330621236
parent
636ca66f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
12 deletions
+75
-12
official/nlp/data/create_pretraining_data.py
official/nlp/data/create_pretraining_data.py
+11
-2
official/nlp/data/pretrain_dataloader.py
official/nlp/data/pretrain_dataloader.py
+22
-6
official/nlp/data/pretrain_dataloader_test.py
official/nlp/data/pretrain_dataloader_test.py
+42
-4
No files found.
official/nlp/data/create_pretraining_data.py
View file @
e10b29f4
...
@@ -60,6 +60,10 @@ flags.DEFINE_bool(
...
@@ -60,6 +60,10 @@ flags.DEFINE_bool(
"gzip_compress"
,
False
,
"gzip_compress"
,
False
,
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
"Whether to use `GZIP` compress option to get compressed TFRecord files."
)
flags
.
DEFINE_bool
(
"use_v2_feature_names"
,
False
,
"Whether to use the feature names consistent with the models."
)
flags
.
DEFINE_integer
(
"max_seq_length"
,
128
,
"Maximum sequence length."
)
flags
.
DEFINE_integer
(
"max_seq_length"
,
128
,
"Maximum sequence length."
)
flags
.
DEFINE_integer
(
"max_predictions_per_seq"
,
20
,
flags
.
DEFINE_integer
(
"max_predictions_per_seq"
,
20
,
...
@@ -147,9 +151,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
...
@@ -147,9 +151,14 @@ def write_instance_to_example_files(instances, tokenizer, max_seq_length,
next_sentence_label
=
1
if
instance
.
is_random_next
else
0
next_sentence_label
=
1
if
instance
.
is_random_next
else
0
features
=
collections
.
OrderedDict
()
features
=
collections
.
OrderedDict
()
if
FLAGS
.
use_v2_feature_names
:
features
[
"input_word_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_type_ids"
]
=
create_int_feature
(
segment_ids
)
else
:
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"masked_lm_positions"
]
=
create_int_feature
(
masked_lm_positions
)
features
[
"masked_lm_positions"
]
=
create_int_feature
(
masked_lm_positions
)
features
[
"masked_lm_ids"
]
=
create_int_feature
(
masked_lm_ids
)
features
[
"masked_lm_ids"
]
=
create_int_feature
(
masked_lm_ids
)
features
[
"masked_lm_weights"
]
=
create_float_feature
(
masked_lm_weights
)
features
[
"masked_lm_weights"
]
=
create_float_feature
(
masked_lm_weights
)
...
...
official/nlp/data/pretrain_dataloader.py
View file @
e10b29f4
...
@@ -35,6 +35,12 @@ class BertPretrainDataConfig(cfg.DataConfig):
...
@@ -35,6 +35,12 @@ class BertPretrainDataConfig(cfg.DataConfig):
max_predictions_per_seq
:
int
=
76
max_predictions_per_seq
:
int
=
76
use_next_sentence_label
:
bool
=
True
use_next_sentence_label
:
bool
=
True
use_position_id
:
bool
=
False
use_position_id
:
bool
=
False
# Historically, BERT implementations take `input_ids` and `segment_ids` as
# feature names. Inside the TF Model Garden implementation, the Keras model
# inputs are set as `input_word_ids` and `input_type_ids`. When
# v2_feature_names is True, the data loader assumes the tf.Examples use
# `input_word_ids` and `input_type_ids` as keys.
use_v2_feature_names
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
BertPretrainDataConfig
)
...
@@ -56,12 +62,8 @@ class BertPretrainDataLoader(data_loader.DataLoader):
...
@@ -56,12 +62,8 @@ class BertPretrainDataLoader(data_loader.DataLoader):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
"""Decodes a serialized tf.Example."""
name_to_features
=
{
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'masked_lm_positions'
:
'masked_lm_positions'
:
tf
.
io
.
FixedLenFeature
([
self
.
_max_predictions_per_seq
],
tf
.
int64
),
tf
.
io
.
FixedLenFeature
([
self
.
_max_predictions_per_seq
],
tf
.
int64
),
'masked_lm_ids'
:
'masked_lm_ids'
:
...
@@ -69,6 +71,16 @@ class BertPretrainDataLoader(data_loader.DataLoader):
...
@@ -69,6 +71,16 @@ class BertPretrainDataLoader(data_loader.DataLoader):
'masked_lm_weights'
:
'masked_lm_weights'
:
tf
.
io
.
FixedLenFeature
([
self
.
_max_predictions_per_seq
],
tf
.
float32
),
tf
.
io
.
FixedLenFeature
([
self
.
_max_predictions_per_seq
],
tf
.
float32
),
}
}
if
self
.
_params
.
use_v2_feature_names
:
name_to_features
.
update
({
'input_word_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_type_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
})
else
:
name_to_features
.
update
({
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
})
if
self
.
_use_next_sentence_label
:
if
self
.
_use_next_sentence_label
:
name_to_features
[
'next_sentence_labels'
]
=
tf
.
io
.
FixedLenFeature
([
1
],
name_to_features
[
'next_sentence_labels'
]
=
tf
.
io
.
FixedLenFeature
([
1
],
tf
.
int64
)
tf
.
int64
)
...
@@ -91,13 +103,17 @@ class BertPretrainDataLoader(data_loader.DataLoader):
...
@@ -91,13 +103,17 @@ class BertPretrainDataLoader(data_loader.DataLoader):
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{
x
=
{
'input_word_ids'
:
record
[
'input_ids'
],
'input_mask'
:
record
[
'input_mask'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
],
'masked_lm_positions'
:
record
[
'masked_lm_positions'
],
'masked_lm_positions'
:
record
[
'masked_lm_positions'
],
'masked_lm_ids'
:
record
[
'masked_lm_ids'
],
'masked_lm_ids'
:
record
[
'masked_lm_ids'
],
'masked_lm_weights'
:
record
[
'masked_lm_weights'
],
'masked_lm_weights'
:
record
[
'masked_lm_weights'
],
}
}
if
self
.
_params
.
use_v2_feature_names
:
x
[
'input_word_ids'
]
=
record
[
'input_word_ids'
]
x
[
'input_type_ids'
]
=
record
[
'input_type_ids'
]
else
:
x
[
'input_word_ids'
]
=
record
[
'input_ids'
]
x
[
'input_type_ids'
]
=
record
[
'segment_ids'
]
if
self
.
_use_next_sentence_label
:
if
self
.
_use_next_sentence_label
:
x
[
'next_sentence_labels'
]
=
record
[
'next_sentence_labels'
]
x
[
'next_sentence_labels'
]
=
record
[
'next_sentence_labels'
]
if
self
.
_use_position_id
:
if
self
.
_use_position_id
:
...
...
official/nlp/data/pretrain_dataloader_test.py
View file @
e10b29f4
...
@@ -24,8 +24,12 @@ import tensorflow as tf
...
@@ -24,8 +24,12 @@ import tensorflow as tf
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dataloader
def
_create_fake_dataset
(
output_path
,
seq_length
,
max_predictions_per_seq
,
def
_create_fake_dataset
(
output_path
,
use_position_id
,
use_next_sentence_label
):
seq_length
,
max_predictions_per_seq
,
use_position_id
,
use_next_sentence_label
,
use_v2_feature_names
=
False
):
"""Creates a fake dataset."""
"""Creates a fake dataset."""
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
writer
=
tf
.
io
.
TFRecordWriter
(
output_path
)
...
@@ -40,8 +44,12 @@ def _create_fake_dataset(output_path, seq_length, max_predictions_per_seq,
...
@@ -40,8 +44,12 @@ def _create_fake_dataset(output_path, seq_length, max_predictions_per_seq,
for
_
in
range
(
100
):
for
_
in
range
(
100
):
features
=
{}
features
=
{}
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
input_ids
=
np
.
random
.
randint
(
100
,
size
=
(
seq_length
))
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"input_mask"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
if
use_v2_feature_names
:
features
[
"input_word_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_type_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
else
:
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"segment_ids"
]
=
create_int_feature
(
np
.
ones_like
(
input_ids
))
features
[
"masked_lm_positions"
]
=
create_int_feature
(
features
[
"masked_lm_positions"
]
=
create_int_feature
(
...
@@ -102,6 +110,36 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -102,6 +110,36 @@ class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
use_next_sentence_label
)
use_next_sentence_label
)
self
.
assertEqual
(
"position_ids"
in
features
,
use_position_id
)
self
.
assertEqual
(
"position_ids"
in
features
,
use_position_id
)
def
test_v2_feature_names
(
self
):
train_data_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"train.tf_record"
)
seq_length
=
128
max_predictions_per_seq
=
20
_create_fake_dataset
(
train_data_path
,
seq_length
,
max_predictions_per_seq
,
use_next_sentence_label
=
True
,
use_position_id
=
False
,
use_v2_feature_names
=
True
)
data_config
=
pretrain_dataloader
.
BertPretrainDataConfig
(
input_path
=
train_data_path
,
max_predictions_per_seq
=
max_predictions_per_seq
,
seq_length
=
seq_length
,
global_batch_size
=
10
,
is_training
=
True
,
use_next_sentence_label
=
True
,
use_position_id
=
False
,
use_v2_feature_names
=
True
)
dataset
=
pretrain_dataloader
.
BertPretrainDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertIn
(
"input_word_ids"
,
features
)
self
.
assertIn
(
"input_mask"
,
features
)
self
.
assertIn
(
"input_type_ids"
,
features
)
self
.
assertIn
(
"masked_lm_positions"
,
features
)
self
.
assertIn
(
"masked_lm_ids"
,
features
)
self
.
assertIn
(
"masked_lm_weights"
,
features
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment