"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2f06f2bcd66917d878db04a3b8c79968b530255c"
Unverified Commit a4ee463d authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Docs] Fix Speech Encoder Decoder doc sample (#18346)

* [Docs] Fix Speech Encoder Decoder doc sample

* improve pre-processing comment

* make style
parent da503ea0
...@@ -85,25 +85,26 @@ As you can see, only 2 inputs are required for the model in order to compute a l ...@@ -85,25 +85,26 @@ As you can see, only 2 inputs are required for the model in order to compute a l
speech inputs) and `labels` (which are the `input_ids` of the encoded target sequence). speech inputs) and `labels` (which are the `input_ids` of the encoded target sequence).
```python ```python
>>> from transformers import Wav2Vec2Processor, SpeechEncoderDecoderModel >>> from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel
>>> from datasets import load_dataset >>> from datasets import load_dataset
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") >>> encoder_id = "facebook/wav2vec2-base-960h" # acoustic model encoder
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") >>> decoder_id = "bert-base-uncased" # text decoder
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
... "facebook/wav2vec2-base-960h", "bert-base-uncased"
... )
>>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id >>> feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
>>> model.config.pad_token_id = processor.tokenizer.pad_token_id >>> tokenizer = AutoTokenizer.from_pretrained(decoder_id)
>>> # Combine pre-trained encoder and pre-trained decoder to form a Seq2Seq model
>>> model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id)
>>> # load a speech input >>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> # load an audio input and pre-process (normalise mean/std to 0/1)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values
>>> # load its corresponding transcription >>> # load its corresponding transcription and tokenize to generate labels
>>> with processor.as_target_processor(): >>> labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
>>> # the forward function automatically creates the correct decoder_input_ids >>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment