Unverified Commit 087436c9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

Fix-ci-whisper (#21767)

* fix history

* input_features instead of input ids for TFWhisport doctest

* use translate intead of transcribe
parent c8545d2a
...@@ -1283,7 +1283,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua ...@@ -1283,7 +1283,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
>>> input_features = inputs.input_features >>> input_features = inputs.input_features
>>> generated_ids = model.generate(input_ids=input_features) >>> generated_ids = model.generate(input_features=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription >>> transcription
......
...@@ -1187,7 +1187,7 @@ class WhisperModelIntegrationTests(unittest.TestCase): ...@@ -1187,7 +1187,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_speech = self._load_datasamples(4) input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features
generated_ids = model.generate(input_features, max_length=20) generated_ids = model.generate(input_features, max_length=20, task="translate")
# fmt: off # fmt: off
EXPECTED_LOGITS = torch.tensor( EXPECTED_LOGITS = torch.tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment