Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8e4ee28e
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1ac2463dfee0a133311c3c585bf7253b0400d6d3"
Unverified
Commit
8e4ee28e
authored
Oct 11, 2022
by
amyeroberts
Committed by
GitHub
Oct 11, 2022
Browse files
Update TF whisper doc tests (#19484)
parent
6c66c6c8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
7 deletions
+3
-7
src/transformers/models/whisper/modeling_tf_whisper.py
src/transformers/models/whisper/modeling_tf_whisper.py
+3
-7
No files found.
src/transformers/models/whisper/modeling_tf_whisper.py
View file @
8e4ee28e
...
@@ -1033,9 +1033,7 @@ class TFWhisperMainLayer(tf.keras.layers.Layer):
...
@@ -1033,9 +1033,7 @@ class TFWhisperMainLayer(tf.keras.layers.Layer):
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
... )
>>> input_features = inputs.input_features
>>> input_features = inputs.input_features
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
...
@@ -1160,9 +1158,7 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
...
@@ -1160,9 +1158,7 @@ class TFWhisperModel(TFWhisperPreTrainedModel):
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> model = TFWhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="tf")
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
... )
>>> input_features = inputs.input_features
>>> input_features = inputs.input_features
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
...
@@ -1288,7 +1284,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
...
@@ -1288,7 +1284,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="tf")
>>> input_features = inputs.input_features
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> generated_ids = model.generate(input
_id
s=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
>>> transcription
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment