Commit 5807078c authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Split extract_archive into dedicated functions. (#2927)

Summary:
`extra_archive` in `datasets.utils` does not distinguish the input type, and blindly treats it as tar, then zip in case of failure.

This is an anti-pattern. All the dataset implementations know which archive type the downloaded files are.

This commit splits extract_archive function into dedicated functions, and make each dataset use the correct one.

Pull Request resolved: https://github.com/pytorch/audio/pull/2927

Reviewed By: carolineechen

Differential Revision: D42154069

Pulled By: mthrok

fbshipit-source-id: bc46cc2af26aa086ef49aa1f9a94b6dedb55f85e
parent d744f33f
......@@ -7,7 +7,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
URL = "aew"
FOLDER_IN_ARCHIVE = "ARCTIC"
......@@ -119,7 +119,7 @@ class CMUARCTIC(Dataset):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
else:
if not os.path.exists(self._path):
raise RuntimeError(
......
......@@ -5,7 +5,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_zip
_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
......@@ -52,7 +52,7 @@ class DR_VCTK(Dataset):
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
download_url_to_file(url, archive, hash_prefix=_CHECKSUM)
extract_archive(archive, root)
_extract_zip(archive, root)
self._config = self._load_config(self._config_filepath)
self._filename_list = sorted(self._config)
......
......@@ -6,7 +6,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
# The following lists prefixed with `filtered_` provide a filtered split
# that:
......@@ -1052,7 +1052,7 @@ class GTZAN(Dataset):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
......
......@@ -7,7 +7,7 @@ from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.librispeech import _get_librispeech_metadata
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
_ARCHIVE_NAME = "librispeech_finetuning"
......@@ -78,7 +78,7 @@ class LibriLightLimited(Dataset):
raise RuntimeError("Dataset not found. Please use `download=True` to download")
if not os.path.isfile(archive):
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive)
_extract_tar(archive)
self._fileids_paths = _get_fileids_paths(self._path, folders, self._ext_audio)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
......
......@@ -5,7 +5,7 @@ from typing import Tuple, Union
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform, extract_archive
from torchaudio.datasets.utils import _extract_tar, _load_waveform
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
......@@ -40,7 +40,7 @@ def _download_librispeech(root, url):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(download_url, None)
download_url_to_file(download_url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
def _get_librispeech_metadata(
......
......@@ -6,7 +6,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriTTS"
......@@ -121,7 +121,7 @@ class LIBRITTS(Dataset):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
else:
if not os.path.exists(self._path):
raise RuntimeError(
......
......@@ -7,7 +7,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
_RELEASE_CONFIGS = {
......@@ -59,7 +59,7 @@ class LJSPEECH(Dataset):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
else:
if not os.path.exists(self._path):
raise RuntimeError(
......
......@@ -6,7 +6,7 @@ import torch
import torchaudio
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_zip
_URL = "https://zenodo.org/record/3338373/files/musdb18hq.zip"
_CHECKSUM = "baac80d0483c61d74b2e5f3be75fa557eec52898339e6aa45c1fa48833c5d21d"
......@@ -74,7 +74,7 @@ class MUSDB_HQ(Dataset):
raise RuntimeError("Dataset not found. Please use `download=True` to download")
download_url_to_file(_URL, archive, hash_prefix=_CHECKSUM)
os.makedirs(base_path, exist_ok=True)
extract_archive(archive, base_path)
_extract_zip(archive, base_path)
self.names = self._collect_songs()
......
......@@ -6,7 +6,7 @@ from typing import Optional, Tuple, Union
import torch
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform, extract_archive
from torchaudio.datasets.utils import _extract_tar, _load_waveform
URL = "https://speech.fit.vutbr.cz/files/quesst14Database.tgz"
......@@ -62,7 +62,7 @@ class QUESST14(Dataset):
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download")
download_url_to_file(URL, archive, hash_prefix=_CHECKSUM)
extract_archive(archive, root)
_extract_tar(archive, root)
if subset == "docs":
self.data = filter_audio_paths(self._path, language, "language_key_utterances.lst")
......
......@@ -5,7 +5,7 @@ from typing import Optional, Tuple, Union
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform, extract_archive
from torchaudio.datasets.utils import _extract_tar, _load_waveform
FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
......@@ -107,7 +107,7 @@ class SPEECHCOMMANDS(Dataset):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive, self._path)
_extract_tar(archive, self._path)
else:
if not os.path.exists(self._path):
raise RuntimeError(
......
......@@ -6,7 +6,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
_RELEASE_CONFIGS = {
......@@ -106,7 +106,7 @@ class TEDLIUM(Dataset):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS[release]["checksum"]
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
else:
if not os.path.exists(self._path):
raise RuntimeError(
......
......@@ -7,61 +7,39 @@ from typing import Any, List, Optional
import torchaudio
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
(Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
List[str]: List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchaudio.datasets.utils.download_from_url(url, from_path)
>>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
"""
def _extract_tar(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path))
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file {}.".format(from_path))
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
tar.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def _extract_zip(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file {}.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
def _load_waveform(
......
......@@ -5,7 +5,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_zip
URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
_CHECKSUMS = {
......@@ -59,7 +59,7 @@ class VCTK_092(Dataset):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive, self._path)
_extract_zip(archive, self._path)
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
......
......@@ -5,7 +5,7 @@ from typing import List, Tuple, Union
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import _load_waveform, extract_archive
from torchaudio.datasets.utils import _extract_zip, _load_waveform
SAMPLE_RATE = 16000
......@@ -54,7 +54,7 @@ def _download_extract_wavs(root: str):
url = _ARCHIVE_CONFIGS[archive]["url"]
checksum = _ARCHIVE_CONFIGS[archive]["checksum"]
download_url_to_file(url, archive_path, hash_prefix=checksum)
extract_archive(archive_path)
_extract_zip(archive_path)
def _get_flist(root: str, file_path: str, subset: str) -> List[str]:
......
......@@ -6,7 +6,7 @@ import torchaudio
from torch import Tensor
from torch.hub import download_url_to_file
from torch.utils.data import Dataset
from torchaudio.datasets.utils import extract_archive
from torchaudio.datasets.utils import _extract_tar
_RELEASE_CONFIGS = {
......@@ -52,7 +52,7 @@ class YESNO(Dataset):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url_to_file(url, archive, hash_prefix=checksum)
extract_archive(archive)
_extract_tar(archive)
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment