Unverified Commit ac4b9f87 authored by Fernando Pérez-García's avatar Fernando Pérez-García Committed by GitHub
Browse files

Improve code readability and docstring (#2020)

* Improve code readability and docstring

* Remove unused argument

* Improve make_dataset() readability
parent 3c254fb7
......@@ -32,26 +32,28 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None):
images = []
dir = os.path.expanduser(dir)
if not ((extensions is None) ^ (is_valid_file is None)):
def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()):
d = os.path.join(dir, target)
if not os.path.isdir(d):
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(d, followlinks=True)):
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = (path, class_to_idx[target])
images.append(item)
return images
item = path, class_index
instances.append(item)
return instances
class DatasetFolder(VisionDataset):
......
......@@ -27,13 +27,13 @@ class HMDB51(VisionDataset):
Args:
root (string): Root directory of the HMDB51 Dataset.
annotation_path (str): path to the folder containing the split files
frames_per_clip (int): number of frames in a clip.
step_between_clips (int): number of frames between each clip.
fold (int, optional): which fold to use. Should be between 1 and 3.
train (bool, optional): if ``True``, creates a dataset from the train split,
annotation_path (str): Path to the folder containing the split files.
frames_per_clip (int): Number of frames in a clip.
step_between_clips (int): Number of frames between each clip.
fold (int, optional): Which fold to use. Should be between 1 and 3.
train (bool, optional): If ``True``, creates a dataset from the train split,
otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video
transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version.
Returns:
......@@ -48,26 +48,29 @@ class HMDB51(VisionDataset):
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5"
}
TRAIN_TAG = 1
TEST_TAG = 2
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0):
super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3:
if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',)
self.fold = fold
self.train = train
classes = sorted(list_dir(root))
class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.samples = make_dataset(
self.root,
class_to_idx,
extensions,
)
classes = list(sorted(list_dir(root)))
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_paths = [path for (path, _) in self.samples]
video_clips = VideoClips(
video_list,
video_paths,
frames_per_clip,
step_between_clips,
frame_rate,
......@@ -78,8 +81,11 @@ class HMDB51(VisionDataset):
_video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples,
)
self.fold = fold
self.train = train
self.classes = classes
self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train)
self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices)
self.transform = transform
......@@ -87,29 +93,38 @@ class HMDB51(VisionDataset):
def metadata(self):
return self.video_clips_metadata
def _select_fold(self, video_list, annotation_path, fold, train):
target_tag = 1 if train else 2
name = "*test_split{}.txt".format(fold)
files = glob.glob(os.path.join(annotation_path, name))
def _select_fold(self, video_list, annotations_dir, fold, train):
target_tag = self.TRAIN_TAG if train else self.TEST_TAG
split_pattern_name = "*test_split{}.txt".format(fold)
split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path)
selected_files = []
for f in files:
with open(f, "r") as fid:
data = fid.readlines()
data = [x.strip().split(" ") for x in data]
data = [x[0] for x in data if int(x[1]) == target_tag]
selected_files.extend(data)
for filepath in annotation_paths:
with open(filepath) as fid:
lines = fid.readlines()
for line in lines:
video_filename, tag_string = line.split()
tag = int(tag_string)
if tag == target_tag:
selected_files.append(video_filename)
selected_files = set(selected_files)
indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files]
indices = []
for video_index, video_path in enumerate(video_list):
if os.path.basename(video_path) in selected_files:
indices.append(video_index)
return indices
def __len__(self):
return self.video_clips.num_clips()
def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx)
label = self.samples[self.indices[video_idx]][1]
video, audio, _, video_idx = self.video_clips.get_clip(idx)
sample_index = self.indices[video_idx]
_, class_index = self.samples[sample_index]
if self.transform is not None:
video = self.transform(video)
return video, audio, label
return video, audio, class_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment