Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1fd7aaaf
Commit
1fd7aaaf
authored
May 02, 2022
by
Yeqing Li
Committed by
A. Unique TensorFlower
May 02, 2022
Browse files
Makes the label field name **configurable** from tf.SequenceExample.
PiperOrigin-RevId: 446008537
parent
43d232e5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
3 deletions
+20
-3
official/projects/yt8m/configs/yt8m.py
official/projects/yt8m/configs/yt8m.py
+2
-0
official/projects/yt8m/dataloaders/utils.py
official/projects/yt8m/dataloaders/utils.py
+3
-1
official/projects/yt8m/dataloaders/yt8m_input.py
official/projects/yt8m/dataloaders/yt8m_input.py
+5
-1
official/projects/yt8m/dataloaders/yt8m_input_test.py
official/projects/yt8m/dataloaders/yt8m_input_test.py
+10
-1
No files found.
official/projects/yt8m/configs/yt8m.py
View file @
1fd7aaaf
...
...
@@ -45,6 +45,7 @@ class DataConfig(cfg.DataConfig):
feature_sources: if the feature from 'context' or 'features'.
feature_dtypes: dtype of decoded feature.
feature_from_bytes: decode feature from bytes or as dtype list.
label_fields: name of field to read from tf.SequenceExample.
segment_size: Number of frames in each segment.
segment_labels: Use segment level label. Default: False, video level label.
include_video_id: `True` means include video id (string) in the input to
...
...
@@ -70,6 +71,7 @@ class DataConfig(cfg.DataConfig):
feature_sources
:
Tuple
[
str
,
...]
=
(
'feature'
,
'feature'
)
feature_dtypes
:
Tuple
[
str
,
...]
=
(
'uint8'
,
'uint8'
)
feature_from_bytes
:
Tuple
[
bool
,
...]
=
(
True
,
True
)
label_field
:
str
=
'labels'
segment_size
:
int
=
1
segment_labels
:
bool
=
False
include_video_id
:
bool
=
False
...
...
official/projects/yt8m/dataloaders/utils.py
View file @
1fd7aaaf
...
...
@@ -248,7 +248,9 @@ def MakeExampleWithFloatFeatures(
seq_example
=
tf
.
train
.
SequenceExample
()
seq_example
.
context
.
feature
[
"id"
].
bytes_list
.
value
[:]
=
[
b
"id001"
]
seq_example
.
context
.
feature
[
"labels"
].
int64_list
.
value
[:]
=
[
1
,
2
,
3
,
4
]
seq_example
.
context
.
feature
[
"clip/label/index"
].
int64_list
.
value
[:]
=
[
1
,
2
,
3
,
4
]
seq_example
.
context
.
feature
[
"segment_labels"
].
int64_list
.
value
[:]
=
(
[
4
]
*
num_segment
)
seq_example
.
context
.
feature
[
"segment_start_times"
].
int64_list
.
value
[:]
=
[
...
...
official/projects/yt8m/dataloaders/yt8m_input.py
View file @
1fd7aaaf
...
...
@@ -251,6 +251,7 @@ class Decoder(decoder.Decoder):
self
.
_feature_dtypes
=
input_params
.
feature_dtypes
self
.
_feature_from_bytes
=
input_params
.
feature_from_bytes
self
.
_include_video_id
=
input_params
.
include_video_id
self
.
_label_field
=
input_params
.
label_field
assert
len
(
self
.
_feature_names
)
==
len
(
self
.
_feature_sources
),
(
"length of feature_names (={}) != length of feature_sizes (={})"
.
format
(
...
...
@@ -270,7 +271,8 @@ class Decoder(decoder.Decoder):
"segment_scores"
:
tf
.
io
.
VarLenFeature
(
tf
.
float32
)
})
else
:
self
.
_context_features
.
update
({
"labels"
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
)})
self
.
_context_features
.
update
(
{
self
.
_label_field
:
tf
.
io
.
VarLenFeature
(
tf
.
int64
)})
for
i
,
name
in
enumerate
(
self
.
_feature_names
):
if
self
.
_feature_from_bytes
[
i
]:
...
...
@@ -308,6 +310,8 @@ class Decoder(decoder.Decoder):
else
:
if
isinstance
(
decoded_tensor
[
name
],
tf
.
SparseTensor
):
decoded_tensor
[
name
]
=
tf
.
sparse
.
to_dense
(
decoded_tensor
[
name
])
if
not
self
.
_segment_labels
:
decoded_tensor
[
"labels"
]
=
decoded_tensor
[
self
.
_label_field
]
return
decoded_tensor
...
...
official/projects/yt8m/dataloaders/yt8m_input_test.py
View file @
1fd7aaaf
...
...
@@ -16,6 +16,7 @@ import os
from
absl
import
logging
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.core
import
input_reader
...
...
@@ -161,17 +162,25 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
else
:
self
.
assertCountEqual
([
'video_matrix'
,
'labels'
,
'num_frames'
],
example
.
keys
())
batch_size
=
params
.
global_batch_size
# Check tensor values.
expected_context
=
examples
[
0
].
context
.
feature
[
'VIDEO_EMBEDDING/context_feature/floats'
].
float_list
.
value
expected_feature
=
examples
[
0
].
feature_lists
.
feature_list
[
'FEATURE/feature/floats'
].
feature
[
0
].
float_list
.
value
expected_labels
=
examples
[
0
].
context
.
feature
[
params
.
label_field
].
int64_list
.
value
self
.
assertAllEqual
(
expected_feature
,
example
[
'video_matrix'
][
0
,
0
,
params
.
feature_sizes
[
0
]:])
self
.
assertAllEqual
(
expected_context
,
example
[
'video_matrix'
][
0
,
0
,
:
params
.
feature_sizes
[
0
]])
self
.
assertAllEqual
(
np
.
nonzero
(
example
[
'labels'
][
0
,
:].
numpy
())[
0
],
expected_labels
)
# Check tensor shape.
batch_size
=
params
.
global_batch_size
self
.
assertEqual
(
example
[
'video_matrix'
].
shape
.
as_list
(),
[
batch_size
,
params
.
max_frames
,
sum
(
params
.
feature_sizes
)])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment