"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9dd2c86033d1f883e80c8688dd370e5399c4a400"
Unverified Commit 8d518013 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Wav2Vec2 Conformer] Fix inference float16 (#25985)

* [Wav2Vec2 Conformer] Fix inference float16

* fix test

* fix test more

* clean pipe test
parent 6bc517cc
...@@ -406,13 +406,15 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): ...@@ -406,13 +406,15 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
return self.cached_rotary_positional_embedding return self.cached_rotary_positional_embedding
self.cached_sequence_length = sequence_length self.cached_sequence_length = sequence_length
# Embeddings are computed in the dtype of the inv_freq constant
time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) time_stamps = torch.arange(sequence_length).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq)
embeddings = torch.cat((freqs, freqs), dim=-1) embeddings = torch.cat((freqs, freqs), dim=-1)
cos_embeddings = embeddings.cos()[:, None, None, :] cos_embeddings = embeddings.cos()[:, None, None, :]
sin_embeddings = embeddings.sin()[:, None, None, :] sin_embeddings = embeddings.sin()[:, None, None, :]
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) # Computed embeddings are cast to the dtype of the hidden state inputs
self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states)
return self.cached_rotary_positional_embedding return self.cached_rotary_positional_embedding
......
...@@ -13,15 +13,15 @@ ...@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2-Conformer model. """ """ Testing suite for the PyTorch Wav2Vec2-Conformer model. """
import math import math
import tempfile
import unittest import unittest
import numpy as np import numpy as np
from datasets import load_dataset from datasets import load_dataset
from transformers import Wav2Vec2ConformerConfig, is_torch_available from transformers import Wav2Vec2ConformerConfig, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_torch, slow, torch_device from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torch_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
...@@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester: ...@@ -215,6 +215,23 @@ class Wav2Vec2ConformerModelTester:
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
) )
def create_and_check_model_float16(self, config, input_values, attention_mask):
model = Wav2Vec2ConformerModel(config=config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model = Wav2Vec2ConformerModel.from_pretrained(tmpdirname, torch_dtype=torch.float16)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_values.type(dtype=torch.float16), attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
)
def create_and_check_batch_inference(self, config, input_values, *args): def create_and_check_batch_inference(self, config, input_values, *args):
# test does not pass for models making use of `group_norm` # test does not pass for models making use of `group_norm`
# check: https://github.com/pytorch/fairseq/issues/3227 # check: https://github.com/pytorch/fairseq/issues/3227
...@@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest ...@@ -451,6 +468,16 @@ class Wav2Vec2ConformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
@require_torch_gpu
def test_model_float16_with_relative(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="relative")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
@require_torch_gpu
def test_model_float16_with_rotary(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(position_embeddings_type="rotary")
self.model_tester.create_and_check_model_float16(*config_and_inputs)
def test_ctc_loss_inference(self): def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_ctc_loss(*config_and_inputs) self.model_tester.check_ctc_loss(*config_and_inputs)
......
...@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -901,6 +901,26 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = speech_recognizer(filename) output = speech_recognizer(filename)
self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow
@require_torch_gpu
def test_wav2vec2_conformer_float16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/wav2vec2-conformer-rope-large-960h-ft",
device="cuda:0",
torch_dtype=torch.float16,
framework="pt",
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
output = speech_recognizer(sample)
self.assertEqual(
output,
{"text": "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL"},
)
@require_torch @require_torch
def test_chunking_fast(self): def test_chunking_fast(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment