Unverified Commit ac4b9f87 authored by Fernando Pérez-García's avatar Fernando Pérez-García Committed by GitHub
Browse files

Improve code readability and docstring (#2020)

* Improve code readability and docstring

* Remove unused argument

* Improve make_dataset() readability
parent 3c254fb7
...@@ -32,26 +32,28 @@ def is_image_file(filename): ...@@ -32,26 +32,28 @@ def is_image_file(filename):
return has_file_allowed_extension(filename, IMG_EXTENSIONS) return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
images = [] instances = []
dir = os.path.expanduser(dir) directory = os.path.expanduser(directory)
if not ((extensions is None) ^ (is_valid_file is None)): both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None: if extensions is not None:
def is_valid_file(x): def is_valid_file(x):
return has_file_allowed_extension(x, extensions) return has_file_allowed_extension(x, extensions)
for target in sorted(class_to_idx.keys()): for target_class in sorted(class_to_idx.keys()):
d = os.path.join(dir, target) class_index = class_to_idx[target_class]
if not os.path.isdir(d): target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue continue
for root, _, fnames in sorted(os.walk(d, followlinks=True)): for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames): for fname in sorted(fnames):
path = os.path.join(root, fname) path = os.path.join(root, fname)
if is_valid_file(path): if is_valid_file(path):
item = (path, class_to_idx[target]) item = path, class_index
images.append(item) instances.append(item)
return instances
return images
class DatasetFolder(VisionDataset): class DatasetFolder(VisionDataset):
......
...@@ -27,13 +27,13 @@ class HMDB51(VisionDataset): ...@@ -27,13 +27,13 @@ class HMDB51(VisionDataset):
Args: Args:
root (string): Root directory of the HMDB51 Dataset. root (string): Root directory of the HMDB51 Dataset.
annotation_path (str): path to the folder containing the split files annotation_path (str): Path to the folder containing the split files.
frames_per_clip (int): number of frames in a clip. frames_per_clip (int): Number of frames in a clip.
step_between_clips (int): number of frames between each clip. step_between_clips (int): Number of frames between each clip.
fold (int, optional): which fold to use. Should be between 1 and 3. fold (int, optional): Which fold to use. Should be between 1 and 3.
train (bool, optional): if ``True``, creates a dataset from the train split, train (bool, optional): If ``True``, creates a dataset from the train split,
otherwise from the ``test`` split. otherwise from the ``test`` split.
transform (callable, optional): A function/transform that takes in a TxHxWxC video transform (callable, optional): A function/transform that takes in a TxHxWxC video
and returns a transformed version. and returns a transformed version.
Returns: Returns:
...@@ -48,26 +48,29 @@ class HMDB51(VisionDataset): ...@@ -48,26 +48,29 @@ class HMDB51(VisionDataset):
"url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar",
"md5": "15e67781e70dcfbdce2d7dbb9b3344b5" "md5": "15e67781e70dcfbdce2d7dbb9b3344b5"
} }
TRAIN_TAG = 1
TEST_TAG = 2
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
frame_rate=None, fold=1, train=True, transform=None, frame_rate=None, fold=1, train=True, transform=None,
_precomputed_metadata=None, num_workers=1, _video_width=0, _precomputed_metadata=None, num_workers=1, _video_width=0,
_video_height=0, _video_min_dimension=0, _audio_samples=0): _video_height=0, _video_min_dimension=0, _audio_samples=0):
super(HMDB51, self).__init__(root) super(HMDB51, self).__init__(root)
if not 1 <= fold <= 3: if fold not in (1, 2, 3):
raise ValueError("fold should be between 1 and 3, got {}".format(fold)) raise ValueError("fold should be between 1 and 3, got {}".format(fold))
extensions = ('avi',) extensions = ('avi',)
self.fold = fold classes = sorted(list_dir(root))
self.train = train class_to_idx = {class_: i for (i, class_) in enumerate(classes)}
self.samples = make_dataset(
self.root,
class_to_idx,
extensions,
)
classes = list(sorted(list_dir(root))) video_paths = [path for (path, _) in self.samples]
class_to_idx = {classes[i]: i for i in range(len(classes))}
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
self.classes = classes
video_list = [x[0] for x in self.samples]
video_clips = VideoClips( video_clips = VideoClips(
video_list, video_paths,
frames_per_clip, frames_per_clip,
step_between_clips, step_between_clips,
frame_rate, frame_rate,
...@@ -78,8 +81,11 @@ class HMDB51(VisionDataset): ...@@ -78,8 +81,11 @@ class HMDB51(VisionDataset):
_video_min_dimension=_video_min_dimension, _video_min_dimension=_video_min_dimension,
_audio_samples=_audio_samples, _audio_samples=_audio_samples,
) )
self.fold = fold
self.train = train
self.classes = classes
self.video_clips_metadata = video_clips.metadata self.video_clips_metadata = video_clips.metadata
self.indices = self._select_fold(video_list, annotation_path, fold, train) self.indices = self._select_fold(video_paths, annotation_path, fold, train)
self.video_clips = video_clips.subset(self.indices) self.video_clips = video_clips.subset(self.indices)
self.transform = transform self.transform = transform
...@@ -87,29 +93,38 @@ class HMDB51(VisionDataset): ...@@ -87,29 +93,38 @@ class HMDB51(VisionDataset):
def metadata(self): def metadata(self):
return self.video_clips_metadata return self.video_clips_metadata
def _select_fold(self, video_list, annotation_path, fold, train): def _select_fold(self, video_list, annotations_dir, fold, train):
target_tag = 1 if train else 2 target_tag = self.TRAIN_TAG if train else self.TEST_TAG
name = "*test_split{}.txt".format(fold) split_pattern_name = "*test_split{}.txt".format(fold)
files = glob.glob(os.path.join(annotation_path, name)) split_pattern_path = os.path.join(annotations_dir, split_pattern_name)
annotation_paths = glob.glob(split_pattern_path)
selected_files = [] selected_files = []
for f in files: for filepath in annotation_paths:
with open(f, "r") as fid: with open(filepath) as fid:
data = fid.readlines() lines = fid.readlines()
data = [x.strip().split(" ") for x in data] for line in lines:
data = [x[0] for x in data if int(x[1]) == target_tag] video_filename, tag_string = line.split()
selected_files.extend(data) tag = int(tag_string)
if tag == target_tag:
selected_files.append(video_filename)
selected_files = set(selected_files) selected_files = set(selected_files)
indices = [i for i in range(len(video_list)) if os.path.basename(video_list[i]) in selected_files]
indices = []
for video_index, video_path in enumerate(video_list):
if os.path.basename(video_path) in selected_files:
indices.append(video_index)
return indices return indices
def __len__(self): def __len__(self):
return self.video_clips.num_clips() return self.video_clips.num_clips()
def __getitem__(self, idx): def __getitem__(self, idx):
video, audio, info, video_idx = self.video_clips.get_clip(idx) video, audio, _, video_idx = self.video_clips.get_clip(idx)
label = self.samples[self.indices[video_idx]][1] sample_index = self.indices[video_idx]
_, class_index = self.samples[sample_index]
if self.transform is not None: if self.transform is not None:
video = self.transform(video) video = self.transform(video)
return video, audio, label return video, audio, class_index
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment