Unverified Commit cb40dd72 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

[DOC] Standardization and minor fixes (#1892)

parent 955cdbdc
...@@ -90,7 +90,7 @@ class Wav2Vec2Model(Module): ...@@ -90,7 +90,7 @@ class Wav2Vec2Model(Module):
lengths (Tensor or None, optional): lengths (Tensor or None, optional):
Indicates the valid length of each audio in the batch. Indicates the valid length of each audio in the batch.
Shape: `(batch, )`. Shape: `(batch, )`.
When the ``waveforms`` contains audios with different duration, When the ``waveforms`` contains audios with different durations,
by providing ``lengths`` argument, the model will compute by providing ``lengths`` argument, the model will compute
the corresponding valid output lengths and apply proper mask in the corresponding valid output lengths and apply proper mask in
transformer attention layer. transformer attention layer.
...@@ -104,7 +104,7 @@ class Wav2Vec2Model(Module): ...@@ -104,7 +104,7 @@ class Wav2Vec2Model(Module):
Shape: `(batch, frames, num labels)`. Shape: `(batch, frames, num labels)`.
Tensor or None Tensor or None
If ``lengths`` argument was provided, a Tensor of shape `(batch, )` If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. is returned.
It indicates the valid length in time axis of the output Tensor. It indicates the valid length in time axis of the output Tensor.
""" """
x, lengths = self.feature_extractor(waveforms, lengths) x, lengths = self.feature_extractor(waveforms, lengths)
......
...@@ -283,7 +283,7 @@ class WaveRNN(nn.Module): ...@@ -283,7 +283,7 @@ class WaveRNN(nn.Module):
specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time) specgram: the input spectrogram to the WaveRNN layer (n_batch, 1, n_freq, n_time)
Return: Return:
Tensor shape: (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes) Tensor: shape (n_batch, 1, (n_time - kernel_size + 1) * hop_length, n_classes)
""" """
assert waveform.size(1) == 1, 'Require the input channel of waveform is 1' assert waveform.size(1) == 1, 'Require the input channel of waveform is 1'
...@@ -343,7 +343,7 @@ class WaveRNN(nn.Module): ...@@ -343,7 +343,7 @@ class WaveRNN(nn.Module):
lengths (Tensor or None, optional): lengths (Tensor or None, optional):
Indicates the valid length of each audio in the batch. Indicates the valid length of each audio in the batch.
Shape: `(batch, )`. Shape: `(batch, )`.
When the ``specgram`` contains spectrograms with different duration, When the ``specgram`` contains spectrograms with different durations,
by providing ``lengths`` argument, the model will compute by providing ``lengths`` argument, the model will compute
the corresponding valid output lengths. the corresponding valid output lengths.
If ``None``, it is assumed that all the audio in ``waveforms`` If ``None``, it is assumed that all the audio in ``waveforms``
...@@ -356,7 +356,7 @@ class WaveRNN(nn.Module): ...@@ -356,7 +356,7 @@ class WaveRNN(nn.Module):
1 stands for a single channel. 1 stands for a single channel.
Tensor or None Tensor or None
If ``lengths`` argument was provided, a Tensor of shape `(batch, )` If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. is returned.
It indicates the valid length in time axis of the output Tensor. It indicates the valid length in time axis of the output Tensor.
""" """
......
...@@ -25,7 +25,7 @@ class _TextProcessor(ABC): ...@@ -25,7 +25,7 @@ class _TextProcessor(ABC):
text (str or list of str): The input texts. text (str or list of str): The input texts.
Returns: Returns:
Tensor and Tensor: (Tensor, Tensor):
Tensor: Tensor:
The encoded texts. Shape: `(batch, max length)` The encoded texts. Shape: `(batch, max length)`
Tensor: Tensor:
...@@ -56,7 +56,7 @@ class _Vocoder(ABC): ...@@ -56,7 +56,7 @@ class _Vocoder(ABC):
The valid length of each sample in the batch. Shape: `(batch, )`. The valid length of each sample in the batch. Shape: `(batch, )`.
Returns: Returns:
Tensor and optional Tensor: (Tensor, Optional[Tensor]):
Tensor: Tensor:
The generated waveform. Shape: `(batch, max length)` The generated waveform. Shape: `(batch, max length)`
Tensor or None: Tensor or None:
......
...@@ -134,7 +134,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -134,7 +134,7 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
unk (str, optional): Token for unknown class. (default: ``'<unk>'``) unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns: Returns:
Tuple of strings: Tuple[str]:
For models fine-tuned on ASR, returns the tuple of strings representing For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels. the output class labels.
......
...@@ -276,20 +276,20 @@ class _EmformerAttention(torch.nn.Module): ...@@ -276,20 +276,20 @@ class _EmformerAttention(torch.nn.Module):
M: number of memory elements. M: number of memory elements.
Args: Args:
utterance (torch.Tensor): utterance frames, with shape (T, B, D). utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``utterance``. number of valid frames for i-th batch element in ``utterance``.
right_context (torch.Tensor): right context frames, with shape (R, B, D). right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
summary (torch.Tensor): summary elements, with shape (S, B, D). summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
mems (torch.Tensor): memory elements, with shape (M, B, D). mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
attention_mask (torch.Tensor): attention mask for underlying attention module. attention_mask (torch.Tensor): attention mask for underlying attention module.
Returns: Returns:
torch.Tensor and torch.Tensor: (Tensor, Tensor):
torch.Tensor Tensor
output frames corresponding to utterance and right_context, with shape (T + R, B, D). output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
torch.Tensor Tensor
updated memory elements, with shape (M, B, D). updated memory elements, with shape `(M, B, D)`.
""" """
output, output_mems, _, _ = self._forward_impl( output, output_mems, _, _ = self._forward_impl(
utterance, lengths, right_context, summary, mems, attention_mask utterance, lengths, right_context, summary, mems, attention_mask
...@@ -317,24 +317,24 @@ class _EmformerAttention(torch.nn.Module): ...@@ -317,24 +317,24 @@ class _EmformerAttention(torch.nn.Module):
M: number of memory elements. M: number of memory elements.
Args: Args:
utterance (torch.Tensor): utterance frames, with shape (T, B, D). utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``utterance``. number of valid frames for i-th batch element in ``utterance``.
right_context (torch.Tensor): right context frames, with shape (R, B, D). right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
summary (torch.Tensor): summary elements, with shape (S, B, D). summary (torch.Tensor): summary elements, with shape `(S, B, D)`.
mems (torch.Tensor): memory elements, with shape (M, B, D). mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
left_context_key (torch.Tensor): left context attention key computed from preceding invocation. left_context_key (torch.Tensor): left context attention key computed from preceding invocation.
left_context_val (torch.Tensor): left context attention value computed from preceding invocation. left_context_val (torch.Tensor): left context attention value computed from preceding invocation.
Returns: Returns:
torch.Tensor, torch.Tensor, torch.Tensor, and torch.Tensor: (Tensor, Tensor, Tensor, and Tensor):
torch.Tensor Tensor
output frames corresponding to utterance and right_context, with shape (T + R, B, D). output frames corresponding to utterance and right_context, with shape `(T + R, B, D)`.
torch.Tensor Tensor
updated memory elements, with shape (M, B, D). updated memory elements, with shape `(M, B, D)`.
torch.Tensor Tensor
attention key computed for left context and utterance. attention key computed for left context and utterance.
torch.Tensor Tensor
attention value computed for left context and utterance. attention value computed for left context and utterance.
""" """
query_dim = right_context.size(0) + utterance.size(0) + summary.size(0) query_dim = right_context.size(0) + utterance.size(0) + summary.size(0)
...@@ -575,21 +575,21 @@ class _EmformerLayer(torch.nn.Module): ...@@ -575,21 +575,21 @@ class _EmformerLayer(torch.nn.Module):
M: number of memory elements. M: number of memory elements.
Args: Args:
utterance (torch.Tensor): utterance frames, with shape (T, B, D). utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``utterance``. number of valid frames for i-th batch element in ``utterance``.
right_context (torch.Tensor): right context frames, with shape (R, B, D). right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
mems (torch.Tensor): memory elements, with shape (M, B, D). mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
attention_mask (torch.Tensor): attention mask for underlying attention module. attention_mask (torch.Tensor): attention mask for underlying attention module.
Returns: Returns:
torch.Tensor, torch.Tensor, and torch.Tensor: (Tensor, Tensor, Tensor):
torch.Tensor Tensor
encoded utterance frames, with shape (T, B, D). encoded utterance frames, with shape `(T, B, D)`.
torch.Tensor Tensor
updated right context frames, with shape (R, B, D). updated right context frames, with shape `(R, B, D)`.
torch.Tensor Tensor
updated memory elements, with shape (M, B, D). updated memory elements, with shape `(M, B, D)`.
""" """
( (
layer_norm_utterance, layer_norm_utterance,
...@@ -625,25 +625,25 @@ class _EmformerLayer(torch.nn.Module): ...@@ -625,25 +625,25 @@ class _EmformerLayer(torch.nn.Module):
M: number of memory elements. M: number of memory elements.
Args: Args:
utterance (torch.Tensor): utterance frames, with shape (T, B, D). utterance (torch.Tensor): utterance frames, with shape `(T, B, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``utterance``. number of valid frames for i-th batch element in ``utterance``.
right_context (torch.Tensor): right context frames, with shape (R, B, D). right_context (torch.Tensor): right context frames, with shape `(R, B, D)`.
state (List[torch.Tensor] or None): list of tensors representing layer internal state state (List[torch.Tensor] or None): list of tensors representing layer internal state
generated in preceding invocation of ``infer``. generated in preceding invocation of ``infer``.
mems (torch.Tensor): memory elements, with shape (M, B, D). mems (torch.Tensor): memory elements, with shape `(M, B, D)`.
Returns: Returns:
torch.Tensor, torch.Tensor, List[torch.Tensor], and torch.Tensor: (Tensor, Tensor, List[torch.Tensor], Tensor):
torch.Tensor Tensor
encoded utterance frames, with shape (T, B, D). encoded utterance frames, with shape `(T, B, D)`.
torch.Tensor Tensor
updated right context frames, with shape (R, B, D). updated right context frames, with shape `(R, B, D)`.
List[torch.Tensor] List[Tensor]
list of tensors representing layer internal state list of tensors representing layer internal state
generated in current invocation of ``infer``. generated in current invocation of ``infer``.
torch.Tensor Tensor
updated memory elements, with shape (M, B, D). updated memory elements, with shape `(M, B, D)`.
""" """
( (
layer_norm_utterance, layer_norm_utterance,
...@@ -851,16 +851,16 @@ class Emformer(torch.nn.Module): ...@@ -851,16 +851,16 @@ class Emformer(torch.nn.Module):
Args: Args:
input (torch.Tensor): utterance frames right-padded with right context frames, with input (torch.Tensor): utterance frames right-padded with right context frames, with
shape (B, T, D). shape `(B, T, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``. number of valid frames for i-th batch element in ``input``.
Returns: Returns:
torch.Tensor and torch.Tensor: (Tensor, Tensor):
torch.Tensor Tensor
output frames, with shape (B, T - ``right_context_length``, D). output frames, with shape `(B, T - ``right_context_length``, D)`.
torch.Tensor Tensor
output lengths, with shape (B,) and i-th element representing output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output frames. number of valid frames for i-th batch element in output frames.
""" """
input = input.permute(1, 0, 2) input = input.permute(1, 0, 2)
...@@ -894,20 +894,20 @@ class Emformer(torch.nn.Module): ...@@ -894,20 +894,20 @@ class Emformer(torch.nn.Module):
Args: Args:
input (torch.Tensor): utterance frames right-padded with right context frames, with input (torch.Tensor): utterance frames right-padded with right context frames, with
shape (B, T, D). shape `(B, T, D)`.
lengths (torch.Tensor): with shape (B,) and i-th element representing lengths (torch.Tensor): with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in ``input``. number of valid frames for i-th batch element in ``input``.
states (List[List[torch.Tensor]] or None, optional): list of lists of tensors states (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing Emformer internal state generated in preceding invocation of ``infer``. (Default: ``None``) representing Emformer internal state generated in preceding invocation of ``infer``. (Default: ``None``)
Returns: Returns:
torch.Tensor, torch.Tensor, and List[List[torch.Tensor]]: (Tensor, Tensor, List[List[Tensor]]):
torch.Tensor Tensor
output frames, with shape (B, T - ``right_context_length``, D). output frames, with shape `(B, T - ``right_context_length``, D)`.
torch.Tensor Tensor
output lengths, with shape (B,) and i-th element representing output lengths, with shape `(B,)` and i-th element representing
number of valid frames for i-th batch element in output frames. number of valid frames for i-th batch element in output frames.
List[List[torch.Tensor]] List[List[Tensor]]
output states; list of lists of tensors representing Emformer internal state output states; list of lists of tensors representing Emformer internal state
generated in current invocation of ``infer``. generated in current invocation of ``infer``.
""" """
......
...@@ -73,10 +73,10 @@ def apply_effects_tensor( ...@@ -73,10 +73,10 @@ def apply_effects_tensor(
sample_rate (int): Sample rate sample_rate (int): Sample rate
effects (List[List[str]]): List of effects. effects (List[List[str]]): List of effects.
channels_first (bool, optional): Indicates if the input Tensor's dimension is channels_first (bool, optional): Indicates if the input Tensor's dimension is
``[channels, time]`` or ``[time, channels]`` `[channels, time]` or `[time, channels]`
Returns: Returns:
Tuple[torch.Tensor, int]: Resulting Tensor and sample rate. (Tensor, int): Resulting Tensor and sample rate.
The resulting Tensor has the same ``dtype`` as the input Tensor, and The resulting Tensor has the same ``dtype`` as the input Tensor, and
the same channels order. The shape of the Tensor can be different based on the the same channels order. The shape of the Tensor can be different based on the
effects applied. Sample rate can also be different based on the effects applied. effects applied. Sample rate can also be different based on the effects applied.
...@@ -191,20 +191,20 @@ def apply_effects_file( ...@@ -191,20 +191,20 @@ def apply_effects_file(
If input file is integer WAV, giving ``False`` will change the resulting Tensor type to If input file is integer WAV, giving ``False`` will change the resulting Tensor type to
integer type. This argument has no effect for formats other integer type. This argument has no effect for formats other
than integer WAV type. than integer WAV type.
channels_first (bool, optional): When True, the returned Tensor has dimension ``[channel, time]``. channels_first (bool, optional): When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is ``[time, channel]``. Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional): format (str or None, optional):
Override the format detection with the given format. Override the format detection with the given format.
Providing the argument might help when libsox can not infer the format Providing the argument might help when libsox can not infer the format
from header or extension, from header or extension,
Returns: Returns:
Tuple[torch.Tensor, int]: Resulting Tensor and sample rate. (Tensor, int): Resulting Tensor and sample rate.
If ``normalize=True``, the resulting Tensor is always ``float32`` type. If ``normalize=True``, the resulting Tensor is always ``float32`` type.
If ``normalize=False`` and the input audio file is of integer WAV file, then the If ``normalize=False`` and the input audio file is of integer WAV file, then the
resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported) resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported)
If ``channels_first=True``, the resulting Tensor has dimension ``[channel, time]``, If ``channels_first=True``, the resulting Tensor has dimension `[channel, time]`,
otherwise ``[time, channel]``. otherwise `[time, channel]`.
Example - Basic usage Example - Basic usage
>>> >>>
......
...@@ -787,7 +787,7 @@ class MuLawEncoding(torch.nn.Module): ...@@ -787,7 +787,7 @@ class MuLawEncoding(torch.nn.Module):
x (Tensor): A signal to be encoded. x (Tensor): A signal to be encoded.
Returns: Returns:
x_mu (Tensor): An encoded signal. Tensor: An encoded signal.
""" """
return F.mu_law_encoding(x, self.quantization_channels) return F.mu_law_encoding(x, self.quantization_channels)
...@@ -1629,7 +1629,7 @@ class PSD(torch.nn.Module): ...@@ -1629,7 +1629,7 @@ class PSD(torch.nn.Module):
of dimension `(..., channel, freq, time)` if multi_mask is ``True`` of dimension `(..., channel, freq, time)` if multi_mask is ``True``
Returns: Returns:
torch.Tensor: PSD matrix of the input STFT matrix. Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension `(..., freq, channel, channel)` Tensor of dimension `(..., freq, channel, channel)`
""" """
# outer product: # outer product:
...@@ -1773,7 +1773,7 @@ class MVDR(torch.nn.Module): ...@@ -1773,7 +1773,7 @@ class MVDR(torch.nn.Module):
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8) eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)
Returns: Returns:
torch.Tensor: the mvdr beamforming weight matrix Tensor: the mvdr beamforming weight matrix
""" """
if self.multi_mask: if self.multi_mask:
# Averaging mask along channel dimension # Averaging mask along channel dimension
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment