Unverified Commit 7c205bf4 authored by Ceyda Cinarel's avatar Ceyda Cinarel Committed by GitHub
Browse files

wav2vec2 converter: create the proper vocab.json while converting fairseq...

wav2vec2 converter: create the proper vocab.json while converting fairseq wav2vec2 finetuned model (#11041)

* add vocab while converting wav2vec2 original finetuned model

* check save directory exists

* return_attention_mask fix

* quality
parent d49d3cf6
......@@ -16,11 +16,22 @@
import argparse
import json
import os
import fairseq
import torch
from fairseq.data import Dictionary
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging
from transformers import (
Wav2Vec2Config,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2Model,
Wav2Vec2Processor,
logging,
)
logging.set_verbosity_info()
......@@ -163,11 +174,46 @@ def convert_wav2vec2_checkpoint(
config = Wav2Vec2Config()
if is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
config.bos_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.pad_token_id = target_dict.pad_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
return
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
json.dump(target_dict.indices, vocab_handle)
tokenizer = Wav2Vec2CTCTokenizer(
vocab_path,
unk_token=target_dict.unk_word,
pad_token=target_dict.pad_word,
bos_token=target_dict.bos_word,
eos_token=target_dict.eos_word,
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2Model(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment