"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7a787c68c6a287ab186f3a099c6496aaee1e8aeb"
Unverified Commit 2faa0953 authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

fix Whisper tests on GPU (#23753)

* move input features to GPU

* skip these tests because undefined behavior

* unskip tests
parent ac224dee
...@@ -1477,7 +1477,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1477,7 +1477,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device) model.to(torch_device)
input_speech = self._load_datasamples(4)[-1:] input_speech = self._load_datasamples(4)[-1:]
input_features = processor(input_speech, return_tensors="pt").input_features input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
output_without_prompt = model.generate(input_features) output_without_prompt = model.generate(input_features)
prompt_ids = processor.get_prompt_ids("Leighton") prompt_ids = processor.get_prompt_ids("Leighton")
...@@ -1494,7 +1494,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1494,7 +1494,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device) model.to(torch_device)
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
task = "translate" task = "translate"
language = "de" language = "de"
expected_tokens = [f"<|{task}|>", f"<|{language}|>"] expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
...@@ -1513,7 +1513,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1513,7 +1513,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to(torch_device) model.to(torch_device)
input_speech = self._load_datasamples(1) input_speech = self._load_datasamples(1)
input_features = processor(input_speech, return_tensors="pt").input_features input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device)
prompt = "test prompt" prompt = "test prompt"
prompt_ids = processor.get_prompt_ids(prompt) prompt_ids = processor.get_prompt_ids(prompt)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment