"doc/git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad"
Commit 7ac3e2e2 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Use double quotes for string in functional and transforms (#2618)

Summary:
To make the code consistent, we should use double quotation marks for all strings. This PR make such changes in functional and transforms.

Pull Request resolved: https://github.com/pytorch/audio/pull/2618

Reviewed By: carolineechen

Differential Revision: D38744137

Pulled By: nateanl

fbshipit-source-id: 74213a24d9f66c306cc92019d77dcb2a877f94bd
parent 05545791
...@@ -758,10 +758,10 @@ def flanger( ...@@ -758,10 +758,10 @@ def flanger(
""" """
if modulation not in ("sinusoidal", "triangular"): if modulation not in ("sinusoidal", "triangular"):
raise ValueError("Only 'sinusoidal' or 'triangular' modulation allowed") raise ValueError('Only "sinusoidal" or "triangular" modulation allowed')
if interpolation not in ("linear", "quadratic"): if interpolation not in ("linear", "quadratic"):
raise ValueError("Only 'linear' or 'quadratic' interpolation allowed") raise ValueError('Only "linear" or "quadratic" interpolation allowed')
actual_shape = waveform.shape actual_shape = waveform.shape
device, dtype = waveform.device, waveform.dtype device, dtype = waveform.device, waveform.dtype
......
...@@ -533,7 +533,7 @@ def melscale_fbanks( ...@@ -533,7 +533,7 @@ def melscale_fbanks(
f_max (float): Maximum frequency (Hz) f_max (float): Maximum frequency (Hz)
n_mels (int): Number of mel filterbanks n_mels (int): Number of mel filterbanks
sample_rate (int): Sample rate of the audio waveform sample_rate (int): Sample rate of the audio waveform
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
...@@ -547,7 +547,7 @@ def melscale_fbanks( ...@@ -547,7 +547,7 @@ def melscale_fbanks(
""" """
if norm is not None and norm != "slaney": if norm is not None and norm != "slaney":
raise ValueError("norm must be one of None or 'slaney'") raise ValueError('norm must be one of None or "slaney"')
# freq bins # freq bins
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
...@@ -634,7 +634,7 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor: ...@@ -634,7 +634,7 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
Args: Args:
n_mfcc (int): Number of mfc coefficients to retain n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of mel filterbanks n_mels (int): Number of mel filterbanks
norm (str or None): Norm to use (either 'ortho' or None) norm (str or None): Norm to use (either "ortho" or None)
Returns: Returns:
Tensor: The transformation matrix, to be right-multiplied to Tensor: The transformation matrix, to be right-multiplied to
...@@ -642,7 +642,7 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor: ...@@ -642,7 +642,7 @@ def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> Tensor:
""" """
if norm is not None and norm != "ortho": if norm is not None and norm != "ortho":
raise ValueError("norm must be either 'ortho' or None") raise ValueError('norm must be either "ortho" or None')
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(float(n_mels)) n = torch.arange(float(n_mels))
...@@ -1571,7 +1571,7 @@ def resample( ...@@ -1571,7 +1571,7 @@ def resample(
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``) Lower values reduce anti-aliasing, but also reduce some of the highest frequencies. (Default: ``0.99``)
resampling_method (str, optional): The resampling method to use. resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) Options: [``"sinc_interpolation"``, ``"kaiser_window"``] (Default: ``"sinc_interpolation"``)
beta (float or None, optional): The shape parameter used for kaiser window. beta (float or None, optional): The shape parameter used for kaiser window.
Returns: Returns:
...@@ -1859,13 +1859,13 @@ def rnnt_loss( ...@@ -1859,13 +1859,13 @@ def rnnt_loss(
blank (int, optional): blank label (Default: ``-1``) blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``) clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
Returns: Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size `(batch)`, Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size `(batch)`,
otherwise scalar. otherwise scalar.
""" """
if reduction not in ["none", "mean", "sum"]: if reduction not in ["none", "mean", "sum"]:
raise ValueError("reduction should be one of 'none', 'mean', or 'sum'") raise ValueError('reduction should be one of "none", "mean", or "sum"')
if blank < 0: # reinterpret blank index if blank < 0. if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank blank = logits.shape[-1] + blank
...@@ -2059,7 +2059,7 @@ def mvdr_weights_souden( ...@@ -2059,7 +2059,7 @@ def mvdr_weights_souden(
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]]) beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
else: else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.") raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
return beamform_weights return beamform_weights
...@@ -2142,7 +2142,7 @@ def mvdr_weights_rtf( ...@@ -2142,7 +2142,7 @@ def mvdr_weights_rtf(
reference_channel = reference_channel.to(psd_n.dtype) reference_channel = reference_channel.to(psd_n.dtype)
scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]]) scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
else: else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.") raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
beamform_weights = beamform_weights * scale[..., None] beamform_weights = beamform_weights * scale[..., None]
...@@ -2220,7 +2220,7 @@ def rtf_power( ...@@ -2220,7 +2220,7 @@ def rtf_power(
reference_channel = reference_channel.to(psd_n.dtype) reference_channel = reference_channel.to(psd_n.dtype)
rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]]) rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
else: else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.") raise TypeError(f'Expected "int" or "Tensor" for reference_channel. Found: {type(reference_channel)}.')
rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1) rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
if n_iter >= 2: if n_iter >= 2:
# The number of iterations in the for loop is `n_iter - 2` # The number of iterations in the for loop is `n_iter - 2`
......
...@@ -178,7 +178,7 @@ class MVDR(torch.nn.Module): ...@@ -178,7 +178,7 @@ class MVDR(torch.nn.Module):
"stv_power", "stv_power",
]: ]:
raise ValueError( raise ValueError(
"`solution` must be one of ['ref_channel', 'stv_evd', 'stv_power']. Given {}".format(solution) '`solution` must be one of ["ref_channel", "stv_evd", "stv_power"]. Given {}'.format(solution)
) )
self.ref_channel = ref_channel self.ref_channel = ref_channel
self.solution = solution self.solution = solution
......
...@@ -52,7 +52,7 @@ class Spectrogram(torch.nn.Module): ...@@ -52,7 +52,7 @@ class Spectrogram(torch.nn.Module):
Deprecated and not used. Deprecated and not used.
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = torchaudio.transforms.Spectrogram(n_fft=800) >>> transform = torchaudio.transforms.Spectrogram(n_fft=800)
>>> spectrogram = transform(waveform) >>> spectrogram = transform(waveform)
...@@ -306,8 +306,8 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -306,8 +306,8 @@ class AmplitudeToDB(torch.nn.Module):
a full clip. a full clip.
Args: Args:
stype (str, optional): scale of input tensor (``'power'`` or ``'magnitude'``). The stype (str, optional): scale of input tensor (``"power"`` or ``"magnitude"``). The
power being the elementwise square of the magnitude. (Default: ``'power'``) power being the elementwise square of the magnitude. (Default: ``"power"``)
top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable top_db (float or None, optional): minimum negative cut-off in decibels. A reasonable
number is 80. (Default: ``None``) number is 80. (Default: ``None``)
...@@ -356,7 +356,7 @@ class MelScale(torch.nn.Module): ...@@ -356,7 +356,7 @@ class MelScale(torch.nn.Module):
f_min (float, optional): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``) n_stft (int, optional): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. (Default: ``201``)
norm (str or None, optional): If ``'slaney'``, divide the triangular mel weights by the width of the mel band norm (str or None, optional): If ``"slaney"``, divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
...@@ -430,7 +430,7 @@ class InverseMelScale(torch.nn.Module): ...@@ -430,7 +430,7 @@ class InverseMelScale(torch.nn.Module):
tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``) tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``) tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``) sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
""" """
...@@ -561,12 +561,12 @@ class MelSpectrogram(torch.nn.Module): ...@@ -561,12 +561,12 @@ class MelSpectrogram(torch.nn.Module):
:attr:`center` is ``True``. (Default: ``"reflect"``) :attr:`center` is ``True``. (Default: ``"reflect"``)
onesided (bool, optional): controls whether to return half of results to onesided (bool, optional): controls whether to return half of results to
avoid redundancy. (Default: ``True``) avoid redundancy. (Default: ``True``)
norm (str or None, optional): If 'slaney', divide the triangular mel weights by the width of the mel band norm (str or None, optional): If "slaney", divide the triangular mel weights by the width of the mel band
(area normalization). (Default: ``None``) (area normalization). (Default: ``None``)
mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``) mel_scale (str, optional): Scale to use: ``htk`` or ``slaney``. (Default: ``htk``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.MelSpectrogram(sample_rate) >>> transform = transforms.MelSpectrogram(sample_rate)
>>> mel_specgram = transform(waveform) # (channel, n_mels, time) >>> mel_specgram = transform(waveform) # (channel, n_mels, time)
...@@ -656,7 +656,7 @@ class MFCC(torch.nn.Module): ...@@ -656,7 +656,7 @@ class MFCC(torch.nn.Module):
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``) n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``) dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``) norm (str, optional): norm to use. (Default: ``"ortho"``)
log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``) log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``) melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
...@@ -737,7 +737,7 @@ class LFCC(torch.nn.Module): ...@@ -737,7 +737,7 @@ class LFCC(torch.nn.Module):
f_min (float, optional): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``) f_max (float or None, optional): Maximum frequency. (Default: ``None``)
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``) dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``) norm (str, optional): norm to use. (Default: ``"ortho"``)
log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``) log_lf (bool, optional): whether to use log-lf spectrograms instead of db-scaled. (Default: ``False``)
speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``) speckwargs (dict or None, optional): arguments for Spectrogram. (Default: ``None``)
...@@ -834,7 +834,7 @@ class MuLawEncoding(torch.nn.Module): ...@@ -834,7 +834,7 @@ class MuLawEncoding(torch.nn.Module):
quantization_channels (int, optional): Number of channels. (Default: ``256``) quantization_channels (int, optional): Number of channels. (Default: ``256``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512) >>> transform = torchaudio.transforms.MuLawEncoding(quantization_channels=512)
>>> mulawtrans = transform(waveform) >>> mulawtrans = transform(waveform)
...@@ -873,7 +873,7 @@ class MuLawDecoding(torch.nn.Module): ...@@ -873,7 +873,7 @@ class MuLawDecoding(torch.nn.Module):
quantization_channels (int, optional): Number of channels. (Default: ``256``) quantization_channels (int, optional): Number of channels. (Default: ``256``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512) >>> transform = torchaudio.transforms.MuLawDecoding(quantization_channels=512)
>>> mulawtrans = transform(waveform) >>> mulawtrans = transform(waveform)
""" """
...@@ -911,7 +911,7 @@ class Resample(torch.nn.Module): ...@@ -911,7 +911,7 @@ class Resample(torch.nn.Module):
orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``) orig_freq (int, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (int, optional): The desired frequency. (Default: ``16000``) new_freq (int, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method to use. resampling_method (str, optional): The resampling method to use.
Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``'sinc_interpolation'``) Options: [``sinc_interpolation``, ``kaiser_window``] (Default: ``"sinc_interpolation"``)
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
but less efficient. (Default: ``6``) but less efficient. (Default: ``6``)
rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist. rolloff (float, optional): The roll-off frequency of the filter, as a fraction of the Nyquist.
...@@ -926,7 +926,7 @@ class Resample(torch.nn.Module): ...@@ -926,7 +926,7 @@ class Resample(torch.nn.Module):
carried out on ``torch.float64``. carried out on ``torch.float64``.
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.Resample(sample_rate, sample_rate/10) >>> transform = transforms.Resample(sample_rate, sample_rate/10)
>>> waveform = transform(waveform) >>> waveform = transform(waveform)
""" """
...@@ -989,7 +989,7 @@ class ComputeDeltas(torch.nn.Module): ...@@ -989,7 +989,7 @@ class ComputeDeltas(torch.nn.Module):
Args: Args:
win_length (int, optional): The window length used for computing delta. (Default: ``5``) win_length (int, optional): The window length used for computing delta. (Default: ``5``)
mode (str, optional): Mode parameter passed to padding. (Default: ``'replicate'``) mode (str, optional): Mode parameter passed to padding. (Default: ``"replicate"``)
""" """
__constants__ = ["win_length"] __constants__ = ["win_length"]
...@@ -1093,8 +1093,8 @@ class Fade(torch.nn.Module): ...@@ -1093,8 +1093,8 @@ class Fade(torch.nn.Module):
(Default: ``"linear"``) (Default: ``"linear"``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.Fade(fade_in_len=sample_rate, fade_out_len=2 * sample_rate, fade_shape='linear') >>> transform = transforms.Fade(fade_in_len=sample_rate, fade_out_len=2 * sample_rate, fade_shape="linear")
>>> faded_waveform = transform(waveform) >>> faded_waveform = transform(waveform)
""" """
...@@ -1312,7 +1312,7 @@ class Vol(torch.nn.Module): ...@@ -1312,7 +1312,7 @@ class Vol(torch.nn.Module):
gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``) gain_type (str, optional): Type of gain. One of: ``amplitude``, ``power``, ``db`` (Default: ``amplitude``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.Vol(gain=0.5, gain_type="amplitude") >>> transform = transforms.Vol(gain=0.5, gain_type="amplitude")
>>> quieter_waveform = transform(waveform) >>> quieter_waveform = transform(waveform)
""" """
...@@ -1362,7 +1362,7 @@ class SlidingWindowCmn(torch.nn.Module): ...@@ -1362,7 +1362,7 @@ class SlidingWindowCmn(torch.nn.Module):
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.SlidingWindowCmn(cmn_window=1000) >>> transform = transforms.SlidingWindowCmn(cmn_window=1000)
>>> cmn_waveform = transform(waveform) >>> cmn_waveform = transform(waveform)
""" """
...@@ -1442,12 +1442,12 @@ class Vad(torch.nn.Module): ...@@ -1442,12 +1442,12 @@ class Vad(torch.nn.Module):
in the detector algorithm. (Default: 2000.0) in the detector algorithm. (Default: 2000.0)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> waveform_reversed, sample_rate = apply_effects_tensor(waveform, sample_rate, [['reverse']]) >>> waveform_reversed, sample_rate = apply_effects_tensor(waveform, sample_rate, [["reverse"]])
>>> transform = transforms.Vad(sample_rate=sample_rate, trigger_level=7.5) >>> transform = transforms.Vad(sample_rate=sample_rate, trigger_level=7.5)
>>> waveform_reversed_front_trim = transform(waveform_reversed) >>> waveform_reversed_front_trim = transform(waveform_reversed)
>>> waveform_end_trim, sample_rate = apply_effects_tensor( >>> waveform_end_trim, sample_rate = apply_effects_tensor(
>>> waveform_reversed_front_trim, sample_rate, [['reverse']] >>> waveform_reversed_front_trim, sample_rate, [["reverse"]]
>>> ) >>> )
Reference: Reference:
...@@ -1545,7 +1545,7 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1545,7 +1545,7 @@ class SpectralCentroid(torch.nn.Module):
wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.SpectralCentroid(sample_rate) >>> transform = transforms.SpectralCentroid(sample_rate)
>>> spectral_centroid = transform(waveform) # (channel, time) >>> spectral_centroid = transform(waveform) # (channel, time)
""" """
...@@ -1604,7 +1604,7 @@ class PitchShift(LazyModuleMixin, torch.nn.Module): ...@@ -1604,7 +1604,7 @@ class PitchShift(LazyModuleMixin, torch.nn.Module):
If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``).
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True) >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.PitchShift(sample_rate, 4) >>> transform = transforms.PitchShift(sample_rate, 4)
>>> waveform_shift = transform(waveform) # (channel, time) >>> waveform_shift = transform(waveform) # (channel, time)
""" """
...@@ -1709,7 +1709,7 @@ class RNNTLoss(torch.nn.Module): ...@@ -1709,7 +1709,7 @@ class RNNTLoss(torch.nn.Module):
blank (int, optional): blank label (Default: ``-1``) blank (int, optional): blank label (Default: ``-1``)
clamp (float, optional): clamp for gradients (Default: ``-1``) clamp (float, optional): clamp for gradients (Default: ``-1``)
reduction (string, optional): Specifies the reduction to apply to the output: reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``) ``"none"`` | ``"mean"`` | ``"sum"``. (Default: ``"mean"``)
Example Example
>>> # Hypothetical values >>> # Hypothetical values
...@@ -1755,7 +1755,7 @@ class RNNTLoss(torch.nn.Module): ...@@ -1755,7 +1755,7 @@ class RNNTLoss(torch.nn.Module):
logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder logit_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of each sequence from encoder
target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence target_lengths (Tensor): Tensor of dimension `(batch)` containing lengths of targets for each sequence
Returns: Returns:
Tensor: Loss with the reduction option applied. If ``reduction`` is ``'none'``, then size (batch), Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size (batch),
otherwise scalar. otherwise scalar.
""" """
return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction) return F.rnnt_loss(logits, targets, logit_lengths, target_lengths, self.blank, self.clamp, self.reduction)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment