Unverified Commit b972125c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089)

* add wav2vec2CTC and deprecate for maskedlm

* remove from docs
parent ba542ffb
...@@ -58,8 +58,8 @@ Wav2Vec2Model ...@@ -58,8 +58,8 @@ Wav2Vec2Model
:members: forward :members: forward
Wav2Vec2ForMaskedLM Wav2Vec2ForCTC
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Wav2Vec2ForMaskedLM .. autoclass:: transformers.Wav2Vec2ForCTC
:members: forward :members: forward
...@@ -367,6 +367,7 @@ if is_torch_available(): ...@@ -367,6 +367,7 @@ if is_torch_available():
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForCTC",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
...@@ -1813,6 +1814,7 @@ if TYPE_CHECKING: ...@@ -1813,6 +1814,7 @@ if TYPE_CHECKING:
) )
from .models.wav2vec2 import ( from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
......
...@@ -29,6 +29,7 @@ if is_torch_available(): ...@@ -29,6 +29,7 @@ if is_torch_available():
_import_structure["modeling_wav2vec2"] = [ _import_structure["modeling_wav2vec2"] = [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
"Wav2Vec2ForMaskedLM", "Wav2Vec2ForMaskedLM",
"Wav2Vec2ForCTC",
"Wav2Vec2Model", "Wav2Vec2Model",
"Wav2Vec2PreTrainedModel", "Wav2Vec2PreTrainedModel",
] ]
...@@ -41,6 +42,7 @@ if TYPE_CHECKING: ...@@ -41,6 +42,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_wav2vec2 import ( from .modeling_wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC,
Wav2Vec2ForMaskedLM, Wav2Vec2ForMaskedLM,
Wav2Vec2Model, Wav2Vec2Model,
Wav2Vec2PreTrainedModel, Wav2Vec2PreTrainedModel,
......
...@@ -20,7 +20,7 @@ import argparse ...@@ -20,7 +20,7 @@ import argparse
import fairseq import fairseq
import torch import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
logging.set_verbosity_info() logging.set_verbosity_info()
...@@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_ ...@@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_
""" """
Copy/paste/tweak model's weights to transformers design. Copy/paste/tweak model's weights to transformers design.
""" """
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config()) hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path} [checkpoint_path], arg_overrides={"data": dict_path}
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" PyTorch Wav2Vec2 model. """ """ PyTorch Wav2Vec2 model. """
import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
...@@ -24,7 +25,7 @@ from torch import nn ...@@ -24,7 +25,7 @@ from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from .configuration_wav2vec2 import Wav2Vec2Config from .configuration_wav2vec2 import Wav2Vec2Config
...@@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): ...@@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
warnings.warn(
"The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
)
self.wav2vec2 = Wav2Vec2Model(config) self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
...@@ -729,3 +734,77 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): ...@@ -729,3 +734,77 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
return output return output
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
@add_start_docstrings(
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
WAV_2_VEC_2_START_DOCSTRING,
)
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
r"""
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training
Returns:
Example::
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
>>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
>>> def map_to_array(batch):
>>> speech, _ = sf.read(batch["file"])
>>> batch["speech"] = speech
>>> return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = tokenizer.decode(predicted_ids[0])
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.lm_head(hidden_states)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
...@@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs): ...@@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Wav2Vec2ForCTC:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class Wav2Vec2ForMaskedLM: class Wav2Vec2ForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init ...@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
class Wav2Vec2ModelTester: class Wav2Vec2ModelTester:
...@@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
test_pruning = False test_pruning = False
test_headmasking = False test_headmasking = False
test_torchscript = False test_torchscript = False
...@@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples] return ds["speech"][:num_samples]
def test_inference_masked_lm_normal(self): def test_inference_masked_lm_normal(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device) model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
...@@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_normal_batched(self): def test_inference_masked_lm_normal_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device) model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
...@@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): ...@@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_robust_batched(self): def test_inference_masked_lm_robust_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device) model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True) tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4) input_speech = self._load_datasamples(4)
......
...@@ -118,6 +118,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -118,6 +118,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"TFMT5EncoderModel", "TFMT5EncoderModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFT5EncoderModel", "TFT5EncoderModel",
"Wav2Vec2ForCTC",
"XLMForQuestionAnswering", "XLMForQuestionAnswering",
"XLMProphetNetDecoder", "XLMProphetNetDecoder",
"XLMProphetNetEncoder", "XLMProphetNetEncoder",
...@@ -370,6 +371,7 @@ DEPRECATED_OBJECTS = [ ...@@ -370,6 +371,7 @@ DEPRECATED_OBJECTS = [
"TFBartPretrainedModel", "TFBartPretrainedModel",
"TextDataset", "TextDataset",
"TextDatasetForNextSentencePrediction", "TextDatasetForNextSentencePrediction",
"Wav2Vec2ForMaskedLM",
"glue_compute_metrics", "glue_compute_metrics",
"glue_convert_examples_to_features", "glue_convert_examples_to_features",
"glue_output_modes", "glue_output_modes",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment