Unverified Commit 9a34e7c0 authored by kingyiusuen's avatar kingyiusuen Committed by GitHub
Browse files

Add DR-VCTK dataset (#1819)

parent e40c9c3c
...@@ -105,6 +105,14 @@ VCTK_092 ...@@ -105,6 +105,14 @@ VCTK_092
:special-members: __getitem__ :special-members: __getitem__
DR_VCTK
~~~~~~~~
.. autoclass:: DR_VCTK
:members:
:special-members: __getitem__
YESNO YESNO
~~~~~ ~~~~~
......
from pathlib import Path
import pytest
from torchaudio.datasets import dr_vctk
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_whitenoise,
save_wav,
)
_SUBSETS = ["train", "test"]
_CONDITIONS = ["clean", "device-recorded"]
_SOURCES = ["DR-VCTK_Office1_ClosedWindow", "DR-VCTK_Office1_OpenedWindow"]
_SPEAKER_IDS = range(226, 230)
_CHANNEL_IDS = range(1, 6)
def get_mock_dataset(root_dir):
"""
root_dir: root directory of the mocked data
"""
mocked_samples = {}
dataset_dir = Path(root_dir) / "DR-VCTK" / "DR-VCTK"
dataset_dir.mkdir(parents=True, exist_ok=True)
config_dir = dataset_dir / "configurations"
config_dir.mkdir(parents=True, exist_ok=True)
sample_rate = 16000
seed = 0
for subset in _SUBSETS:
mocked_samples[subset] = []
for condition in _CONDITIONS:
audio_dir = dataset_dir / f"{condition}_{subset}set_wav_16k"
audio_dir.mkdir(parents=True, exist_ok=True)
config_filepath = config_dir / f"{subset}_ch_log.txt"
with open(config_filepath, "w") as f:
if subset == "train":
f.write("\n")
f.write("File Name\tMain Source\tChannel Idx\n")
for speaker_id in _SPEAKER_IDS:
utterance_id = 1
for source in _SOURCES:
for channel_id in _CHANNEL_IDS:
filename = f"p{speaker_id}_{utterance_id:03d}.wav"
f.write(f"{filename}\t{source}\t{channel_id}\n")
data = {}
for condition in _CONDITIONS:
data[condition] = get_whitenoise(
sample_rate=sample_rate,
duration=0.01,
n_channels=1,
dtype='float32',
seed=seed
)
audio_dir = dataset_dir / f"{condition}_{subset}set_wav_16k"
audio_file_path = audio_dir / filename
save_wav(audio_file_path, data[condition], sample_rate)
seed += 1
sample = (
data[_CONDITIONS[0]],
sample_rate,
data[_CONDITIONS[1]],
sample_rate,
"p" + str(speaker_id),
f"{utterance_id:03d}",
source,
channel_id,
)
mocked_samples[subset].append(sample)
utterance_id += 1
return mocked_samples
class TestDRVCTK(TempDirMixin, TorchaudioTestCase):
backend = 'default'
root_dir = None
samples = {}
@classmethod
def setUpClass(cls):
cls.root_dir = cls.get_base_temp_dir()
cls.samples = get_mock_dataset(cls.root_dir)
def _test_dr_vctk(self, dataset, subset):
num_samples = 0
for i, (
waveform_clean,
sample_rate_clean,
waveform_dr,
sample_rate_dr,
speaker_id,
utterance_id,
source,
channel_id,
) in enumerate(dataset):
self.assertEqual(waveform_clean, self.samples[subset][i][0], atol=5e-5, rtol=1e-8)
assert sample_rate_clean == self.samples[subset][i][1]
self.assertEqual(waveform_dr, self.samples[subset][i][2], atol=5e-5, rtol=1e-8)
assert sample_rate_dr == self.samples[subset][i][3]
assert speaker_id == self.samples[subset][i][4]
assert utterance_id == self.samples[subset][i][5]
assert source == self.samples[subset][i][6]
assert channel_id == self.samples[subset][i][7]
num_samples += 1
assert num_samples == len(self.samples[subset])
def test_dr_vctk_train_str(self):
subset = "train"
dataset = dr_vctk.DR_VCTK(self.root_dir, subset=subset)
self._test_dr_vctk(dataset, subset)
def test_dr_vctk_test_str(self):
subset = "test"
dataset = dr_vctk.DR_VCTK(self.root_dir, subset=subset)
self._test_dr_vctk(dataset, subset)
def test_dr_vctk_train_path(self):
subset = "train"
dataset = dr_vctk.DR_VCTK(Path(self.root_dir), subset=subset)
self._test_dr_vctk(dataset, subset)
def test_dr_vctk_test_path(self):
subset = "test"
dataset = dr_vctk.DR_VCTK(Path(self.root_dir), subset=subset)
self._test_dr_vctk(dataset, subset)
def test_dr_vctk_invalid_subset(self):
subset = "invalid"
with pytest.raises(RuntimeError, match=f"The subset '{subset}' does not match any of the supported subsets"):
dr_vctk.DR_VCTK(self.root_dir, subset=subset)
...@@ -2,6 +2,7 @@ from .commonvoice import COMMONVOICE ...@@ -2,6 +2,7 @@ from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS from .speechcommands import SPEECHCOMMANDS
from .vctk import VCTK_092 from .vctk import VCTK_092
from .dr_vctk import DR_VCTK
from .gtzan import GTZAN from .gtzan import GTZAN
from .yesno import YESNO from .yesno import YESNO
from .ljspeech import LJSPEECH from .ljspeech import LJSPEECH
...@@ -16,6 +17,7 @@ __all__ = [ ...@@ -16,6 +17,7 @@ __all__ = [
"LIBRISPEECH", "LIBRISPEECH",
"SPEECHCOMMANDS", "SPEECHCOMMANDS",
"VCTK_092", "VCTK_092",
"DR_VCTK",
"YESNO", "YESNO",
"LJSPEECH", "LJSPEECH",
"GTZAN", "GTZAN",
......
from pathlib import Path
from typing import Dict, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import (
download_url,
extract_archive,
validate_file,
)
_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"
_CHECKSUM = "29e93debeb0e779986542229a81ff29b"
_SUPPORTED_SUBSETS = {"train", "test"}
class DR_VCTK(Dataset):
"""Create a dataset for Device Recorded VCTK (Small subset version).
Args:
root (str or Path): Root directory where the dataset's top level directory is found.
subset (str): The subset to use. Can be one of ``"train"`` and ``"test"``. (default: ``"train"``).
download (bool):
Whether to download the dataset if it is not found at root path. (default: ``False``).
url (str): The URL to download the dataset from.
(default: ``"https://datashare.ed.ac.uk/bitstream/handle/10283/3038/DR-VCTK.zip"``)
"""
def __init__(
self,
root: Union[str, Path],
subset: str = "train",
*,
download: bool = False,
url: str = _URL,
) -> None:
if subset not in _SUPPORTED_SUBSETS:
raise RuntimeError(
f"The subset '{subset}' does not match any of the supported subsets: {_SUPPORTED_SUBSETS}"
)
root = Path(root).expanduser()
archive = root / "DR-VCTK.zip"
self._subset = subset
self._path = root / "DR-VCTK" / "DR-VCTK"
self._clean_audio_dir = self._path / f"clean_{self._subset}set_wav_16k"
self._noisy_audio_dir = self._path / f"device-recorded_{self._subset}set_wav_16k"
self._config_filepath = self._path / "configurations" / f"{self._subset}_ch_log.txt"
if not self._path.is_dir():
if not archive.is_file():
if not download:
raise RuntimeError("Dataset not found. Please use `download=True` to download it.")
download_url(url, root)
self._validate_checksum(archive)
extract_archive(archive, root)
self._config = self._load_config(self._config_filepath)
self._filename_list = sorted(self._config)
def _validate_checksum(self, archive):
with open(archive, "rb") as file_obj:
if not validate_file(file_obj, _CHECKSUM, "md5"):
raise RuntimeError(
f"The hash of {str(archive)} does not match. Delete the file manually and retry."
)
def _load_config(self, filepath: str) -> Dict[str, Tuple[str, int]]:
# Skip header
skip_rows = 2 if self._subset == "train" else 1
config = {}
with open(filepath) as f:
for i, line in enumerate(f):
if i < skip_rows or not line:
continue
filename, source, channel_id = line.strip().split("\t")
config[filename] = (source, int(channel_id))
return config
def _load_dr_vctk_item(self, filename: str) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
speaker_id, utterance_id = filename.split(".")[0].split("_")
source, channel_id = self._config[filename]
file_clean_audio = self._clean_audio_dir / filename
file_noisy_audio = self._noisy_audio_dir / filename
waveform_clean, sample_rate_clean = torchaudio.load(file_clean_audio)
waveform_noisy, sample_rate_noisy = torchaudio.load(file_noisy_audio)
return (
waveform_clean,
sample_rate_clean,
waveform_noisy,
sample_rate_noisy,
speaker_id,
utterance_id,
source,
channel_id,
)
def __getitem__(self, n: int) -> Tuple[Tensor, int, Tensor, int, str, str, str, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id, utterance_id,\
source, channel_id)``
"""
filename = self._filename_list[n]
return self._load_dr_vctk_item(filename)
def __len__(self) -> int:
return len(self._filename_list)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment