Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
13100073
Commit
13100073
authored
Jul 01, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 382655843
parent
00024735
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
13 deletions
+20
-13
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+20
-13
No files found.
official/nlp/data/sentence_prediction_dataloader.py
View file @
13100073
...
...
@@ -60,8 +60,8 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
else
:
self
.
_label_name_mapping
=
dict
()
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""De
codes a serialized tf.Example
."""
def
name_to_features_spec
(
self
):
"""De
fines features to decode. Subclass may override to append features
."""
label_type
=
LABEL_TYPES_MAP
[
self
.
_params
.
label_type
]
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
...
...
@@ -72,7 +72,11 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if
self
.
_include_example_id
:
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
return
name_to_features
def
_decode
(
self
,
record
:
tf
.
Tensor
):
"""Decodes a serialized tf.Example."""
example
=
tf
.
io
.
parse_single_example
(
record
,
self
.
name_to_features_spec
())
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
...
...
@@ -86,20 +90,23 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{
'input_
word_
ids'
:
record
[
'input_ids'
]
,
'input_mask'
:
record
[
'input_mask'
]
,
'
input_type_ids'
:
record
[
'segment
_ids'
]
key_mapping
=
{
'input_ids'
:
'input_
word_
ids'
,
'input_mask'
:
'input_mask'
,
'
segment_ids'
:
'input_type
_ids'
}
if
self
.
_include_example_id
:
x
[
'example_id'
]
=
record
[
'example_id'
]
x
[
self
.
_label_field
]
=
record
[
self
.
_label_field
]
ret
=
{}
for
record_key
in
record
:
if
record_key
in
key_mapping
:
ret
[
key_mapping
[
record_key
]]
=
record
[
record_key
]
else
:
ret
[
record_key
]
=
record
[
record_key
]
if
self
.
_label_field
in
self
.
_label_name_mapping
:
x
[
self
.
_label_name_mapping
[
self
.
_label_field
]]
=
record
[
self
.
_label_field
]
ret
[
self
.
_label_name_mapping
[
self
.
_label_field
]]
=
record
[
self
.
_label_field
]
return
x
return
ret
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a tf.dataset.Dataset."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment