Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b152ed9c
Commit
b152ed9c
authored
Aug 15, 2020
by
Kaushik Shivakumar
Browse files
fix
parent
2bd53cf1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
187 additions
and
17 deletions
+187
-17
research/object_detection/dataset_tools/create_ava_actions_tf_record.py
...t_detection/dataset_tools/create_ava_actions_tf_record.py
+186
-16
research/object_detection/utils/label_map_util.py
research/object_detection/utils/label_map_util.py
+1
-1
No files found.
research/object_detection/dataset_tools/create_ava_actions_tf_record.py
View file @
b152ed9c
...
@@ -60,6 +60,7 @@ import sys
...
@@ -60,6 +60,7 @@ import sys
import
zipfile
import
zipfile
import
collections
import
collections
import
glob
import
glob
import
hashlib
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
...
@@ -69,8 +70,10 @@ from six.moves import urllib
...
@@ -69,8 +70,10 @@ from six.moves import urllib
import
tensorflow.compat.v1
as
tf
import
tensorflow.compat.v1
as
tf
import
cv2
import
cv2
from
object_detection.utils
import
dataset_util
from
object_detection.dataset_tools
import
seq_example_util
from
object_detection.dataset_tools
import
seq_example_util
from
object_detection.protos
import
string_int_label_map_pb2
from
object_detection.utils
import
dataset_util
from
object_detection.utils
import
label_map_util
POSSIBLE_TIMESTAMPS
=
range
(
902
,
1798
)
POSSIBLE_TIMESTAMPS
=
range
(
902
,
1798
)
ANNOTATION_URL
=
"https://research.google.com/ava/download/ava_v2.2.zip"
ANNOTATION_URL
=
"https://research.google.com/ava/download/ava_v2.2.zip"
...
@@ -116,7 +119,8 @@ class Ava(object):
...
@@ -116,7 +119,8 @@ class Ava(object):
splits_to_process
=
"train,val,test"
,
splits_to_process
=
"train,val,test"
,
video_path_format_string
=
None
,
video_path_format_string
=
None
,
seconds_per_sequence
=
10
,
seconds_per_sequence
=
10
,
hop_between_sequences
=
10
):
hop_between_sequences
=
10
,
examples_for_context
=
False
):
"""Downloads data and generates sharded TFRecords.
"""Downloads data and generates sharded TFRecords.
Downloads the data files, generates metadata, and processes the metadata
Downloads the data files, generates metadata, and processes the metadata
...
@@ -133,11 +137,15 @@ class Ava(object):
...
@@ -133,11 +137,15 @@ class Ava(object):
hop_between_sequences: The gap between the centers of
hop_between_sequences: The gap between the centers of
successive sequences.
successive sequences.
"""
"""
example_function
=
self
.
_generate_sequence_examples
if
examples_for_context
:
example_function
=
self
.
_generate_examples
logging
.
info
(
"Downloading data."
)
logging
.
info
(
"Downloading data."
)
download_output
=
self
.
_download_data
()
download_output
=
self
.
_download_data
()
for
key
in
splits_to_process
.
split
(
","
):
for
key
in
splits_to_process
.
split
(
","
):
logging
.
info
(
"Generating examples for split: %s"
,
key
)
logging
.
info
(
"Generating examples for split: %s"
,
key
)
all_metadata
=
list
(
self
.
_generate_examples
(
all_metadata
=
list
(
example_function
(
download_output
[
0
][
key
][
0
],
download_output
[
0
][
key
][
1
],
download_output
[
0
][
key
][
0
],
download_output
[
0
][
key
][
1
],
download_output
[
1
],
seconds_per_sequence
,
hop_between_sequences
,
download_output
[
1
],
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
))
video_path_format_string
))
...
@@ -155,7 +163,7 @@ class Ava(object):
...
@@ -155,7 +163,7 @@ class Ava(object):
writers
[
i
%
len
(
writers
)].
write
(
seq_ex
.
SerializeToString
())
writers
[
i
%
len
(
writers
)].
write
(
seq_ex
.
SerializeToString
())
logging
.
info
(
"Data extraction complete."
)
logging
.
info
(
"Data extraction complete."
)
def
_generate_examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
def
_generate_
sequence_
examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
seconds_per_sequence
,
hop_between_sequences
,
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
):
video_path_format_string
):
"""For each row in the annotation CSV, generates the corresponding examples.
"""For each row in the annotation CSV, generates the corresponding examples.
...
@@ -275,6 +283,154 @@ class Ava(object):
...
@@ -275,6 +283,154 @@ class Ava(object):
cur_vid
.
release
()
cur_vid
.
release
()
def
_generate_examples
(
self
,
annotation_file
,
excluded_file
,
label_map
,
seconds_per_sequence
,
hop_between_sequences
,
video_path_format_string
):
"""For each row in the annotation CSV, generates the corresponding
examples. When iterating through frames for a single example, skips
over excluded frames. Generates equal-length sequence examples, each with
length seconds_per_sequence (1 fps) and gaps of hop_between_sequences
frames (and seconds) between them, possible greater due to excluded frames.
Args:
annotation_file: path to the file of AVA CSV annotations.
excluded_path: path to a CSV file of excluded timestamps for each video.
label_map: an {int: string} label map.
seconds_per_sequence: The number of seconds per example in each example.
hop_between_sequences: The hop between sequences. If less than
seconds_per_sequence, will overlap.
Yields:
Each prepared tf.Example of metadata also containing video frames
"""
fieldnames
=
[
"id"
,
"timestamp_seconds"
,
"xmin"
,
"ymin"
,
"xmax"
,
"ymax"
,
"action_label"
]
frame_excluded
=
{}
# create a sparse, nested map of videos and frame indices.
with
open
(
excluded_file
,
"r"
)
as
excluded
:
reader
=
csv
.
reader
(
excluded
)
for
row
in
reader
:
frame_excluded
[(
row
[
0
],
int
(
float
(
row
[
1
])))]
=
True
with
open
(
annotation_file
,
"r"
)
as
annotations
:
reader
=
csv
.
DictReader
(
annotations
,
fieldnames
)
frame_annotations
=
collections
.
defaultdict
(
list
)
ids
=
set
()
# aggreggate by video and timestamp:
for
row
in
reader
:
ids
.
add
(
row
[
"id"
])
key
=
(
row
[
"id"
],
int
(
float
(
row
[
"timestamp_seconds"
])))
frame_annotations
[
key
].
append
(
row
)
# for each video, find aggreggates near each sampled frame.:
logging
.
info
(
"Generating metadata..."
)
media_num
=
1
for
media_id
in
ids
:
logging
.
info
(
"%d/%d, ignore warnings.
\n
"
%
(
media_num
,
len
(
ids
)))
media_num
+=
1
filepath
=
glob
.
glob
(
video_path_format_string
.
format
(
media_id
)
+
"*"
)[
0
]
filename
=
filepath
.
split
(
"/"
)[
-
1
]
cur_vid
=
cv2
.
VideoCapture
(
filepath
)
width
=
cur_vid
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
)
height
=
cur_vid
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
)
middle_frame_time
=
POSSIBLE_TIMESTAMPS
[
0
]
total_non_excluded
=
0
;
while
middle_frame_time
<
POSSIBLE_TIMESTAMPS
[
-
1
]:
if
(
media_id
,
middle_frame_time
)
not
in
frame_excluded
:
total_non_excluded
+=
1
middle_frame_time
+=
1
middle_frame_time
=
POSSIBLE_TIMESTAMPS
[
0
]
cur_frame_num
=
0
while
middle_frame_time
<
POSSIBLE_TIMESTAMPS
[
-
1
]:
cur_vid
.
set
(
cv2
.
CAP_PROP_POS_MSEC
,
(
middle_frame_time
)
*
SECONDS_TO_MILLI
)
success
,
image
=
cur_vid
.
read
()
success
,
buffer
=
cv2
.
imencode
(
'.jpg'
,
image
)
bufstring
=
buffer
.
tostring
()
if
(
media_id
,
middle_frame_time
)
in
frame_excluded
:
middle_frame_time
+=
1
logging
.
info
(
"Ignoring and skipping excluded frame."
)
continue
cur_frame_num
+=
1
source_id
=
str
(
middle_frame_time
)
+
"_"
+
media_id
xmins
=
[]
xmaxs
=
[]
ymins
=
[]
ymaxs
=
[]
areas
=
[]
labels
=
[]
label_strings
=
[]
confidences
=
[]
for
row
in
frame_annotations
[(
media_id
,
middle_frame_time
)]:
if
len
(
row
)
>
2
and
int
(
row
[
"action_label"
])
in
label_map
:
xmins
.
append
(
float
(
row
[
"xmin"
]))
xmaxs
.
append
(
float
(
row
[
"xmax"
]))
ymins
.
append
(
float
(
row
[
"ymin"
]))
ymaxs
.
append
(
float
(
row
[
"ymax"
]))
areas
.
append
(
float
((
xmaxs
[
-
1
]
-
xmins
[
-
1
])
*
(
ymaxs
[
-
1
]
-
ymins
[
-
1
]))
/
2
)
labels
.
append
(
int
(
row
[
"action_label"
]))
label_strings
.
append
(
label_map
[
int
(
row
[
"action_label"
])])
confidences
.
append
(
1
)
else
:
logging
.
warning
(
"Unknown label: %s"
,
row
[
"action_label"
])
middle_frame_time
+=
1
/
3
if
abs
(
middle_frame_time
-
round
(
middle_frame_time
)
<
0.0001
):
middle_frame_time
=
round
(
middle_frame_time
)
key
=
hashlib
.
sha256
(
bufstring
).
hexdigest
()
date_captured_feature
=
(
"2020-06-17 00:%02d:%02d"
%
((
middle_frame_time
-
900
)
*
3
//
60
,
(
middle_frame_time
-
900
)
*
3
%
60
))
context_feature_dict
=
{
'image/height'
:
dataset_util
.
int64_feature
(
int
(
height
)),
'image/width'
:
dataset_util
.
int64_feature
(
int
(
width
)),
'image/format'
:
dataset_util
.
bytes_feature
(
'jpeg'
.
encode
(
'utf8'
)),
'image/source_id'
:
dataset_util
.
bytes_feature
(
source_id
.
encode
(
"utf8"
)),
'image/filename'
:
dataset_util
.
bytes_feature
(
source_id
.
encode
(
"utf8"
)),
'image/encoded'
:
dataset_util
.
bytes_feature
(
bufstring
),
'image/key/sha256'
:
dataset_util
.
bytes_feature
(
key
.
encode
(
'utf8'
)),
'image/object/bbox/xmin'
:
dataset_util
.
float_list_feature
(
xmins
),
'image/object/bbox/xmax'
:
dataset_util
.
float_list_feature
(
xmaxs
),
'image/object/bbox/ymin'
:
dataset_util
.
float_list_feature
(
ymins
),
'image/object/bbox/ymax'
:
dataset_util
.
float_list_feature
(
ymaxs
),
'image/object/area'
:
dataset_util
.
float_list_feature
(
areas
),
'image/object/class/label'
:
dataset_util
.
int64_list_feature
(
labels
),
'image/object/class/text'
:
dataset_util
.
bytes_list_feature
(
label_strings
),
'image/location'
:
dataset_util
.
bytes_feature
(
media_id
.
encode
(
'utf8'
)),
'image/date_captured'
:
dataset_util
.
bytes_feature
(
date_captured_feature
.
encode
(
'utf8'
)),
'image/seq_num_frames'
:
dataset_util
.
int64_feature
(
total_non_excluded
),
'image/seq_frame_num'
:
dataset_util
.
int64_feature
(
cur_frame_num
),
'image/seq_id'
:
dataset_util
.
bytes_feature
(
media_id
.
encode
(
'utf8'
)),
}
yield
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
context_feature_dict
))
cur_vid
.
release
()
def
_download_data
(
self
):
def
_download_data
(
self
):
"""Downloads and extracts data if not already available."""
"""Downloads and extracts data if not already available."""
if
sys
.
version_info
>=
(
3
,
0
):
if
sys
.
version_info
>=
(
3
,
0
):
...
@@ -300,14 +456,27 @@ class Ava(object):
...
@@ -300,14 +456,27 @@ class Ava(object):
SPLITS
[
split
][
"excluded-csv"
]
=
excluded_csv_path
SPLITS
[
split
][
"excluded-csv"
]
=
excluded_csv_path
paths
[
split
]
=
(
csv_path
,
excluded_csv_path
)
paths
[
split
]
=
(
csv_path
,
excluded_csv_path
)
label_map
=
self
.
get_label_map
(
os
.
path
.
join
(
self
.
path_to_data_download
,
#label_map = self.get_label_map(os.path.join(self.path_to_data_download, "ava_action_list_v2.2.pbtxt"))
"ava_action_list_v2.2.pbtxt"
))
#
label_map
=
self
.
get_label_map
(
"object_detection/data/mscoco_label_map.pbtxt"
)
return
paths
,
label_map
return
paths
,
label_map
def
get_label_map
(
self
,
path
):
def
get_label_map
(
self
,
path
):
"""Parses
s
a label map into {integer:string} format."""
"""Parses a label map into {integer:string} format."""
label_map
=
{}
label_map
=
{}
with
open
(
path
,
"r"
)
as
f
:
label_map
=
label_map_util
.
load_labelmap
(
path
)
print
(
label_map
)
label_map_dict
=
{}
for
item
in
label_map
.
item
:
label_map_dict
[
item
.
name
]
=
item
.
label_id
with
open
(
path
,
"rb"
)
as
f
:
#label_map_util.load_labelmap()
#label_map_str = f.read()
#print(str(label_map_str))
#label_map = string_int_label_map_pb2.StringIntLabelMap()
#label_map.ParseFromString(label_map_str)
pass
"""
current_id = -1
current_id = -1
current_label = ""
current_label = ""
for line in f:
for line in f:
...
@@ -322,16 +491,12 @@ class Ava(object):
...
@@ -322,16 +491,12 @@ class Ava(object):
if "id:" in line:
if "id:" in line:
current_id = int(line.split()[1])
current_id = int(line.split()[1])
if "}" in line:
if "}" in line:
label_map
[
current_id
]
=
bytes23
(
current_label
)
label_map[current_id] = bytes(current_label, "utf8")"""
logging
.
info
(
label_map
)
print
(
'label map dict'
)
logging
.
info
(
label_map_dict
)
assert
len
(
label_map
)
==
NUM_CLASSES
assert
len
(
label_map
)
==
NUM_CLASSES
return
label_map
return
label_map
def
bytes23
(
string
):
"""Creates a bytes string in either Python 2 or 3."""
if
sys
.
version_info
>=
(
3
,
0
):
return
bytes
(
string
,
"utf8"
)
return
bytes
(
string
)
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
_close_on_exit
(
writers
):
def
_close_on_exit
(
writers
):
...
@@ -350,7 +515,8 @@ def main(argv):
...
@@ -350,7 +515,8 @@ def main(argv):
flags
.
FLAGS
.
splits_to_process
,
flags
.
FLAGS
.
splits_to_process
,
flags
.
FLAGS
.
video_path_format_string
,
flags
.
FLAGS
.
video_path_format_string
,
flags
.
FLAGS
.
seconds_per_sequence
,
flags
.
FLAGS
.
seconds_per_sequence
,
flags
.
FLAGS
.
hop_between_sequences
)
flags
.
FLAGS
.
hop_between_sequences
,
flags
.
FLAGS
.
examples_for_context
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
DEFINE_string
(
"path_to_download_data"
,
flags
.
DEFINE_string
(
"path_to_download_data"
,
...
@@ -375,4 +541,8 @@ if __name__ == "__main__":
...
@@ -375,4 +541,8 @@ if __name__ == "__main__":
10
,
10
,
"The hop between sequences. If less than "
"The hop between sequences. If less than "
"seconds_per_sequence, will overlap."
)
"seconds_per_sequence, will overlap."
)
flags
.
DEFINE_boolean
(
"examples_for_context"
,
False
,
"Whether to generate examples instead of sequence examples. "
"If true, will generate tf.Example objects for use in Context R-CNN."
)
app
.
run
(
main
)
app
.
run
(
main
)
research/object_detection/utils/label_map_util.py
View file @
b152ed9c
...
@@ -152,7 +152,7 @@ def load_labelmap(path):
...
@@ -152,7 +152,7 @@ def load_labelmap(path):
Returns:
Returns:
a StringIntLabelMapProto
a StringIntLabelMapProto
"""
"""
with
tf
.
io
.
gfile
.
GFile
(
path
,
'r'
)
as
fid
:
with
tf
.
io
.
gfile
.
GFile
(
path
,
'r
b
'
)
as
fid
:
label_map_string
=
fid
.
read
()
label_map_string
=
fid
.
read
()
label_map
=
string_int_label_map_pb2
.
StringIntLabelMap
()
label_map
=
string_int_label_map_pb2
.
StringIntLabelMap
()
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment