Unverified Commit 8e4ee28e authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Update TF whisper doc tests (#19484)

parent 6c66c6c8
...@@ -1033,9 +1033,7 @@ class TFWhisperMainLayer(tf.keras.layers.Layer): ...@@ -1033,9 +1033,7 @@ class TFWhisperMainLayer(tf.keras.layers.Layer):
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base") >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base") >>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor( >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
... )
>>> input_features = inputs.input_features >>> input_features = inputs.input_features
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
...@@ -1160,9 +1158,7 @@ class TFWhisperModel(TFWhisperPreTrainedModel): ...@@ -1160,9 +1158,7 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base") >>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base") >>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor( >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
... )
>>> input_features = inputs.input_features >>> input_features = inputs.input_features
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id >>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
...@@ -1288,7 +1284,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua ...@@ -1288,7 +1284,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
>>> input_features = inputs.input_features >>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features) >>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription >>> transcription
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment