Unverified Commit c448c01f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2] Fix convert (#11562)

* push

* small change

* correct other typo
parent 623281aa
......@@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
if dict_path:
target_dict = Dictionary.load(dict_path)
config.bos_token_id = target_dict.bos_index
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.pad_token_id = target_dict.pad_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
......@@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2Model(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment