Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
28f6bdc0
Commit
28f6bdc0
authored
Jun 08, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jun 08, 2021
Browse files
Add outputs_as_dict to SentencePredictionDataLoader.
PiperOrigin-RevId: 378190887
parent
73c04752
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
0 deletions
+27
-0
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+5
-0
official/nlp/data/sentence_prediction_dataloader_test.py
official/nlp/data/sentence_prediction_dataloader_test.py
+22
-0
No files found.
official/nlp/data/sentence_prediction_dataloader.py
View file @
28f6bdc0
...
@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
...
@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type
:
str
=
'int'
label_type
:
str
=
'int'
# Whether to include the example id number.
# Whether to include the example id number.
include_example_id
:
bool
=
False
include_example_id
:
bool
=
False
outputs_as_dict
:
bool
=
False
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
@
data_loader_factory
.
register_data_loader_cls
(
SentencePredictionDataConfig
)
...
@@ -85,6 +86,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -85,6 +86,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
x
[
'example_id'
]
=
record
[
'example_id'
]
x
[
'example_id'
]
=
record
[
'example_id'
]
if
self
.
_params
.
outputs_as_dict
:
x
[
'next_sentence_labels'
]
=
record
[
'label_ids'
]
return
x
y
=
record
[
'label_ids'
]
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
return
(
x
,
y
)
...
...
official/nlp/data/sentence_prediction_dataloader_test.py
View file @
28f6bdc0
...
@@ -141,6 +141,28 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -141,6 +141,28 @@ class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
.
shape
,
(
batch_size
,))
self
.
assertEqual
(
labels
.
dtype
,
expected_label_type
)
self
.
assertEqual
(
labels
.
dtype
,
expected_label_type
)
def
test_load_dataset_as_dict
(
self
):
input_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'train.tf_record'
)
batch_size
=
10
seq_length
=
128
_create_fake_preprocessed_dataset
(
input_path
,
seq_length
,
'int'
)
data_config
=
loader
.
SentencePredictionDataConfig
(
input_path
=
input_path
,
seq_length
=
seq_length
,
global_batch_size
=
batch_size
,
label_type
=
'int'
,
outputs_as_dict
=
True
)
dataset
=
loader
.
SentencePredictionDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
self
.
assertCountEqual
([
'input_word_ids'
,
'input_mask'
,
'input_type_ids'
,
'next_sentence_labels'
],
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'next_sentence_labels'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
'next_sentence_labels'
].
dtype
,
tf
.
int32
)
class
SentencePredictionTfdsDataLoaderTest
(
tf
.
test
.
TestCase
,
class
SentencePredictionTfdsDataLoaderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
parameterized
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment