Unverified Commit 58ef8fb6 authored by Aziz's avatar Aziz Committed by GitHub
Browse files

Remove deprecated transform from Dataset (#1120)

parent 47c2040e
import os import os
import warnings import warnings
from typing import Any, Tuple, Union
from pathlib import Path from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
...@@ -66,8 +67,6 @@ class VCTK(Dataset): ...@@ -66,8 +67,6 @@ class VCTK(Dataset):
Giving ``download=True`` will result in error as the dataset is no longer Giving ``download=True`` will result in error as the dataset is no longer
publicly available. publicly available.
downsample (bool, optional): Not used. downsample (bool, optional): Not used.
transform (callable, optional): Optional transform applied on waveform. (default: ``None``)
target_transform (callable, optional): Optional transform applied on utterance. (default: ``None``)
""" """
_folder_txt = "txt" _folder_txt = "txt"
...@@ -81,9 +80,7 @@ class VCTK(Dataset): ...@@ -81,9 +80,7 @@ class VCTK(Dataset):
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False,
downsample: bool = False, downsample: bool = False) -> None:
transform: Any = None,
target_transform: Any = None) -> None:
if downsample: if downsample:
warnings.warn( warnings.warn(
...@@ -92,17 +89,7 @@ class VCTK(Dataset): ...@@ -92,17 +89,7 @@ class VCTK(Dataset):
"and suppress this warning." "and suppress this warning."
) )
if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning."
)
self.downsample = downsample self.downsample = downsample
self.transform = transform
self.target_transform = target_transform
# Get string representation of 'root' in case Path object is passed # Get string representation of 'root' in case Path object is passed
root = os.fspath(root) root = os.fspath(root)
...@@ -149,10 +136,6 @@ class VCTK(Dataset): ...@@ -149,10 +136,6 @@ class VCTK(Dataset):
# return item # return item
waveform, sample_rate, utterance, speaker_id, utterance_id = item waveform, sample_rate, utterance, speaker_id, utterance_id = item
if self.transform is not None:
waveform = self.transform(waveform)
if self.target_transform is not None:
utterance = self.target_transform(utterance)
return waveform, sample_rate, utterance, speaker_id, utterance_id return waveform, sample_rate, utterance, speaker_id, utterance_id
def __len__(self) -> int: def __len__(self) -> int:
......
import os import os
import warnings
from typing import Any, List, Tuple, Union
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union
import torchaudio
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import ( from torchaudio.datasets.utils import (
download_url, download_url,
extract_archive, extract_archive,
...@@ -41,8 +41,6 @@ class YESNO(Dataset): ...@@ -41,8 +41,6 @@ class YESNO(Dataset):
The top-level directory of the dataset. (default: ``"waves_yesno"``) The top-level directory of the dataset. (default: ``"waves_yesno"``)
download (bool, optional): download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``). Whether to download the dataset if it is not found at root path. (default: ``False``).
transform (callable, optional): Optional transform applied on waveform. (default: ``None``)
target_transform (callable, optional): Optional transform applied on utterance. (default: ``None``)
""" """
_ext_audio = ".wav" _ext_audio = ".wav"
...@@ -51,19 +49,7 @@ class YESNO(Dataset): ...@@ -51,19 +49,7 @@ class YESNO(Dataset):
root: Union[str, Path], root: Union[str, Path],
url: str = URL, url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE, folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False, download: bool = False) -> None:
transform: Any = None,
target_transform: Any = None) -> None:
if transform is not None or target_transform is not None:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please remove the option `transform=True` and "
"`target_transform=True` to suppress this warning."
)
self.transform = transform
self.target_transform = target_transform
# Get string representation of 'root' in case Path object is passed # Get string representation of 'root' in case Path object is passed
root = os.fspath(root) root = os.fspath(root)
...@@ -102,10 +88,6 @@ class YESNO(Dataset): ...@@ -102,10 +88,6 @@ class YESNO(Dataset):
# return item # return item
waveform, sample_rate, labels = item waveform, sample_rate, labels = item
if self.transform is not None:
waveform = self.transform(waveform)
if self.target_transform is not None:
labels = self.target_transform(labels)
return waveform, sample_rate, labels return waveform, sample_rate, labels
def __len__(self) -> int: def __len__(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment