Unverified Commit 10704e12 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Test] Fix W2V-Conformer integration test (#17303)

* [Test] Fix W2V-Conformer integration test

* correct w2v2

* up
parent 28a08116
......@@ -1414,7 +1414,6 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
>>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
>>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
......
......@@ -1442,7 +1442,7 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
@add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2-base->wav2vec2-conformer-rel-pos-large,wav2vec2->wav2vec2_conformer
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large
def forward(
self,
input_values: Optional[torch.Tensor],
......@@ -1470,14 +1470,9 @@ class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel):
>>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining
>>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import _compute_mask_indices
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> feature_extractor = AutoFeatureExtractor.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained(
... "facebook/wav2vec2_conformer-conformer-rel-pos-large"
... )
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
......
......@@ -581,6 +581,10 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
module.weight_v.data.fill_(3)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.fill_(3)
if hasattr(module, "pos_bias_u") and module.pos_bias_u is not None:
module.pos_bias_u.data.fill_(3)
if hasattr(module, "pos_bias_v") and module.pos_bias_v is not None:
module.pos_bias_v.data.fill_(3)
if hasattr(module, "codevectors") and module.codevectors is not None:
module.codevectors.data.fill_(3)
if hasattr(module, "masked_spec_embed") and module.masked_spec_embed is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment