Unverified Commit 7fc1cb15 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Remove all hf-internal-testing checkpoints that can be removed (#21199)

* Remove all hf-internal-testing checkpoints that can be removed

* Fix copies

* Put back processor_class in TF example

* Address review comment
parent 142ad1a1
...@@ -41,14 +41,10 @@ from .configuration_hubert import HubertConfig ...@@ -41,14 +41,10 @@ from .configuration_hubert import HubertConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
# General docstring # General docstring
_CONFIG_FOR_DOC = "HubertConfig" _CONFIG_FOR_DOC = "HubertConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
...@@ -59,7 +55,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND ...@@ -59,7 +55,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND
_CTC_EXPECTED_LOSS = 22.68 _CTC_EXPECTED_LOSS = 22.68
# Audio class docstring # Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 8.53 _SEQ_CLASS_EXPECTED_LOSS = 8.53
...@@ -1145,7 +1140,6 @@ class HubertForCTC(HubertPreTrainedModel): ...@@ -1145,7 +1140,6 @@ class HubertForCTC(HubertPreTrainedModel):
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1280,7 +1274,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel): ...@@ -1280,7 +1274,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -41,7 +41,6 @@ logger = logging.get_logger(__name__) ...@@ -41,7 +41,6 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096" _CHECKPOINT_FOR_DOC = "allenai/longformer-base-4096"
_CONFIG_FOR_DOC = "LongformerConfig" _CONFIG_FOR_DOC = "LongformerConfig"
_TOKENIZER_FOR_DOC = "LongformerTokenizer"
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"allenai/longformer-base-4096", "allenai/longformer-base-4096",
...@@ -1903,7 +1902,6 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1903,7 +1902,6 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="jpwahle/longformer-base-plagiarism-detection", checkpoint="jpwahle/longformer-base-plagiarism-detection",
output_type=LongformerSequenceClassifierOutput, output_type=LongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -2172,7 +2170,6 @@ class LongformerForTokenClassification(LongformerPreTrainedModel): ...@@ -2172,7 +2170,6 @@ class LongformerForTokenClassification(LongformerPreTrainedModel):
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint="brad1141/Longformer-finetuned-norm", checkpoint="brad1141/Longformer-finetuned-norm",
output_type=LongformerTokenClassifierOutput, output_type=LongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -2260,7 +2257,6 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel): ...@@ -2260,7 +2257,6 @@ class LongformerForMultipleChoice(LongformerPreTrainedModel):
LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") LONGFORMER_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
) )
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=LongformerMultipleChoiceModelOutput, output_type=LongformerMultipleChoiceModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -2381,11 +2381,9 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque ...@@ -2381,11 +2381,9 @@ class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSeque
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint="hf-internal-testing/tiny-random-longformer", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFLongformerSequenceClassifierOutput, output_type=TFLongformerSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output="'LABEL_1'",
expected_loss=0.69,
) )
def call( def call(
self, self,
...@@ -2636,15 +2634,9 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla ...@@ -2636,15 +2634,9 @@ class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenCla
@add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
checkpoint="hf-internal-testing/tiny-random-longformer", checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFLongformerTokenClassifierOutput, output_type=TFLongformerTokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=(
"['LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
" 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1', 'LABEL_1',"
" 'LABEL_1', 'LABEL_1']"
),
expected_loss=0.59,
) )
def call( def call(
self, self,
......
...@@ -49,22 +49,10 @@ logger = logging.get_logger(__name__) ...@@ -49,22 +49,10 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25" _CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
_CONFIG_FOR_DOC = "MBartConfig" _CONFIG_FOR_DOC = "MBartConfig"
_TOKENIZER_FOR_DOC = "MBartTokenizer"
# Base model docstring # Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-random-mbart"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
# QuestionAsnwering docstring
_CHECKPOINT_FOR_QA = "hf-internal-testing/tiny-random-mbart"
_QA_EXPECTED_LOSS = 3.55
_QA_EXPECTED_OUTPUT = "'? Jim Henson was a'"
MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ MBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/mbart-large-cc25", "facebook/mbart-large-cc25",
# See all MBART models at https://huggingface.co/models?filter=mbart # See all MBART models at https://huggingface.co/models?filter=mbart
...@@ -1187,7 +1175,6 @@ class MBartModel(MBartPreTrainedModel): ...@@ -1187,7 +1175,6 @@ class MBartModel(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput, output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1467,12 +1454,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1467,12 +1454,9 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
...@@ -1596,12 +1580,9 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1596,12 +1580,9 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
@add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_QA,
output_type=Seq2SeqQuestionAnsweringModelOutput, output_type=Seq2SeqQuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_loss=_QA_EXPECTED_LOSS,
expected_output=_QA_EXPECTED_OUTPUT,
) )
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
def forward( def forward(
......
...@@ -48,16 +48,6 @@ logger = logging.get_logger(__name__) ...@@ -48,16 +48,6 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "uclanlp/plbart-base" _CHECKPOINT_FOR_DOC = "uclanlp/plbart-base"
_CONFIG_FOR_DOC = "PLBartConfig" _CONFIG_FOR_DOC = "PLBartConfig"
_TOKENIZER_FOR_DOC = "PLBartTokenizer"
# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
# SequenceClassification docstring
_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "hf-internal-testing/tiny-plbart"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.69
PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [ PLBART_PRETRAINED_MODEL_ARCHIVE_LIST = [
"uclanlp/plbart-base", "uclanlp/plbart-base",
...@@ -1161,7 +1151,6 @@ class PLBartModel(PLBartPreTrainedModel): ...@@ -1161,7 +1151,6 @@ class PLBartModel(PLBartPreTrainedModel):
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput, output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1438,12 +1427,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel): ...@@ -1438,12 +1427,9 @@ class PLBartForSequenceClassification(PLBartPreTrainedModel):
@add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
output_type=Seq2SeqSequenceClassifierOutput, output_type=Seq2SeqSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
def forward( def forward(
......
...@@ -49,7 +49,6 @@ logger = logging.get_logger(__name__) ...@@ -49,7 +49,6 @@ logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/reformer-crime-and-punishment" _CHECKPOINT_FOR_DOC = "google/reformer-crime-and-punishment"
_CONFIG_FOR_DOC = "ReformerConfig" _CONFIG_FOR_DOC = "ReformerConfig"
_TOKENIZER_FOR_DOC = "ReformerTokenizer"
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"google/reformer-crime-and-punishment", "google/reformer-crime-and-punishment",
...@@ -2009,7 +2008,6 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2009,7 +2008,6 @@ class ReformerModel(ReformerPreTrainedModel):
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=ReformerModelOutput, output_type=ReformerModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -2220,7 +2218,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): ...@@ -2220,7 +2218,6 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel):
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -2360,13 +2357,20 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): ...@@ -2360,13 +2357,20 @@ class ReformerForMaskedLM(ReformerPreTrainedModel):
Returns: Returns:
<Tip warning={true}>
This example uses a false checkpoint since we don't have any available pretrained model for the masked language
modeling task with the Reformer architecture.
</Tip>
Example: Example:
```python ```python
>>> import torch >>> import torch
>>> from transformers import ReformerTokenizer, ReformerForMaskedLM >>> from transformers import AutoTokenizer, ReformerForMaskedLM
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
>>> model = ReformerForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> model = ReformerForMaskedLM.from_pretrained("hf-internal-testing/tiny-random-reformer")
>>> # add mask_token >>> # add mask_token
...@@ -2479,10 +2483,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2479,10 +2483,10 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import ReformerTokenizer, ReformerForSequenceClassification >>> from transformers import AutoTokenizer, ReformerForSequenceClassification
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> tokenizer = AutoTokenizer.from_pretrained("google/reformer-crime-and-punishment")
>>> model = ReformerForSequenceClassification.from_pretrained("hf-internal-testing/tiny-random-reformer") >>> model = ReformerForSequenceClassification.from_pretrained("google/reformer-crime-and-punishment")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
...@@ -2491,59 +2495,20 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2491,59 +2495,20 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
>>> predicted_class_id = logits.argmax().item() >>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id] >>> model.config.id2label[predicted_class_id]
'LABEL_1' 'LABEL_0'
``` ```
```python ```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)` >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label) >>> num_labels = len(model.config.id2label)
>>> model = ReformerForSequenceClassification.from_pretrained( >>> model = ReformerForSequenceClassification.from_pretrained(
... "hf-internal-testing/tiny-random-reformer", num_labels=num_labels ... "google/reformer-crime-and-punishment", num_labels=num_labels
... ) ... )
>>> labels = torch.tensor(1) >>> labels = torch.tensor(1)
>>> loss = model(**inputs, labels=labels).loss >>> loss = model(**inputs, labels=labels).loss
>>> round(loss.item(), 2) >>> round(loss.item(), 2)
0.69 0.68
```
Example of multi-label classification:
```python
>>> import torch
>>> from transformers import ReformerTokenizer, ReformerForSequenceClassification
>>> tokenizer = ReformerTokenizer.from_pretrained("hf-internal-testing/tiny-random-reformer")
>>> model = ReformerForSequenceClassification.from_pretrained(
... "hf-internal-testing/tiny-random-reformer", problem_type="multi_label_classification"
... )
>>> # add pad_token
>>> tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # doctest: +IGNORE_RESULT
>>> inputs = tokenizer("Hello, my dog is cute", max_length=100, padding="max_length", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id]
'LABEL_1'
```
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = ReformerForSequenceClassification.from_pretrained(
... "hf-internal-testing/tiny-random-reformer", num_labels=num_labels
... )
>>> model.train() # doctest: +IGNORE_RESULT
>>> num_labels = len(model.config.id2label)
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
... torch.float
... )
>>> loss = model(**inputs, labels=labels).loss
>>> loss.backward() # doctest: +IGNORE_RESULT
``` ```
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -2641,12 +2606,9 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): ...@@ -2641,12 +2606,9 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel):
@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint="hf-internal-testing/tiny-random-reformer",
output_type=QuestionAnsweringModelOutput, output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output="''",
expected_loss=3.28,
) )
def forward( def forward(
self, self,
......
...@@ -36,16 +36,11 @@ from .configuration_sew import SEWConfig ...@@ -36,16 +36,11 @@ from .configuration_sew import SEWConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_HIDDEN_STATES_START_POSITION = 1 _HIDDEN_STATES_START_POSITION = 1
# General docstring # General docstring
_CONFIG_FOR_DOC = "SEWConfig" _CONFIG_FOR_DOC = "SEWConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h" _CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
...@@ -58,7 +53,6 @@ _CTC_EXPECTED_OUTPUT = ( ...@@ -58,7 +53,6 @@ _CTC_EXPECTED_OUTPUT = (
_CTC_EXPECTED_LOSS = 0.42 _CTC_EXPECTED_LOSS = 0.42
# Audio class docstring # Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting" _SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 9.52 _SEQ_CLASS_EXPECTED_LOSS = 9.52
...@@ -916,7 +910,6 @@ class SEWModel(SEWPreTrainedModel): ...@@ -916,7 +910,6 @@ class SEWModel(SEWPreTrainedModel):
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput, output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1020,7 +1013,6 @@ class SEWForCTC(SEWPreTrainedModel): ...@@ -1020,7 +1013,6 @@ class SEWForCTC(SEWPreTrainedModel):
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1155,7 +1147,6 @@ class SEWForSequenceClassification(SEWPreTrainedModel): ...@@ -1155,7 +1147,6 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -42,7 +42,6 @@ _HIDDEN_STATES_START_POSITION = 1 ...@@ -42,7 +42,6 @@ _HIDDEN_STATES_START_POSITION = 1
# General docstring # General docstring
_CONFIG_FOR_DOC = "SEWDConfig" _CONFIG_FOR_DOC = "SEWDConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h" _CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
...@@ -53,7 +52,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND ...@@ -53,7 +52,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND
_CTC_EXPECTED_LOSS = 0.21 _CTC_EXPECTED_LOSS = 0.21
# Audio class docstring # Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting" _SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 3.16 _SEQ_CLASS_EXPECTED_LOSS = 3.16
...@@ -1453,7 +1451,6 @@ class SEWDModel(SEWDPreTrainedModel): ...@@ -1453,7 +1451,6 @@ class SEWDModel(SEWDPreTrainedModel):
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput, output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1557,7 +1554,6 @@ class SEWDForCTC(SEWDPreTrainedModel): ...@@ -1557,7 +1554,6 @@ class SEWDForCTC(SEWDPreTrainedModel):
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1692,7 +1688,6 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel): ...@@ -1692,7 +1688,6 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -48,7 +48,6 @@ _HIDDEN_STATES_START_POSITION = 2 ...@@ -48,7 +48,6 @@ _HIDDEN_STATES_START_POSITION = 2
# General docstring # General docstring
_CONFIG_FOR_DOC = "UniSpeechConfig" _CONFIG_FOR_DOC = "UniSpeechConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit" _CHECKPOINT_FOR_DOC = "patrickvonplaten/unispeech-large-1500h-cv-timit"
...@@ -58,12 +57,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] ...@@ -58,12 +57,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
_CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'" _CTC_EXPECTED_OUTPUT = "'mister quilter is the apposl of the midle classes and weare glad to welcom his gosepl'"
_CTC_EXPECTED_LOSS = 17.17 _CTC_EXPECTED_LOSS = 17.17
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.66 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [ UNISPEECH_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/unispeech-large-1500h-cv", "microsoft/unispeech-large-1500h-cv",
"microsoft/unispeech-large-multi-lingual-1500h-cv", "microsoft/unispeech-large-multi-lingual-1500h-cv",
...@@ -1143,7 +1136,6 @@ class UniSpeechModel(UniSpeechPreTrainedModel): ...@@ -1143,7 +1136,6 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1286,12 +1278,9 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel): ...@@ -1286,12 +1278,9 @@ class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechForPreTraining >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
>>> from transformers.models.unispeech.modeling_unispeech import _compute_mask_indices
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
... "hf-internal-testing/tiny-random-unispeech-sat"
... )
>>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv") >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
>>> # TODO: Add full pretraining example >>> # TODO: Add full pretraining example
```""" ```"""
...@@ -1395,7 +1384,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1395,7 +1384,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1482,7 +1470,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1482,7 +1470,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
""", """,
UNISPEECH_START_DOCSTRING, UNISPEECH_START_DOCSTRING,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1501,6 +1488,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1501,6 +1488,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameters will Calling this function will disable the gradient computation for the feature encoder so that its parameters will
...@@ -1513,6 +1501,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1513,6 +1501,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
) )
self.freeze_feature_encoder() self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech
def freeze_feature_encoder(self): def freeze_feature_encoder(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameter will Calling this function will disable the gradient computation for the feature encoder so that its parameter will
...@@ -1520,6 +1509,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1520,6 +1509,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
""" """
self.unispeech.feature_extractor._freeze_parameters() self.unispeech.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech
def freeze_base_model(self): def freeze_base_model(self):
""" """
Calling this function will disable the gradient computation for the base model so that its parameters will not Calling this function will disable the gradient computation for the base model so that its parameters will not
...@@ -1530,14 +1520,12 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): ...@@ -1530,14 +1520,12 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeech, wav2vec2->unispeech
def forward( def forward(
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
......
...@@ -55,7 +55,6 @@ _HIDDEN_STATES_START_POSITION = 2 ...@@ -55,7 +55,6 @@ _HIDDEN_STATES_START_POSITION = 2
# General docstring # General docstring
_CONFIG_FOR_DOC = "UniSpeechSatConfig" _CONFIG_FOR_DOC = "UniSpeechSatConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft" _CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-100h-libri-ft"
...@@ -65,12 +64,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768] ...@@ -65,12 +64,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
_CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" _CTC_EXPECTED_OUTPUT = "'MISTER QUILDER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 39.88 _CTC_EXPECTED_LOSS = 39.88
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-unispeech-sat"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.71 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
# Frame class docstring # Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" _FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0] _FRAME_EXPECTED_OUTPUT = [0, 0]
...@@ -1158,7 +1151,6 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): ...@@ -1158,7 +1151,6 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1299,10 +1291,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel): ...@@ -1299,10 +1291,10 @@ class UniSpeechSatForPreTraining(UniSpeechSatPreTrainedModel):
```python ```python
>>> import torch >>> import torch
>>> from transformers import Wav2Vec2FeatureExtractor, UniSpeechSatForPreTraining >>> from transformers import AutoFeatureExtractor, UniSpeechSatForPreTraining
>>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices >>> from transformers.models.unispeech_sat.modeling_unispeech_sat import _compute_mask_indices
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-sat-base")
>>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base") >>> model = UniSpeechSatForPreTraining.from_pretrained("microsoft/unispeech-sat-base")
>>> # TODO: Add full pretraining example >>> # TODO: Add full pretraining example
```""" ```"""
...@@ -1399,7 +1391,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1399,7 +1391,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1486,7 +1477,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1486,7 +1477,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
""", """,
UNISPEECH_SAT_START_DOCSTRING, UNISPEECH_SAT_START_DOCSTRING,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1505,6 +1495,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1505,6 +1495,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameters will Calling this function will disable the gradient computation for the feature encoder so that its parameters will
...@@ -1517,6 +1508,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1517,6 +1508,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
) )
self.freeze_feature_encoder() self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->unispeech_sat
def freeze_feature_encoder(self): def freeze_feature_encoder(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameter will Calling this function will disable the gradient computation for the feature encoder so that its parameter will
...@@ -1524,6 +1516,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1524,6 +1516,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
""" """
self.unispeech_sat.feature_extractor._freeze_parameters() self.unispeech_sat.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->unispeech_sat
def freeze_base_model(self): def freeze_base_model(self):
""" """
Calling this function will disable the gradient computation for the base model so that its parameters will not Calling this function will disable the gradient computation for the base model so that its parameters will not
...@@ -1534,14 +1527,12 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): ...@@ -1534,14 +1527,12 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat
def forward( def forward(
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
...@@ -1658,7 +1649,6 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): ...@@ -1658,7 +1649,6 @@ class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT, checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1841,7 +1831,6 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel): ...@@ -1841,7 +1831,6 @@ class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT, checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -56,7 +56,6 @@ _HIDDEN_STATES_START_POSITION = 2 ...@@ -56,7 +56,6 @@ _HIDDEN_STATES_START_POSITION = 2
# General docstring # General docstring
_CONFIG_FOR_DOC = "Wav2Vec2Config" _CONFIG_FOR_DOC = "Wav2Vec2Config"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
...@@ -67,7 +66,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND ...@@ -67,7 +66,6 @@ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND
_CTC_EXPECTED_LOSS = 53.48 _CTC_EXPECTED_LOSS = 53.48
# Audio class docstring # Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks" _SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
_SEQ_CLASS_EXPECTED_LOSS = 6.54 _SEQ_CLASS_EXPECTED_LOSS = 6.54
...@@ -1279,7 +1277,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): ...@@ -1279,7 +1277,6 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1655,7 +1652,6 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -1655,7 +1652,6 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1789,7 +1785,6 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): ...@@ -1789,7 +1785,6 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT, checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1911,7 +1906,6 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): ...@@ -1911,7 +1906,6 @@ class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT, checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -2091,7 +2085,6 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): ...@@ -2091,7 +2085,6 @@ class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT, checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
...@@ -54,7 +54,6 @@ _HIDDEN_STATES_START_POSITION = 2 ...@@ -54,7 +54,6 @@ _HIDDEN_STATES_START_POSITION = 2
# General docstring # General docstring
_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig" _CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft"
...@@ -64,20 +63,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] ...@@ -64,20 +63,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 1024]
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
_CTC_EXPECTED_LOSS = 64.21 _CTC_EXPECTED_LOSS = 64.21
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-seq-class"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
_SEQ_CLASS_EXPECTED_LOSS = 0.68
# Frame class docstring
_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-frame-class"
_FRAME_EXPECTED_OUTPUT = [1, 0]
# Speaker Verification docstring
_XVECTOR_CHECKPOINT = "hf-internal-testing/wav2vec2-conformer-xvector"
_XVECTOR_EXPECTED_OUTPUT = 1.0
WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-conformer-rel-pos-large", "facebook/wav2vec2-conformer-rel-pos-large",
...@@ -1324,7 +1309,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): ...@@ -1324,7 +1309,6 @@ class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1643,7 +1627,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ...@@ -1643,7 +1627,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1769,13 +1752,10 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode ...@@ -1769,13 +1752,10 @@ class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedMode
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward( def forward(
...@@ -1884,12 +1864,10 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo ...@@ -1884,12 +1864,10 @@ class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedMo
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_FRAME_EXPECTED_OUTPUT,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer
def forward( def forward(
...@@ -2058,12 +2036,10 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): ...@@ -2058,12 +2036,10 @@ class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_XVECTOR_EXPECTED_OUTPUT,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER
def forward( def forward(
......
...@@ -48,7 +48,6 @@ _HIDDEN_STATES_START_POSITION = 2 ...@@ -48,7 +48,6 @@ _HIDDEN_STATES_START_POSITION = 2
# General docstring # General docstring
_CONFIG_FOR_DOC = "WavLMConfig" _CONFIG_FOR_DOC = "WavLMConfig"
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus" _CHECKPOINT_FOR_DOC = "patrickvonplaten/wavlm-libri-clean-100h-base-plus"
...@@ -58,12 +57,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768] ...@@ -58,12 +57,6 @@ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
_CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'" _CTC_EXPECTED_OUTPUT = "'mister quilter is the aposle of the middle classes and we are glad to welcome his gospel'"
_CTC_EXPECTED_LOSS = 12.51 _CTC_EXPECTED_LOSS = 12.51
# Audio class docstring
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-wavlm"
_SEQ_CLASS_EXPECTED_OUTPUT = "'no'" # TODO(anton) - could you quickly fine-tune a KS WavLM Model
_SEQ_CLASS_EXPECTED_LOSS = 0.7 # TODO(anton) - could you quickly fine-tune a KS WavLM Model
# Frame class docstring # Frame class docstring
_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" _FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd"
_FRAME_EXPECTED_OUTPUT = [0, 0] _FRAME_EXPECTED_OUTPUT = [0, 0]
...@@ -1212,7 +1205,6 @@ class WavLMModel(WavLMPreTrainedModel): ...@@ -1212,7 +1205,6 @@ class WavLMModel(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Wav2Vec2BaseModelOutput, output_type=Wav2Vec2BaseModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1320,7 +1312,6 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1320,7 +1312,6 @@ class WavLMForCTC(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_PROCESSOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=CausalLMOutput, output_type=CausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1407,7 +1398,6 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1407,7 +1398,6 @@ class WavLMForCTC(WavLMPreTrainedModel):
""", """,
WAVLM_START_DOCSTRING, WAVLM_START_DOCSTRING,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM
class WavLMForSequenceClassification(WavLMPreTrainedModel): class WavLMForSequenceClassification(WavLMPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -1426,6 +1416,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ...@@ -1426,6 +1416,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
def freeze_feature_extractor(self): def freeze_feature_extractor(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameters will Calling this function will disable the gradient computation for the feature encoder so that its parameters will
...@@ -1438,6 +1429,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ...@@ -1438,6 +1429,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
) )
self.freeze_feature_encoder() self.freeze_feature_encoder()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
def freeze_feature_encoder(self): def freeze_feature_encoder(self):
""" """
Calling this function will disable the gradient computation for the feature encoder so that its parameter will Calling this function will disable the gradient computation for the feature encoder so that its parameter will
...@@ -1445,6 +1437,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ...@@ -1445,6 +1437,7 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
""" """
self.wavlm.feature_extractor._freeze_parameters() self.wavlm.feature_extractor._freeze_parameters()
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
def freeze_base_model(self): def freeze_base_model(self):
""" """
Calling this function will disable the gradient computation for the base model so that its parameters will not Calling this function will disable the gradient computation for the base model so that its parameters will not
...@@ -1455,14 +1448,12 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): ...@@ -1455,14 +1448,12 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
checkpoint=_SEQ_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="audio", modality="audio",
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
) )
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
def forward( def forward(
self, self,
input_values: Optional[torch.Tensor], input_values: Optional[torch.Tensor],
...@@ -1578,7 +1569,6 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel): ...@@ -1578,7 +1569,6 @@ class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_FRAME_CLASS_CHECKPOINT, checkpoint=_FRAME_CLASS_CHECKPOINT,
output_type=TokenClassifierOutput, output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
...@@ -1761,7 +1751,6 @@ class WavLMForXVector(WavLMPreTrainedModel): ...@@ -1761,7 +1751,6 @@ class WavLMForXVector(WavLMPreTrainedModel):
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_XVECTOR_CHECKPOINT, checkpoint=_XVECTOR_CHECKPOINT,
output_type=XVectorOutput, output_type=XVectorOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment