Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b152ed9c
Commit
b152ed9c
authored
Aug 15, 2020
by
Kaushik Shivakumar
Browse files
fix
parent
2bd53cf1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
187 additions
and
17 deletions
+187
-17
research/object_detection/dataset_tools/create_ava_actions_tf_record.py
...t_detection/dataset_tools/create_ava_actions_tf_record.py
+186
-16
research/object_detection/utils/label_map_util.py
research/object_detection/utils/label_map_util.py
+1
-1
No files found.
research/object_detection/dataset_tools/create_ava_actions_tf_record.py
View file @
b152ed9c
...
...
@@ -60,6 +60,7 @@ import sys
import
zipfile
import
collections
import
glob
import
hashlib
from
absl
import
app
from
absl
import
flags
...
...
@@ -69,8 +70,10 @@ from six.moves import urllib
import
tensorflow.compat.v1
as
tf
import
cv2
from
object_detection.utils
import
dataset_util
from
object_detection.dataset_tools
import
seq_example_util
from
object_detection.protos
import
string_int_label_map_pb2
from
object_detection.utils
import
dataset_util
from
object_detection.utils
import
label_map_util
POSSIBLE_TIMESTAMPS
=
range
(
902
,
1798
)
ANNOTATION_URL
=
"https://research.google.com/ava/download/ava_v2.2.zip"
...
...
@@ -116,7 +119,8 @@ class Ava(object):
splits_to_process
=
"train,val,test"
,
video_path_format_string
=
None
,
seconds_per_sequence
=
10
,
hop_between_sequences
=
10
):
hop_between_sequences
=
10
,
examples_for_context
=
False
):
"""Downloads data and generates sharded TFRecords.
Downloads the data files, generates metadata, and processes the metadata
...
...
@@ -133,11 +137,15 @@ class Ava(object):
hop_between_sequences: The gap between the centers of
successive sequences.
"""
example_function
=
self
.
_generate_sequence_examples
if
examples_for_context
:
example_function
=
self
.
_generate_examples
logging
.
info
(
"Downloading data."
)
download_output
=
self
.
_download_data
()
for
key
in
splits_to_process
.
split
(
","
):
logging
.
info
(
"Generating examples for split: %s"
,
key
)
all_metadata
=
list
(
self
.
_generate_examples
(
all_metadata
=
list
(
example_function
(
download_output
[
0
][
key
][
0
],
download_output
[
0
][
key
][
1
],
download_output
[
1
],
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
))
...
...
@@ -155,7 +163,7 @@ class Ava(object):
writers
[
i
%
len
(
writers
)].
write
(
seq_ex
.
SerializeToString
())
logging
.
info
(
"Data extraction complete."
)
def
_generate_examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
def
_generate_
sequence_
examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
):
"""For each row in the annotation CSV, generates the corresponding examples.
...
...
@@ -275,6 +283,154 @@ class Ava(object):
cur_vid
.
release
()
def
_generate_examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
):
"""For each row in the annotation CSV, generates the corresponding
examples. When iterating through frames for a single example, skips
over excluded frames. Generates equal-length sequence examples, each with
length seconds_per_sequence (1 fps) and gaps of hop_between_sequences
frames (and seconds) between them, possible greater due to excluded frames.
Args:
annotation_file: path to the file of AVA CSV annotations.
excluded_path: path to a CSV file of excluded timestamps for each video.
label_map: an {int: string} label map.
seconds_per_sequence: The number of seconds per example in each example.
hop_between_sequences: The hop between sequences. If less than
seconds_per_sequence, will overlap.
Yields:
Each prepared tf.Example of metadata also containing video frames
"""
fieldnames
=
[
"id"
,
"timestamp_seconds"
,
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
,
"action_label"
]
frame_excluded
=
{}
# create a sparse, nested map of videos and frame indices.
with
open
(
excluded_file
,
"r"
)
as
excluded
:
reader
=
csv
.
reader
(
excluded
)
for
row
in
reader
:
frame_excluded
[(
row
[
0
],
int
(
float
(
row
[
1
])))]
=
True
with
open
(
annotation_file
,
"r"
)
as
annotations
:
reader
=
csv
.
DictReader
(
annotations
,
fieldnames
)
frame_annotations
=
collections
.
defaultdict
(
list
)
ids
=
set
()
# aggreggate by video and timestamp:
for
row
in
reader
:
ids
.
add
(
row
[
"id"
])
key
=
(
row
[
"id"
],
int
(
float
(
row
[
"timestamp_seconds"
])))
frame_annotations
[
key
].
append
(
row
)
# for each video, find aggreggates near each sampled frame.:
logging
.
info
(
"Generating metadata..."
)
media_num
=
1
for
media_id
in
ids
:
logging
.
info
(
"%d/%d, ignore warnings.
\n
"
%
(
media_num
,
len
(
ids
)))
media_num
+=
1
filepath
=
glob
.
glob
(
video_path_format_string
.
format
(
media_id
)
+
"*"
)[
0
]
filename
=
filepath
.
split
(
"/"
)[
-
1
]
cur_vid
=
cv2
.
VideoCapture
(
filepath
)
width
=
cur_vid
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
height
=
cur_vid
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
middle_frame_time
=
POSSIBLE_TIMESTAMPS
[
0
]
total_non_excluded
=
0
;
while
middle_frame_time
<
POSSIBLE_TIMESTAMPS
[
-
1
]:
if
(
media_id
,
middle_frame_time
)
not
in
frame_excluded
:
total_non_excluded
+=
1
middle_frame_time
+=
1
middle_frame_time
=
POSSIBLE_TIMESTAMPS
[
0
]
cur_frame_num
=
0
while
middle_frame_time
<
POSSIBLE_TIMESTAMPS
[
-
1
]:
cur_vid
.
set
(
cv2
.
CAP_PROP_POS_MSEC
,
(
middle_frame_time
)
*
SECONDS_TO_MILLI
)
success
,
image
=
cur_vid
.
read
()
success
,
buffer
=
cv2
.
imencode
(
'.jpg'
,
image
)
bufstring
=
buffer
.
tostring
()
if
(
media_id
,
middle_frame_time
)
in
frame_excluded
:
middle_frame_time
+=
1
logging
.
info
(
"Ignoring and skipping excluded frame."
)
continue
cur_frame_num
+=
1
source_id
=
str
(
middle_frame_time
)
+
"_"
+
media_id
xmins
=
[]
xmaxs
=
[]
ymins
=
[]
ymaxs
=
[]
areas
=
[]
labels
=
[]
label_strings
=
[]
confidences
=
[]
for
row
in
frame_annotations
[(
media_id
,
middle_frame_time
)]:
if
len
(
row
)
>
2
and
int
(
row
[
"action_label"
])
in
label_map
:
xmins
.
append
(
float
(
row
[
"xmin"
]))
xmaxs
.
append
(
float
(
row
[
"xmax"
]))
ymins
.
append
(
float
(
row
[
"ymin"
]))
ymaxs
.
append
(
float
(
row
[
"ymax"
]))
areas
.
append
(
float
((
xmaxs
[
-
1
]
-
xmins
[
-
1
])
*
(
ymaxs
[
-
1
]
-
ymins
[
-
1
]))
/
2
)
labels
.
append
(
int
(
row
[
"action_label"
]))
label_strings
.
append
(
label_map
[
int
(
row
[
"action_label"
])])
confidences
.
append
(
1
)
else
:
logging
.
warning
(
"Unknown label: %s"
,
row
[
"action_label"
])
middle_frame_time
+=
1
/
3
if
abs
(
middle_frame_time
-
round
(
middle_frame_time
)
<
0.0001
):
middle_frame_time
=
round
(
middle_frame_time
)
key
=
hashlib
.
sha256
(
bufstring
).
hexdigest
()
date_captured_feature
=
(
"2020-06-17 00:%02d:%02d"
%
((
middle_frame_time
-
900
)
*
3
//
60
,
(
middle_frame_time
-
900
)
*
3
%
60
))
context_feature_dict
=
{
'image/height'
:
dataset_util
.
int64_feature
(
int
(
height
)),
'image/width'
:
dataset_util
.
int64_feature
(
int
(
width
)),
'image/format'
:
dataset_util
.
bytes_feature
(
'jpeg'
.
encode
(
'utf8'
)),
'image/source_id'
:
dataset_util
.
bytes_feature
(
source_id
.
encode
(
"utf8"
)),
'image/filename'
:
dataset_util
.
bytes_feature
(
source_id
.
encode
(
"utf8"
)),
'image/encoded'
:
dataset_util
.
bytes_feature
(
bufstring
),
'image/key/sha256'
:
dataset_util
.
bytes_feature
(
key
.
encode
(
'utf8'
)),
'image/object/bbox/xmin'
:
dataset_util
.
float_list_feature
(
xmins
),
'image/object/bbox/xmax'
:
dataset_util
.
float_list_feature
(
xmaxs
),
'image/object/bbox/ymin'
:
dataset_util
.
float_list_feature
(
ymins
),
'image/object/bbox/ymax'
:
dataset_util
.
float_list_feature
(
ymaxs
),
'image/object/area'
:
dataset_util
.
float_list_feature
(
areas
),
'image/object/class/label'
:
dataset_util
.
int64_list_feature
(
labels
),
'image/object/class/text'
:
dataset_util
.
bytes_list_feature
(
label_strings
),
'image/location'
:
dataset_util
.
bytes_feature
(
media_id
.
encode
(
'utf8'
)),
'image/date_captured'
:
dataset_util
.
bytes_feature
(
date_captured_feature
.
encode
(
'utf8'
)),
'image/seq_num_frames'
:
dataset_util
.
int64_feature
(
total_non_excluded
),
'image/seq_frame_num'
:
dataset_util
.
int64_feature
(
cur_frame_num
),
'image/seq_id'
:
dataset_util
.
bytes_feature
(
media_id
.
encode
(
'utf8'
)),
}
yield
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
context_feature_dict
))
cur_vid
.
release
()
def
_download_data
(
self
):
"""Downloads and extracts data if not already available."""
if
sys
.
version_info
>=
(
3
,
0
):
...
...
@@ -300,14 +456,27 @@ class Ava(object):
SPLITS
[
split
][
"excluded-csv"
]
=
excluded_csv_path
paths
[
split
]
=
(
csv_path
,
excluded_csv_path
)
label_map
=
self
.
get_label_map
(
os
.
path
.
join
(
self
.
path_to_data_download
,
"ava_action_list_v2.2.pbtxt"
))
#label_map = self.get_label_map(os.path.join(self.path_to_data_download, "ava_action_list_v2.2.pbtxt"))
#
label_map
=
self
.
get_label_map
(
"object_detection/data/mscoco_label_map.pbtxt"
)
return
paths
,
label_map
def
get_label_map
(
self
,
path
):
"""Parses
s
a label map into {integer:string} format."""
"""Parses a label map into {integer:string} format."""
label_map
=
{}
with
open
(
path
,
"r"
)
as
f
:
label_map
=
label_map_util
.
load_labelmap
(
path
)
print
(
label_map
)
label_map_dict
=
{}
for
item
in
label_map
.
item
:
label_map_dict
[
item
.
name
]
=
item
.
label_id
with
open
(
path
,
"rb"
)
as
f
:
#label_map_util.load_labelmap()
#label_map_str = f.read()
#print(str(label_map_str))
#label_map = string_int_label_map_pb2.StringIntLabelMap()
#label_map.ParseFromString(label_map_str)
pass
"""
current_id = -1
current_label = ""
for line in f:
...
...
@@ -322,16 +491,12 @@ class Ava(object):
if "id:" in line:
current_id = int(line.split()[1])
if "}" in line:
label_map
[
current_id
]
=
bytes23
(
current_label
)
logging
.
info
(
label_map
)
label_map[current_id] = bytes(current_label, "utf8")"""
print
(
'label map dict'
)
logging
.
info
(
label_map_dict
)
assert
len
(
label_map
)
==
NUM_CLASSES
return
label_map
def
bytes23
(
string
):
"""Creates a bytes string in either Python 2 or 3."""
if
sys
.
version_info
>=
(
3
,
0
):
return
bytes
(
string
,
"utf8"
)
return
bytes
(
string
)
@
contextlib
.
contextmanager
def
_close_on_exit
(
writers
):
...
...
@@ -350,7 +515,8 @@ def main(argv):
flags
.
FLAGS
.
splits_to_process
,
flags
.
FLAGS
.
video_path_format_string
,
flags
.
FLAGS
.
seconds_per_sequence
,
flags
.
FLAGS
.
hop_between_sequences
)
flags
.
FLAGS
.
hop_between_sequences
,
flags
.
FLAGS
.
examples_for_context
)
if
__name__
==
"__main__"
:
flags
.
DEFINE_string
(
"path_to_download_data"
,
...
...
@@ -375,4 +541,8 @@ if __name__ == "__main__":
10
,
"The hop between sequences. If less than "
"seconds_per_sequence, will overlap."
)
flags
.
DEFINE_boolean
(
"examples_for_context"
,
False
,
"Whether to generate examples instead of sequence examples. "
"If true, will generate tf.Example objects for use in Context R-CNN."
)
app
.
run
(
main
)
research/object_detection/utils/label_map_util.py
View file @
b152ed9c
...
...
@@ -152,7 +152,7 @@ def load_labelmap(path):
Returns:
a StringIntLabelMapProto
"""
with
tf
.
io
.
gfile
.
GFile
(
path
,
'r'
)
as
fid
:
with
tf
.
io
.
gfile
.
GFile
(
path
,
'r
b
'
)
as
fid
:
label_map_string
=
fid
.
read
()
label_map
=
string_int_label_map_pb2
.
StringIntLabelMap
()
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment