Unverified Commit 49d62b01 authored by rbsteinm's avatar rbsteinm Committed by GitHub
Browse files

[Wav2Vec2] Fix None loss in doc examples (#19218)

* pass sampled_negative_indices parameter to the model to avoid getting a None loss
* concerns doc examples for Wav2Vec2ForPreTraining and Wav2Vec2ConformerForPreTraining
parent 1a1893e5
...@@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1421,7 +1421,7 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
...@@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1432,9 +1432,19 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # compute masked indices >>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape >>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) >>> mask_time_indices = _compute_mask_indices(
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long) ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad(): >>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices) ... outputs = model(input_values, mask_time_indices=mask_time_indices)
...@@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): ...@@ -1448,7 +1458,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> # for contrastive loss training model should be put into train mode >>> # for contrastive loss training model should be put into train mode
>>> model = model.train() >>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss >>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
...@@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1469,7 +1469,10 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
... _compute_mask_indices,
... _sample_negative_indices,
... )
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
...@@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1480,9 +1483,19 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # compute masked indices >>> # compute masked indices
>>> batch_size, raw_sequence_length = input_values.shape >>> batch_size, raw_sequence_length = input_values.shape
>>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
>>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) >>> mask_time_indices = _compute_mask_indices(
>>> mask_time_indices = torch.tensor(mask_time_indices, device=input_values.device, dtype=torch.long) ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
... )
>>> sampled_negative_indices = _sample_negative_indices(
... features_shape=(batch_size, sequence_length),
... num_negatives=model.config.num_negatives,
... mask_time_indices=mask_time_indices,
... )
>>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
>>> sampled_negative_indices = torch.tensor(
... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
... )
>>> with torch.no_grad(): >>> with torch.no_grad():
... outputs = model(input_values, mask_time_indices=mask_time_indices) ... outputs = model(input_values, mask_time_indices=mask_time_indices)
...@@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): ...@@ -1496,7 +1509,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> # for contrastive loss training model should be put into train mode >>> # for contrastive loss training model should be put into train mode
>>> model = model.train() >>> model = model.train()
>>> loss = model(input_values, mask_time_indices=mask_time_indices).loss >>> loss = model(
... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
... ).loss
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment