Unverified Commit 49d62b01 authored by rbsteinm's avatar rbsteinm Committed by GitHub
Browse files

[Wav2Vec2] Fix None loss in doc examples (#19218)

* pass sampled_negative_indices parameter to the model to avoid getting a None loss
* concerns doc examples for Wav2Vec2ForPreTraining and Wav2Vec2ConformerForPreTraining
parent 1a1893e5
......@@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
......@@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices(
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
......@@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # for contrastive loss training model should be put into train mode
>>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
>>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
......@@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
... _compute_mask_indices,
... _sample_negative_indices,
... )
>>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
......@@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length)
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2)
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices(
... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices)
......@@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # for contrastive loss training model should be put into train mode
>>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss
>>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment