Unverified Commit 375e4ab2 authored by Sahil Goyal's avatar Sahil Goyal Committed by GitHub
Browse files

update urls for kinetics dataset (#5578)



* update urls for kinetics dataset

* update urls for kinetics dataset

* remove errors

* update the changes and add test option to split

* added test to valid values for split arg

* change .txt to .csv for annotation url of k600
Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 6c126d2e
import csv import csv
import os import os
import time import time
import urllib
import warnings import warnings
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
...@@ -53,7 +54,7 @@ class Kinetics(VisionDataset): ...@@ -53,7 +54,7 @@ class Kinetics(VisionDataset):
Note: split is appended automatically using the split argument. Note: split is appended automatically using the split argument.
frames_per_clip (int): number of frames in a clip frames_per_clip (int): number of frames in a clip
num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700 num_classes (int): select between Kinetics-400 (default), Kinetics-600, and Kinetics-700
split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` split (str): split of the dataset to consider; supports ``"train"`` (default) ``"val"`` ``"test"``
frame_rate (float): If omitted, interpolate different frame rate for each clip. frame_rate (float): If omitted, interpolate different frame rate for each clip.
step_between_clips (int): number of frames between each clip step_between_clips (int): number of frames between each clip
transform (callable, optional): A function/transform that takes in a TxHxWxC video transform (callable, optional): A function/transform that takes in a TxHxWxC video
...@@ -81,7 +82,7 @@ class Kinetics(VisionDataset): ...@@ -81,7 +82,7 @@ class Kinetics(VisionDataset):
} }
_ANNOTATION_URLS = { _ANNOTATION_URLS = {
"400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv", "400": "https://s3.amazonaws.com/kinetics/400/annotations/{split}.csv",
"600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.txt", "600": "https://s3.amazonaws.com/kinetics/600/annotations/{split}.csv",
"700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv", "700": "https://s3.amazonaws.com/kinetics/700_2020/annotations/{split}.csv",
} }
...@@ -122,7 +123,7 @@ class Kinetics(VisionDataset): ...@@ -122,7 +123,7 @@ class Kinetics(VisionDataset):
raise ValueError("Cannot download the videos using legacy_structure.") raise ValueError("Cannot download the videos using legacy_structure.")
else: else:
self.split_folder = path.join(root, split) self.split_folder = path.join(root, split)
self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"]) self.split = verify_str_arg(split, arg="split", valid_values=["train", "val", "test"])
if download: if download:
self.download_and_process_videos() self.download_and_process_videos()
...@@ -177,17 +178,16 @@ class Kinetics(VisionDataset): ...@@ -177,17 +178,16 @@ class Kinetics(VisionDataset):
split_url_filepath = path.join(file_list_path, path.basename(split_url)) split_url_filepath = path.join(file_list_path, path.basename(split_url))
if not check_integrity(split_url_filepath): if not check_integrity(split_url_filepath):
download_url(split_url, file_list_path) download_url(split_url, file_list_path)
list_video_urls = open(split_url_filepath) with open(split_url_filepath) as file:
list_video_urls = [urllib.parse.quote(line, safe="/,:") for line in file.read().splitlines()]
if self.num_download_workers == 1: if self.num_download_workers == 1:
for line in list_video_urls.readlines(): for line in list_video_urls:
line = str(line).replace("\n", "")
download_and_extract_archive(line, tar_path, self.split_folder) download_and_extract_archive(line, tar_path, self.split_folder)
else: else:
part = partial(_dl_wrap, tar_path, self.split_folder) part = partial(_dl_wrap, tar_path, self.split_folder)
lines = [str(line).replace("\n", "") for line in list_video_urls.readlines()]
poolproc = Pool(self.num_download_workers) poolproc = Pool(self.num_download_workers)
poolproc.map(part, lines) poolproc.map(part, list_video_urls)
def _make_ds_structure(self) -> None: def _make_ds_structure(self) -> None:
"""move videos from """move videos from
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment