Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
cb40dd72
Unverified
Commit
cb40dd72
authored
Oct 18, 2021
by
Caroline Chen
Committed by
GitHub
Oct 18, 2021
Browse files
[DOC] Standardization and minor fixes (#1892)
parent
955cdbdc
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
46 additions
and
40 deletions
+46
-40
torchaudio/backend/soundfile_backend.py
torchaudio/backend/soundfile_backend.py
+7
-7
torchaudio/backend/sox_io_backend.py
torchaudio/backend/sox_io_backend.py
+7
-7
torchaudio/datasets/cmuarctic.py
torchaudio/datasets/cmuarctic.py
+1
-1
torchaudio/datasets/cmudict.py
torchaudio/datasets/cmudict.py
+1
-1
torchaudio/datasets/commonvoice.py
torchaudio/datasets/commonvoice.py
+2
-2
torchaudio/datasets/dr_vctk.py
torchaudio/datasets/dr_vctk.py
+3
-2
torchaudio/datasets/gtzan.py
torchaudio/datasets/gtzan.py
+1
-1
torchaudio/datasets/librimix.py
torchaudio/datasets/librimix.py
+1
-1
torchaudio/datasets/librispeech.py
torchaudio/datasets/librispeech.py
+2
-1
torchaudio/datasets/libritts.py
torchaudio/datasets/libritts.py
+2
-2
torchaudio/datasets/ljspeech.py
torchaudio/datasets/ljspeech.py
+2
-1
torchaudio/datasets/speechcommands.py
torchaudio/datasets/speechcommands.py
+2
-1
torchaudio/datasets/tedlium.py
torchaudio/datasets/tedlium.py
+2
-1
torchaudio/datasets/utils.py
torchaudio/datasets/utils.py
+1
-1
torchaudio/datasets/vctk.py
torchaudio/datasets/vctk.py
+2
-1
torchaudio/datasets/yesno.py
torchaudio/datasets/yesno.py
+1
-1
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+3
-3
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+2
-2
torchaudio/models/conv_tasnet.py
torchaudio/models/conv_tasnet.py
+3
-3
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+1
-1
No files found.
torchaudio/backend/soundfile_backend.py
View file @
cb40dd72
...
@@ -146,7 +146,7 @@ def load(
...
@@ -146,7 +146,7 @@ def load(
* SPHERE
* SPHERE
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of
`
`[channel, time]`
`
.
``float32`` dtype and the shape of `[channel, time]`.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
...
@@ -182,16 +182,16 @@ def load(
...
@@ -182,16 +182,16 @@ def load(
integer type.
integer type.
This argument has no effect for formats other than integer WAV type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool, optional):
channels_first (bool, optional):
When True, the returned Tensor has dimension
`
`[channel, time]`
`
.
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is
`
`[time, channel]`
`
.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
format (str or None, optional):
Not used. PySoundFile does not accept format hint.
Not used. PySoundFile does not accept format hint.
Returns:
Returns:
Tuple[
torch.Tensor, int
]
: Resulting Tensor and sample rate.
(
torch.Tensor, int
)
: Resulting Tensor and sample rate.
If the input file has integer wav format and normalization is off, then it has
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
`
`[channel, time]`
`
else
`
`[time, channel]`
`
.
`[channel, time]` else `[time, channel]`.
"""
"""
with
soundfile
.
SoundFile
(
filepath
,
"r"
)
as
file_
:
with
soundfile
.
SoundFile
(
filepath
,
"r"
)
as
file_
:
if
file_
.
format
!=
"WAV"
or
normalize
:
if
file_
.
format
!=
"WAV"
or
normalize
:
...
@@ -335,8 +335,8 @@ def save(
...
@@ -335,8 +335,8 @@ def save(
filepath (str or pathlib.Path): Path to audio file.
filepath (str or pathlib.Path): Path to audio file.
src (torch.Tensor): Audio data to save. must be 2D tensor.
src (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
sample_rate (int): sampling rate
channels_first (bool, optional): If ``True``, the given tensor is interpreted as
`
`[channel, time]`
`
,
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
otherwise
`
`[time, channel]`
`
.
otherwise `[time, channel]`.
compression (float of None, optional): Not used.
compression (float of None, optional): Not used.
It is here only for interface compatibility reson with "sox_io" backend.
It is here only for interface compatibility reson with "sox_io" backend.
format (str or None, optional): Override the audio format.
format (str or None, optional): Override the audio format.
...
...
torchaudio/backend/sox_io_backend.py
View file @
cb40dd72
...
@@ -89,7 +89,7 @@ def load(
...
@@ -89,7 +89,7 @@ def load(
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
and corresponding codec libraries such as ``libmad`` or ``libmp3lame`` etc.
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
By default (``normalize=True``, ``channels_first=True``), this function returns Tensor with
``float32`` dtype and the shape of
`
`[channel, time]`
`
.
``float32`` dtype and the shape of `[channel, time]`.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
...
@@ -131,18 +131,18 @@ def load(
...
@@ -131,18 +131,18 @@ def load(
integer type.
integer type.
This argument has no effect for formats other than integer WAV type.
This argument has no effect for formats other than integer WAV type.
channels_first (bool, optional):
channels_first (bool, optional):
When True, the returned Tensor has dimension
`
`[channel, time]`
`
.
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is
`
`[time, channel]`
`
.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
format (str or None, optional):
Override the format detection with the given format.
Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format
Providing the argument might help when libsox can not infer the format
from header or extension,
from header or extension,
Returns:
Returns:
Tuple[
torch.Tensor, int
]
: Resulting Tensor and sample rate.
(
torch.Tensor, int
)
: Resulting Tensor and sample rate.
If the input file has integer wav format and normalization is off, then it has
If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
`
`[channel, time]`
`
else
`
`[time, channel]`
`
.
`[channel, time]` else `[time, channel]`.
"""
"""
if
not
torch
.
jit
.
is_scripting
():
if
not
torch
.
jit
.
is_scripting
():
if
hasattr
(
filepath
,
'read'
):
if
hasattr
(
filepath
,
'read'
):
...
@@ -172,8 +172,8 @@ def save(
...
@@ -172,8 +172,8 @@ def save(
as ``str`` for TorchScript compiler compatibility.
as ``str`` for TorchScript compiler compatibility.
src (torch.Tensor): Audio data to save. must be 2D tensor.
src (torch.Tensor): Audio data to save. must be 2D tensor.
sample_rate (int): sampling rate
sample_rate (int): sampling rate
channels_first (bool, optional): If ``True``, the given tensor is interpreted as
`
`[channel, time]`
`
,
channels_first (bool, optional): If ``True``, the given tensor is interpreted as `[channel, time]`,
otherwise
`
`[time, channel]`
`
.
otherwise `[time, channel]`.
compression (float or None, optional): Used for formats other than WAV.
compression (float or None, optional): Used for formats other than WAV.
This corresponds to ``-C`` option of ``sox`` command.
This corresponds to ``-C`` option of ``sox`` command.
...
...
torchaudio/datasets/cmuarctic.py
View file @
cb40dd72
...
@@ -164,7 +164,7 @@ class CMUARCTIC(Dataset):
...
@@ -164,7 +164,7 @@ class CMUARCTIC(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple
: ``(waveform, sample_rate, transcript, utterance_id)``
(Tensor, int, str, str)
: ``(waveform, sample_rate, transcript, utterance_id)``
"""
"""
line
=
self
.
_walker
[
n
]
line
=
self
.
_walker
[
n
]
return
load_cmuarctic_item
(
line
,
self
.
_path
,
self
.
_folder_audio
,
self
.
_ext_audio
)
return
load_cmuarctic_item
(
line
,
self
.
_path
,
self
.
_folder_audio
,
self
.
_ext_audio
)
...
...
torchaudio/datasets/cmudict.py
View file @
cb40dd72
...
@@ -167,7 +167,7 @@ class CMUDict(Dataset):
...
@@ -167,7 +167,7 @@ class CMUDict(Dataset):
n (int): The index of the sample to be loaded.
n (int): The index of the sample to be loaded.
Returns:
Returns:
tuple
: The corresponding word and phonemes ``(word, [phonemes])``.
(str, List[str])
: The corresponding word and phonemes ``(word, [phonemes])``.
"""
"""
return
self
.
_dictionary
[
n
]
return
self
.
_dictionary
[
n
]
...
...
torchaudio/datasets/commonvoice.py
View file @
cb40dd72
...
@@ -65,8 +65,8 @@ class COMMONVOICE(Dataset):
...
@@ -65,8 +65,8 @@ class COMMONVOICE(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple
: ``(waveform, sample_rate, dictionary)``, where dictionary
is built
(Tensor, int, Dict[str, str])
: ``(waveform, sample_rate, dictionary)``, where dictionary
from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``,
is built
from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``,
``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``.
``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``.
"""
"""
line
=
self
.
_walker
[
n
]
line
=
self
.
_walker
[
n
]
...
...
torchaudio/datasets/dr_vctk.py
View file @
cb40dd72
...
@@ -107,8 +107,9 @@ class DR_VCTK(Dataset):
...
@@ -107,8 +107,9 @@ class DR_VCTK(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id, utterance_id,
\
(Tensor, int, Tensor, int, str, str, str, int):
source, channel_id)``
``(waveform_clean, sample_rate_clean, waveform_noisy, sample_rate_noisy, speaker_id,
\
utterance_id, source, channel_id)``
"""
"""
filename
=
self
.
_filename_list
[
n
]
filename
=
self
.
_filename_list
[
n
]
return
self
.
_load_dr_vctk_item
(
filename
)
return
self
.
_load_dr_vctk_item
(
filename
)
...
...
torchaudio/datasets/gtzan.py
View file @
cb40dd72
...
@@ -1102,7 +1102,7 @@ class GTZAN(Dataset):
...
@@ -1102,7 +1102,7 @@ class GTZAN(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple
: ``(waveform, sample_rate, label)``
(Tensor, int, str)
: ``(waveform, sample_rate, label)``
"""
"""
fileid
=
self
.
_walker
[
n
]
fileid
=
self
.
_walker
[
n
]
item
=
load_gtzan_item
(
fileid
,
self
.
_path
,
self
.
_ext_audio
)
item
=
load_gtzan_item
(
fileid
,
self
.
_path
,
self
.
_ext_audio
)
...
...
torchaudio/datasets/librimix.py
View file @
cb40dd72
...
@@ -84,6 +84,6 @@ class LibriMix(Dataset):
...
@@ -84,6 +84,6 @@ class LibriMix(Dataset):
Args:
Args:
key (int): The index of the sample to be loaded
key (int): The index of the sample to be loaded
Returns:
Returns:
tuple
: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
(int, Tensor, List[Tensor])
: ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
"""
return
self
.
_load_sample
(
self
.
files
[
key
])
return
self
.
_load_sample
(
self
.
files
[
key
])
torchaudio/datasets/librispeech.py
View file @
cb40dd72
...
@@ -133,7 +133,8 @@ class LIBRISPEECH(Dataset):
...
@@ -133,7 +133,8 @@ class LIBRISPEECH(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
"""
fileid
=
self
.
_walker
[
n
]
fileid
=
self
.
_walker
[
n
]
return
load_librispeech_item
(
fileid
,
self
.
_path
,
self
.
_ext_audio
,
self
.
_ext_txt
)
return
load_librispeech_item
(
fileid
,
self
.
_path
,
self
.
_ext_audio
,
self
.
_ext_txt
)
...
...
torchaudio/datasets/libritts.py
View file @
cb40dd72
...
@@ -134,8 +134,8 @@ class LIBRITTS(Dataset):
...
@@ -134,8 +134,8 @@ class LIBRITTS(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform, sample_rate, original_text, normalized_text, speaker_id,
(Tensor, int, str, str, str, int, int, str):
chapter_id, utterance_id)``
``(waveform, sample_rate, original_text, normalized_text, speaker_id,
chapter_id, utterance_id)``
"""
"""
fileid
=
self
.
_walker
[
n
]
fileid
=
self
.
_walker
[
n
]
return
load_libritts_item
(
return
load_libritts_item
(
...
...
torchaudio/datasets/ljspeech.py
View file @
cb40dd72
...
@@ -68,7 +68,8 @@ class LJSPEECH(Dataset):
...
@@ -68,7 +68,8 @@ class LJSPEECH(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform, sample_rate, transcript, normalized_transcript)``
(Tensor, int, str, str):
``(waveform, sample_rate, transcript, normalized_transcript)``
"""
"""
line
=
self
.
_flist
[
n
]
line
=
self
.
_flist
[
n
]
fileid
,
transcript
,
normalized_transcript
=
line
fileid
,
transcript
,
normalized_transcript
=
line
...
...
torchaudio/datasets/speechcommands.py
View file @
cb40dd72
...
@@ -138,7 +138,8 @@ class SPEECHCOMMANDS(Dataset):
...
@@ -138,7 +138,8 @@ class SPEECHCOMMANDS(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform, sample_rate, label, speaker_id, utterance_number)``
(Tensor, int, str, str, int):
``(waveform, sample_rate, label, speaker_id, utterance_number)``
"""
"""
fileid
=
self
.
_walker
[
n
]
fileid
=
self
.
_walker
[
n
]
return
load_speechcommands_item
(
fileid
,
self
.
_path
)
return
load_speechcommands_item
(
fileid
,
self
.
_path
)
...
...
torchaudio/datasets/tedlium.py
View file @
cb40dd72
...
@@ -127,7 +127,8 @@ class TEDLIUM(Dataset):
...
@@ -127,7 +127,8 @@ class TEDLIUM(Dataset):
path (str): Dataset root path
path (str): Dataset root path
Returns:
Returns:
tuple: ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
"""
"""
transcript_path
=
os
.
path
.
join
(
path
,
"stm"
,
fileid
)
transcript_path
=
os
.
path
.
join
(
path
,
"stm"
,
fileid
)
with
open
(
transcript_path
+
".stm"
)
as
f
:
with
open
(
transcript_path
+
".stm"
)
as
f
:
...
...
torchaudio/datasets/utils.py
View file @
cb40dd72
...
@@ -151,7 +151,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
...
@@ -151,7 +151,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
overwrite (bool, optional): overwrite existing files (Default: ``False``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
Returns:
l
ist: List of paths to extracted files even if not overwritten.
L
ist
[str]
: List of paths to extracted files even if not overwritten.
Examples:
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
...
...
torchaudio/datasets/vctk.py
View file @
cb40dd72
...
@@ -134,7 +134,8 @@ class VCTK_092(Dataset):
...
@@ -134,7 +134,8 @@ class VCTK_092(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple: ``(waveform, sample_rate, transcript, speaker_id, utterance_id)``
(Tensor, int, str, str, str):
``(waveform, sample_rate, transcript, speaker_id, utterance_id)``
"""
"""
speaker_id
,
utterance_id
=
self
.
_sample_ids
[
n
]
speaker_id
,
utterance_id
=
self
.
_sample_ids
[
n
]
return
self
.
_load_sample
(
speaker_id
,
utterance_id
,
self
.
_mic_id
)
return
self
.
_load_sample
(
speaker_id
,
utterance_id
,
self
.
_mic_id
)
...
...
torchaudio/datasets/yesno.py
View file @
cb40dd72
...
@@ -77,7 +77,7 @@ class YESNO(Dataset):
...
@@ -77,7 +77,7 @@ class YESNO(Dataset):
n (int): The index of the sample to be loaded
n (int): The index of the sample to be loaded
Returns:
Returns:
tuple
: ``(waveform, sample_rate, labels)``
(Tensor, int, List[int])
: ``(waveform, sample_rate, labels)``
"""
"""
fileid
=
self
.
_walker
[
n
]
fileid
=
self
.
_walker
[
n
]
item
=
self
.
_load_item
(
fileid
,
self
.
_path
)
item
=
self
.
_load_item
(
fileid
,
self
.
_path
)
...
...
torchaudio/functional/filtering.py
View file @
cb40dd72
...
@@ -663,7 +663,7 @@ def filtfilt(
...
@@ -663,7 +663,7 @@ def filtfilt(
Returns:
Returns:
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
are 2D Tensors, or `(..., time)` otherwise.
"""
"""
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
backward_filtered
=
lfilter
(
backward_filtered
=
lfilter
(
...
@@ -987,7 +987,7 @@ def lfilter(
...
@@ -987,7 +987,7 @@ def lfilter(
Returns:
Returns:
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
are 2D Tensors, or `(..., time)` otherwise.
"""
"""
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
ndim
<=
2
assert
a_coeffs
.
ndim
<=
2
...
@@ -1474,7 +1474,7 @@ def vad(
...
@@ -1474,7 +1474,7 @@ def vad(
in the detector algorithm. (Default: 2000.0)
in the detector algorithm. (Default: 2000.0)
Returns:
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
Reference:
Reference:
- http://sox.sourceforge.net/sox.html
- http://sox.sourceforge.net/sox.html
...
...
torchaudio/functional/functional.py
View file @
cb40dd72
...
@@ -263,7 +263,7 @@ def griffinlim(
...
@@ -263,7 +263,7 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns:
Returns:
torch.
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
"""
"""
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
...
@@ -1369,7 +1369,7 @@ def apply_codec(
...
@@ -1369,7 +1369,7 @@ def apply_codec(
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
Returns:
Returns:
torch.
Tensor: Resulting Tensor.
Tensor: Resulting Tensor.
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
"""
"""
bytes
=
io
.
BytesIO
()
bytes
=
io
.
BytesIO
()
...
...
torchaudio/models/conv_tasnet.py
View file @
cb40dd72
...
@@ -154,7 +154,7 @@ class MaskGenerator(torch.nn.Module):
...
@@ -154,7 +154,7 @@ class MaskGenerator(torch.nn.Module):
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
input (torch.Tensor): 3D Tensor with shape [batch, features, frames]
Returns:
Returns:
torch.
Tensor: shape [batch, num_sources, features, frames]
Tensor: shape [batch, num_sources, features, frames]
"""
"""
batch_size
=
input
.
shape
[
0
]
batch_size
=
input
.
shape
[
0
]
feats
=
self
.
input_norm
(
input
)
feats
=
self
.
input_norm
(
input
)
...
@@ -264,7 +264,7 @@ class ConvTasNet(torch.nn.Module):
...
@@ -264,7 +264,7 @@ class ConvTasNet(torch.nn.Module):
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
input (torch.Tensor): 3D Tensor with shape (batch_size, channels==1, frames)
Returns:
Returns:
torch.
Tensor: Padded Tensor
Tensor: Padded Tensor
int: Number of paddings performed
int: Number of paddings performed
"""
"""
batch_size
,
num_channels
,
num_frames
=
input
.
shape
batch_size
,
num_channels
,
num_frames
=
input
.
shape
...
@@ -291,7 +291,7 @@ class ConvTasNet(torch.nn.Module):
...
@@ -291,7 +291,7 @@ class ConvTasNet(torch.nn.Module):
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
input (torch.Tensor): 3D Tensor with shape [batch, channel==1, frames]
Returns:
Returns:
torch.
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
Tensor: 3D Tensor with shape [batch, channel==num_sources, frames]
"""
"""
if
input
.
ndim
!=
3
or
input
.
shape
[
1
]
!=
1
:
if
input
.
ndim
!=
3
or
input
.
shape
[
1
]
!=
1
:
raise
ValueError
(
raise
ValueError
(
...
...
torchaudio/models/tacotron2.py
View file @
cb40dd72
...
@@ -1031,7 +1031,7 @@ class Tacotron2(nn.Module):
...
@@ -1031,7 +1031,7 @@ class Tacotron2(nn.Module):
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
Returns:
Returns:
Tensor, Tensor, Tensor,
and
Tensor:
[
Tensor, Tensor, Tensor, Tensor
]
:
Tensor
Tensor
Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
Mel spectrogram before Postnet with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
Tensor
Tensor
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment