Unverified Commit 0e83c966 authored by heatz123's avatar heatz123 Committed by GitHub
Browse files

Fix fairseq wav2vec2-xls-r pretrained weights conversion scripts (#19508)

* fix loading fairseq wav2vec2 pretrained weights

Specified fairseq task as "audio_pretraining" when loading fairseq weights,
since loading wav2vec2-xls-r weights fails if the task is unspecified.

Resolves: #19319

* fix style
parent 4212bb0d
...@@ -246,7 +246,10 @@ def convert_wav2vec2_checkpoint( ...@@ -246,7 +246,10 @@ def convert_wav2vec2_checkpoint(
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
) )
else: else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) task_arg = argparse.Namespace(task="audio_pretraining")
task = fairseq.tasks.setup_task(task_arg)
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)
model = model[0].eval() model = model[0].eval()
......
...@@ -283,7 +283,10 @@ def convert_wav2vec2_conformer_checkpoint( ...@@ -283,7 +283,10 @@ def convert_wav2vec2_conformer_checkpoint(
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
) )
else: else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) task_arg = argparse.Namespace(task="audio_pretraining")
task = fairseq.tasks.setup_task(task_arg)
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task)
model = model[0].eval() model = model[0].eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment