Output type of :class:`~transformers.Wav2Vec2ForXVector`.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Classification hidden states before AMSoftmax.
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
Utterance embeddings used for vector similarity-based retrieval.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
def_compute_mask_indices(
def_compute_mask_indices(
shape:Tuple[int,int],
shape:Tuple[int,int],
...
@@ -1447,3 +1483,285 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
...
@@ -1447,3 +1483,285 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
hidden_states=outputs.hidden_states,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
attentions=outputs.attentions,
)
)
@add_start_docstrings(
"""
WavLM Model with a frame classification head on top for tasks like Speaker Diarization.
""",
WAVLM_START_DOCSTRING,
)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM