Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0e74158f
"vscode:/vscode.git/clone" did not exist on "5319098e2a2357dca2b144dfc005df234cb7ca79"
Commit
0e74158f
authored
Jul 10, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 384018258
parent
5f23689e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
15 deletions
+26
-15
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+8
-6
official/nlp/data/sentence_prediction_dataloader_test.py
official/nlp/data/sentence_prediction_dataloader_test.py
+18
-9
No files found.
official/nlp/data/sentence_prediction_dataloader.py
View file @
0e74158f
...
...
@@ -222,13 +222,12 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
"""Berts preprocess."""
segments
=
[
record
[
x
]
for
x
in
self
.
_text_fields
]
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
_include_example_i
d
:
model_inputs
[
'example_id'
]
=
record
[
'example_id'
]
model_inputs
[
self
.
_label_field
]
=
record
[
self
.
_label_field
]
for
key
in
recor
d
:
if
key
not
in
self
.
_text_fields
:
model_inputs
[
key
]
=
record
[
key
]
return
model_inputs
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
def
name_to_features_spec
(
self
):
name_to_features
=
{}
for
text_field
in
self
.
_text_fields
:
name_to_features
[
text_field
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
string
)
...
...
@@ -237,8 +236,11 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
name_to_features
[
self
.
_label_field
]
=
tf
.
io
.
FixedLenFeature
([],
label_type
)
if
self
.
_include_example_id
:
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
return
name_to_features
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
example
=
tf
.
io
.
parse_single_example
(
record
,
self
.
name_to_features_spec
())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
...
...
official/nlp/data/sentence_prediction_dataloader_test.py
View file @
0e74158f
...
...
@@ -198,9 +198,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
expected_keys
=
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
]
if
use_tfds
:
expected_keys
+=
[
'idx'
]
self
.
assertCountEqual
(
expected_keys
,
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
...
...
@@ -233,9 +236,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
expected_keys
=
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
]
if
use_tfds
:
expected_keys
+=
[
'idx'
]
self
.
assertCountEqual
(
expected_keys
,
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
...
...
@@ -268,9 +274,12 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
],
features
.
keys
())
expected_keys
=
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_field
]
if
use_tfds
:
expected_keys
+=
[
'idx'
]
self
.
assertCountEqual
(
expected_keys
,
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment