Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
f1a0b605
Unverified
Commit
f1a0b605
authored
Jun 01, 2021
by
moto
Committed by
GitHub
Jun 01, 2021
Browse files
Add wav2vec2 fairseq importer (#1531)
parent
07d9bc21
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1285 additions
and
0 deletions
+1285
-0
.circleci/unittest/linux/scripts/install.sh
.circleci/unittest/linux/scripts/install.sh
+5
-0
.circleci/unittest/windows/scripts/install.sh
.circleci/unittest/windows/scripts/install.sh
+5
-0
docs/source/models.rst
docs/source/models.rst
+2
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_fairseq_model_config.py
.../assets/wav2vec2/fairseq/generate_fairseq_model_config.py
+106
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/libri960_big.json
...haudio_unittest/assets/wav2vec2/fairseq/libri960_big.json
+54
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json
..._unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json
+146
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json
...est/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json
+146
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json
...ssets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json
+146
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json
...audio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json
+54
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json
..._unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json
+146
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json
...dio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json
+54
-0
test/torchaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json
...chaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json
+51
-0
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
...udio_unittest/models/wav2vec2/fairseq_integration_test.py
+167
-0
torchaudio/models/wav2vec2/utils/__init__.py
torchaudio/models/wav2vec2/utils/__init__.py
+2
-0
torchaudio/models/wav2vec2/utils/import_fairseq.py
torchaudio/models/wav2vec2/utils/import_fairseq.py
+201
-0
No files found.
.circleci/unittest/linux/scripts/install.sh
View file @
f1a0b605
...
...
@@ -58,3 +58,8 @@ fi
conda
install
-y
-c
conda-forge
${
NUMBA_DEV_CHANNEL
}
'librosa>=0.8.0'
parameterized
'requests>=2.20'
pip
install
kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
cd
fairseq
git checkout e6eddd80
pip
install
.
.circleci/unittest/windows/scripts/install.sh
View file @
f1a0b605
...
...
@@ -46,3 +46,8 @@ fi
conda
install
-y
-c
conda-forge
${
NUMBA_DEV_CHANNEL
}
'librosa>=0.8.0'
parameterized
'requests>=2.20'
pip
install
kaldi-io SoundFile coverage pytest pytest-cov scipy transformers
)
# Install fairseq
git clone https://github.com/pytorch/fairseq
cd
fairseq
git checkout e6eddd80
pip
install
.
docs/source/models.rst
View file @
f1a0b605
...
...
@@ -62,6 +62,8 @@ Utility Functions
.. autofunction:: import_huggingface_model
.. autofunction:: import_fairseq_model
.. currentmodule:: torchaudio.models
:hidden:`WaveRNN`
...
...
test/torchaudio_unittest/assets/wav2vec2/fairseq/generate_fairseq_model_config.py
0 → 100644
View file @
f1a0b605
#!/usr/bin/env python3
"""Generate the conf JSON from fairseq pretrained weight file, that is consumed by unit tests
Usage:
1. Download pretrained parameters from https://github.com/pytorch/fairseq/tree/master/examples/wav2vec
2. Download the dict from https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt
and put it in the same directory as parameter files.
3. Run this script and save the resulting JSON configuration in assets directory.
Example:
```
# Pretrained
python generate_fairseq_model_config.py
\
--model-file wav2vec_small.pt
\
> wav2vec_small.json
python generate_fairseq_model_config.py
\
--model-file libri960_big.pt
\
> libri960_big.json
python generate_fairseq_model_config.py
\
--model-file wav2vec_vox_new.pt
\
> wav2vec_vox_new.json
# Fine-tuned
python generate_fairseq_model_config.py
\
--model-file wav2vec_small_960h.pt
\
> wav2vec_small_960h.json
python generate_fairseq_model_config.py
\
--model-file wav2vec_big_960h.pt
\
> wav2vec_large_960h.json
python generate_fairseq_model_config.py
\
--model-file wav2vec2_vox_960h_new.pt
\
> wav2vec_large_lv60_960h.json
python generate_fairseq_model_config.py
\
--model-file wav2vec_vox_960h_pl.pt
\
> wav2vec_large_lv60_self_960h.json
```
"""
import
os
import
json
import
argparse
def
_parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
,
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
'--model-file'
,
required
=
True
,
help
=
(
'A point file from '
'https://github.com/pytorch/fairseq/tree/master/examples/wav2vec'
)
)
parser
.
add_argument
(
'--dict-dir'
,
help
=
(
'Directory where `dict.ltr.txt` file is found. '
'Default: the directory of the given model.'
)
)
args
=
parser
.
parse_args
()
if
args
.
dict_dir
is
None
:
args
.
dict_dir
=
os
.
path
.
dirname
(
args
.
model_file
)
return
args
def
_to_json
(
conf
):
import
yaml
from
omegaconf
import
OmegaConf
return
yaml
.
safe_load
(
OmegaConf
.
to_yaml
(
conf
))
def
_load
(
model_file
,
dict_dir
):
import
fairseq
overrides
=
{
'data'
:
dict_dir
}
_
,
args
,
_
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
(
[
model_file
],
arg_overrides
=
overrides
)
return
_to_json
(
args
[
'model'
])
def
_main
():
args
=
_parse_args
()
conf
=
_load
(
args
.
model_file
,
args
.
dict_dir
)
if
conf
[
'_name'
]
==
'wav2vec_ctc'
:
del
conf
[
'data'
]
del
conf
[
'w2v_args'
][
'task'
][
'data'
]
conf
[
'w2v_args'
]
=
{
key
:
conf
[
'w2v_args'
][
key
]
for
key
in
[
'model'
,
'task'
]
}
print
(
json
.
dumps
(
conf
,
indent
=
4
,
sort_keys
=
True
))
if
__name__
==
'__main__'
:
_main
()
test/torchaudio_unittest/assets/wav2vec2/fairseq/libri960_big.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
false
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.2
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"default"
,
"feature_grad_mult"
:
0.1
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.5
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
false
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_960h.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec_ctc"
,
"activation_dropout"
:
0.1
,
"apply_mask"
:
true
,
"attention_dropout"
:
0.0
,
"blank_mode"
:
"add"
,
"blank_weight"
:
0.0
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
,
"dropout"
:
0.0
,
"dropout_input"
:
0.0
,
"encoder_embed_dim"
:
512
,
"feature_grad_mult"
:
0.0
,
"final_dropout"
:
0.0
,
"freeze_finetune_updates"
:
10000
,
"layerdrop"
:
0.2
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
64
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.1
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.5
,
"mask_selection"
:
"static"
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"no_pretrained_weights"
:
false
,
"normalize"
:
false
,
"w2v_args"
:
{
"model"
:
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
false
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.2
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"default"
,
"feature_grad_mult"
:
0.1
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.5
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
false
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
},
"task"
:
{
"_name"
:
"audio_pretraining"
,
"autoregressive"
:
false
,
"binarized_dataset"
:
false
,
"enable_padding"
:
false
,
"eval_wer"
:
false
,
"eval_wer_config"
:
{
"beam"
:
5
,
"constraints"
:
null
,
"decoding_format"
:
null
,
"diverse_beam_groups"
:
-1
,
"diverse_beam_strength"
:
0.5
,
"diversity_rate"
:
-1.0
,
"iter_decode_eos_penalty"
:
0.0
,
"iter_decode_force_max_iter"
:
false
,
"iter_decode_max_iter"
:
10
,
"iter_decode_with_beam"
:
1
,
"iter_decode_with_external_reranker"
:
false
,
"lenpen"
:
1.0
,
"lm_path"
:
null
,
"lm_weight"
:
0.0
,
"match_source_len"
:
false
,
"max_len_a"
:
0.0
,
"max_len_b"
:
200
,
"min_len"
:
1
,
"nbest"
:
1
,
"no_beamable_mm"
:
false
,
"no_early_stop"
:
false
,
"no_repeat_ngram_size"
:
0
,
"no_seed_provided"
:
false
,
"prefix_size"
:
0
,
"print_alignment"
:
null
,
"print_step"
:
false
,
"replace_unk"
:
null
,
"retain_dropout"
:
false
,
"retain_dropout_modules"
:
null
,
"retain_iter_history"
:
false
,
"sacrebleu"
:
false
,
"sampling"
:
false
,
"sampling_topk"
:
-1
,
"sampling_topp"
:
-1.0
,
"score_reference"
:
false
,
"temperature"
:
1.0
,
"unkpen"
:
0.0
,
"unnormalized"
:
false
},
"eval_wer_post_process"
:
"letter"
,
"eval_wer_tokenizer"
:
null
,
"inferred_w2v_config"
:
null
,
"labels"
:
null
,
"max_sample_size"
:
320000
,
"min_sample_size"
:
32000
,
"normalize"
:
false
,
"num_batch_buckets"
:
0
,
"precompute_mask_indices"
:
false
,
"sample_rate"
:
16000
,
"tpu"
:
true
}
},
"w2v_path"
:
"???"
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_960h.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec_ctc"
,
"activation_dropout"
:
0.1
,
"apply_mask"
:
true
,
"attention_dropout"
:
0.0
,
"blank_mode"
:
"add"
,
"blank_weight"
:
0.0
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
,
"dropout"
:
0.0
,
"dropout_input"
:
0.0
,
"encoder_embed_dim"
:
512
,
"feature_grad_mult"
:
0.0
,
"final_dropout"
:
0.0
,
"freeze_finetune_updates"
:
10000
,
"layerdrop"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
64
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.25
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.5
,
"mask_selection"
:
"static"
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"no_pretrained_weights"
:
false
,
"normalize"
:
true
,
"w2v_args"
:
{
"model"
:
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
true
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.0
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"layer_norm"
,
"feature_grad_mult"
:
1.0
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.1
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
true
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
},
"task"
:
{
"_name"
:
"audio_pretraining"
,
"autoregressive"
:
false
,
"binarized_dataset"
:
false
,
"enable_padding"
:
false
,
"eval_wer"
:
false
,
"eval_wer_config"
:
{
"beam"
:
5
,
"constraints"
:
null
,
"decoding_format"
:
null
,
"diverse_beam_groups"
:
-1
,
"diverse_beam_strength"
:
0.5
,
"diversity_rate"
:
-1.0
,
"iter_decode_eos_penalty"
:
0.0
,
"iter_decode_force_max_iter"
:
false
,
"iter_decode_max_iter"
:
10
,
"iter_decode_with_beam"
:
1
,
"iter_decode_with_external_reranker"
:
false
,
"lenpen"
:
1.0
,
"lm_path"
:
null
,
"lm_weight"
:
0.0
,
"match_source_len"
:
false
,
"max_len_a"
:
0.0
,
"max_len_b"
:
200
,
"min_len"
:
1
,
"nbest"
:
1
,
"no_beamable_mm"
:
false
,
"no_early_stop"
:
false
,
"no_repeat_ngram_size"
:
0
,
"no_seed_provided"
:
false
,
"prefix_size"
:
0
,
"print_alignment"
:
null
,
"print_step"
:
false
,
"replace_unk"
:
null
,
"retain_dropout"
:
false
,
"retain_dropout_modules"
:
null
,
"retain_iter_history"
:
false
,
"sacrebleu"
:
false
,
"sampling"
:
false
,
"sampling_topk"
:
-1
,
"sampling_topp"
:
-1.0
,
"score_reference"
:
false
,
"temperature"
:
1.0
,
"unkpen"
:
0.0
,
"unnormalized"
:
false
},
"eval_wer_post_process"
:
"letter"
,
"eval_wer_tokenizer"
:
null
,
"inferred_w2v_config"
:
null
,
"labels"
:
null
,
"max_sample_size"
:
320000
,
"min_sample_size"
:
32000
,
"normalize"
:
true
,
"num_batch_buckets"
:
0
,
"precompute_mask_indices"
:
false
,
"sample_rate"
:
16000
,
"tpu"
:
true
}
},
"w2v_path"
:
"???"
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_large_lv60k_self_960h.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec_ctc"
,
"activation_dropout"
:
0.1
,
"apply_mask"
:
true
,
"attention_dropout"
:
0.0
,
"blank_mode"
:
"add"
,
"blank_weight"
:
0.0
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
,
"dropout"
:
0.0
,
"dropout_input"
:
0.0
,
"encoder_embed_dim"
:
768
,
"feature_grad_mult"
:
0.0
,
"final_dropout"
:
0.0
,
"freeze_finetune_updates"
:
10000
,
"layerdrop"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
64
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.1
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.1
,
"mask_selection"
:
"static"
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"no_pretrained_weights"
:
false
,
"normalize"
:
true
,
"w2v_args"
:
{
"model"
:
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
true
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.0
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"layer_norm"
,
"feature_grad_mult"
:
1.0
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.1
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
true
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
},
"task"
:
{
"_name"
:
"audio_pretraining"
,
"autoregressive"
:
false
,
"binarized_dataset"
:
false
,
"enable_padding"
:
false
,
"eval_wer"
:
false
,
"eval_wer_config"
:
{
"beam"
:
5
,
"constraints"
:
null
,
"decoding_format"
:
null
,
"diverse_beam_groups"
:
-1
,
"diverse_beam_strength"
:
0.5
,
"diversity_rate"
:
-1.0
,
"iter_decode_eos_penalty"
:
0.0
,
"iter_decode_force_max_iter"
:
false
,
"iter_decode_max_iter"
:
10
,
"iter_decode_with_beam"
:
1
,
"iter_decode_with_external_reranker"
:
false
,
"lenpen"
:
1.0
,
"lm_path"
:
null
,
"lm_weight"
:
0.0
,
"match_source_len"
:
false
,
"max_len_a"
:
0.0
,
"max_len_b"
:
200
,
"min_len"
:
1
,
"nbest"
:
1
,
"no_beamable_mm"
:
false
,
"no_early_stop"
:
false
,
"no_repeat_ngram_size"
:
0
,
"no_seed_provided"
:
false
,
"prefix_size"
:
0
,
"print_alignment"
:
null
,
"print_step"
:
false
,
"replace_unk"
:
null
,
"retain_dropout"
:
false
,
"retain_dropout_modules"
:
null
,
"retain_iter_history"
:
false
,
"sacrebleu"
:
false
,
"sampling"
:
false
,
"sampling_topk"
:
-1
,
"sampling_topp"
:
-1.0
,
"score_reference"
:
false
,
"temperature"
:
1.0
,
"unkpen"
:
0.0
,
"unnormalized"
:
false
},
"eval_wer_post_process"
:
"letter"
,
"eval_wer_tokenizer"
:
null
,
"inferred_w2v_config"
:
null
,
"labels"
:
null
,
"max_sample_size"
:
320000
,
"min_sample_size"
:
32000
,
"normalize"
:
true
,
"num_batch_buckets"
:
0
,
"precompute_mask_indices"
:
false
,
"sample_rate"
:
16000
,
"tpu"
:
true
}
},
"w2v_path"
:
"/private/home/abaevski/models/wav2vec2/wav2vec_vox_new.pt"
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
false
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.1
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
12
,
"encoder_embed_dim"
:
768
,
"encoder_ffn_embed_dim"
:
3072
,
"encoder_layerdrop"
:
0.05
,
"encoder_layers"
:
12
,
"extractor_mode"
:
"default"
,
"feature_grad_mult"
:
0.1
,
"final_dim"
:
256
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.5
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
false
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_small_960h.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec_ctc"
,
"activation_dropout"
:
0.1
,
"apply_mask"
:
true
,
"attention_dropout"
:
0.0
,
"blank_mode"
:
"add"
,
"blank_weight"
:
0.0
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
,
"dropout"
:
0.0
,
"dropout_input"
:
0.0
,
"encoder_embed_dim"
:
512
,
"feature_grad_mult"
:
0.0
,
"final_dropout"
:
0.0
,
"freeze_finetune_updates"
:
0
,
"layerdrop"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
64
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.1
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.5
,
"mask_selection"
:
"static"
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"no_pretrained_weights"
:
false
,
"normalize"
:
false
,
"w2v_args"
:
{
"model"
:
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
false
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.1
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
12
,
"encoder_embed_dim"
:
768
,
"encoder_ffn_embed_dim"
:
3072
,
"encoder_layerdrop"
:
0.05
,
"encoder_layers"
:
12
,
"extractor_mode"
:
"default"
,
"feature_grad_mult"
:
0.1
,
"final_dim"
:
256
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2
,
0.5
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
false
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
},
"task"
:
{
"_name"
:
"audio_pretraining"
,
"autoregressive"
:
false
,
"binarized_dataset"
:
false
,
"enable_padding"
:
false
,
"eval_wer"
:
false
,
"eval_wer_config"
:
{
"beam"
:
5
,
"constraints"
:
null
,
"decoding_format"
:
null
,
"diverse_beam_groups"
:
-1
,
"diverse_beam_strength"
:
0.5
,
"diversity_rate"
:
-1.0
,
"iter_decode_eos_penalty"
:
0.0
,
"iter_decode_force_max_iter"
:
false
,
"iter_decode_max_iter"
:
10
,
"iter_decode_with_beam"
:
1
,
"iter_decode_with_external_reranker"
:
false
,
"lenpen"
:
1.0
,
"lm_path"
:
null
,
"lm_weight"
:
0.0
,
"match_source_len"
:
false
,
"max_len_a"
:
0.0
,
"max_len_b"
:
200
,
"min_len"
:
1
,
"nbest"
:
1
,
"no_beamable_mm"
:
false
,
"no_early_stop"
:
false
,
"no_repeat_ngram_size"
:
0
,
"no_seed_provided"
:
false
,
"prefix_size"
:
0
,
"print_alignment"
:
null
,
"print_step"
:
false
,
"replace_unk"
:
null
,
"retain_dropout"
:
false
,
"retain_dropout_modules"
:
null
,
"retain_iter_history"
:
false
,
"sacrebleu"
:
false
,
"sampling"
:
false
,
"sampling_topk"
:
-1
,
"sampling_topp"
:
-1.0
,
"score_reference"
:
false
,
"temperature"
:
1.0
,
"unkpen"
:
0.0
,
"unnormalized"
:
false
},
"eval_wer_post_process"
:
"letter"
,
"eval_wer_tokenizer"
:
null
,
"inferred_w2v_config"
:
null
,
"labels"
:
null
,
"max_sample_size"
:
250000
,
"min_sample_size"
:
32000
,
"normalize"
:
false
,
"num_batch_buckets"
:
0
,
"precompute_mask_indices"
:
false
,
"sample_rate"
:
16000
,
"tpu"
:
true
}
},
"w2v_path"
:
"???"
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/wav2vec_vox_new.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.1
,
"codebook_negatives"
:
0
,
"conv_bias"
:
true
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.1
,
"dropout_input"
:
0.1
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.0
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"layer_norm"
,
"feature_grad_mult"
:
1.0
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.1
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
true
,
"logit_temp"
:
0.1
,
"mask_channel_before"
:
false
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"quantizer_depth"
:
1
,
"quantizer_factor"
:
3
,
"same_quantizer"
:
false
,
"target_glu"
:
false
}
test/torchaudio_unittest/assets/wav2vec2/fairseq/xlsr_53_56k.json
0 → 100644
View file @
f1a0b605
{
"_name"
:
"wav2vec2"
,
"activation_dropout"
:
0.0
,
"activation_fn"
:
"gelu"
,
"attention_dropout"
:
0.0
,
"codebook_negatives"
:
0
,
"conv_bias"
:
true
,
"conv_feature_layers"
:
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2"
,
"conv_pos"
:
128
,
"conv_pos_groups"
:
16
,
"cross_sample_negatives"
:
0
,
"dropout"
:
0.0
,
"dropout_features"
:
0.0
,
"dropout_input"
:
0.0
,
"encoder_attention_heads"
:
16
,
"encoder_embed_dim"
:
1024
,
"encoder_ffn_embed_dim"
:
4096
,
"encoder_layerdrop"
:
0.0
,
"encoder_layers"
:
24
,
"extractor_mode"
:
"layer_norm"
,
"feature_grad_mult"
:
1.0
,
"final_dim"
:
768
,
"latent_dim"
:
0
,
"latent_groups"
:
2
,
"latent_temp"
:
[
2.0
,
0.1
,
0.999995
],
"latent_vars"
:
320
,
"layer_norm_first"
:
true
,
"logit_temp"
:
0.1
,
"mask_channel_length"
:
10
,
"mask_channel_min_space"
:
1
,
"mask_channel_other"
:
0.0
,
"mask_channel_prob"
:
0.0
,
"mask_channel_selection"
:
"static"
,
"mask_length"
:
10
,
"mask_min_space"
:
1
,
"mask_other"
:
0.0
,
"mask_prob"
:
0.65
,
"mask_selection"
:
"static"
,
"negatives_from_everywhere"
:
false
,
"no_mask_channel_overlap"
:
false
,
"no_mask_overlap"
:
false
,
"num_negatives"
:
100
,
"quantize_input"
:
false
,
"quantize_targets"
:
true
,
"same_quantizer"
:
false
,
"target_glu"
:
false
}
test/torchaudio_unittest/models/wav2vec2/fairseq_integration_test.py
0 → 100644
View file @
f1a0b605
import
json
import
torch
from
torchaudio.models.wav2vec2
import
(
wav2vec2_base
,
wav2vec2_large
,
wav2vec2_large_lv60k
,
)
from
torchaudio.models.wav2vec2.utils
import
(
import_fairseq_model
,
)
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
(
get_asset_path
,
skipIfNoModule
,
TorchaudioTestCase
,
)
def
_load_config
(
*
paths
):
with
open
(
f
'
{
get_asset_path
(
"wav2vec2"
,
"fairseq"
,
*
paths
)
}
.json'
,
'r'
)
as
file_
:
return
json
.
load
(
file_
)
# Pretrined (not fine-tuned) models
BASE
=
_load_config
(
'wav2vec_small'
)
LARGE
=
_load_config
(
'libri960_big'
)
LARGE_LV60K
=
_load_config
(
'wav2vec_vox_new'
)
XLSR_53_56K
=
_load_config
(
'xlsr_53_56k'
)
# Fine-tuned models
BASE_960H
=
_load_config
(
'wav2vec_small_960h'
)
LARGE_960H
=
_load_config
(
'wav2vec_large_960h'
)
LARGE_LV60K_960H
=
_load_config
(
'wav2vec_large_lv60k_960h'
)
LARGE_LV60K_SELF_960H
=
_load_config
(
'wav2vec_large_lv60k_self_960h'
)
# Config and corresponding factory functions
PRETRAINED_CONFIGS
=
[
(
BASE
,
wav2vec2_base
),
(
LARGE
,
wav2vec2_large
),
(
LARGE_LV60K
,
wav2vec2_large_lv60k
),
(
XLSR_53_56K
,
wav2vec2_large_lv60k
),
]
FINETUNED_CONFIGS
=
[
(
BASE_960H
,
wav2vec2_base
),
(
LARGE_960H
,
wav2vec2_large
),
(
LARGE_LV60K_960H
,
wav2vec2_large_lv60k
),
(
LARGE_LV60K_SELF_960H
,
wav2vec2_large_lv60k
),
]
@
skipIfNoModule
(
'fairseq'
)
class
TestFairseqIntegration
(
TorchaudioTestCase
):
"""Test the process of importing the models from fairseq.
Test methods in this test suite check the following things
1. Models loaded with fairseq cane be imported.
2. The same model can be recreated without fairseq.
"""
def
_get_model
(
self
,
config
,
num_out
):
import
copy
from
omegaconf
import
OmegaConf
from
fairseq.models.wav2vec.wav2vec2
import
(
Wav2Vec2Config
,
Wav2Vec2Model
,
)
from
fairseq.models.wav2vec.wav2vec2_asr
import
(
Wav2VecEncoder
,
Wav2Vec2CtcConfig
,
)
if
config
[
'_name'
]
==
'wav2vec_ctc'
:
config
=
copy
.
deepcopy
(
config
)
config
[
'w2v_args'
]
=
OmegaConf
.
create
(
config
[
'w2v_args'
])
return
Wav2VecEncoder
(
Wav2Vec2CtcConfig
(
**
config
),
num_out
)
if
config
[
'_name'
]
==
'wav2vec2'
:
return
Wav2Vec2Model
(
Wav2Vec2Config
(
**
config
))
@
parameterized
.
expand
([
conf
[:
1
]
for
conf
in
PRETRAINED_CONFIGS
])
def
test_import_pretrained_model
(
self
,
config
):
"""Pretrained wav2vec2 models from fairseq can be imported and yields the same results"""
num_out
=
28
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
,
num_out
).
eval
()
imported
=
import_fairseq_model
(
original
,
28
).
eval
()
x
=
torch
.
randn
(
batch_size
,
num_frames
)
ref
=
original
.
feature_extractor
(
x
).
transpose
(
1
,
2
)
hyp
,
_
=
imported
.
extract_features
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
@
parameterized
.
expand
(
PRETRAINED_CONFIGS
)
def
test_recreate_pretrained_model
(
self
,
config
,
factory_func
):
"""Imported pretrained models can be recreated via a factory function without fairseq."""
num_out
=
28
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
,
num_out
).
eval
()
imported
=
import_fairseq_model
(
original
,
28
).
eval
()
reloaded
=
factory_func
(
num_out
=
num_out
)
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
x
=
torch
.
randn
(
batch_size
,
num_frames
)
lengths
=
torch
.
randint
(
low
=
0
,
high
=
num_frames
,
size
=
[
batch_size
,
])
# Without mask
ref
,
_
=
imported
(
x
)
hyp
,
_
=
reloaded
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# With mask
ref
,
ref_lengths
=
imported
(
x
,
lengths
)
hyp
,
hyp_lengths
=
reloaded
(
x
,
lengths
)
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref_lengths
,
hyp_lengths
)
@
parameterized
.
expand
([
conf
[:
1
]
for
conf
in
FINETUNED_CONFIGS
])
def
test_import_finetuned_model
(
self
,
config
):
"""Fintuned wav2vec2 models from fairseq can be imported and yields the same results"""
num_out
=
28
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
,
num_out
).
eval
()
imported
=
import_fairseq_model
(
original
).
eval
()
# Without mask
x
=
torch
.
randn
(
batch_size
,
num_frames
)
ref
=
original
(
x
,
torch
.
zeros_like
(
x
))[
'encoder_out'
].
transpose
(
0
,
1
)
hyp
,
_
=
imported
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# With mask
lengths
=
torch
.
randint
(
low
=
0
,
high
=
num_frames
,
size
=
[
batch_size
,
])
mask
=
torch
.
arange
(
num_frames
).
expand
(
batch_size
,
num_frames
)
>=
lengths
[:,
None
]
ref
=
original
(
x
,
mask
)[
'encoder_out'
].
transpose
(
0
,
1
)
hyp
,
output_lengths
=
imported
(
x
,
lengths
)
for
i
,
l
in
enumerate
(
output_lengths
):
self
.
assertEqual
(
ref
[
i
,
:
l
,
...],
hyp
[
i
,
:
l
,
...])
@
parameterized
.
expand
(
FINETUNED_CONFIGS
)
def
test_recreate_finetuned_model
(
self
,
config
,
factory_func
):
"""Imported finetuned models can be recreated via a factory function without fairseq."""
num_out
=
28
batch_size
,
num_frames
=
3
,
1024
original
=
self
.
_get_model
(
config
,
num_out
).
eval
()
imported
=
import_fairseq_model
(
original
).
eval
()
reloaded
=
factory_func
(
num_out
=
num_out
)
reloaded
.
load_state_dict
(
imported
.
state_dict
())
reloaded
.
eval
()
# Without mask
torch
.
manual_seed
(
0
)
x
=
torch
.
randn
(
batch_size
,
num_frames
)
ref
,
_
=
imported
(
x
)
hyp
,
_
=
reloaded
(
x
)
self
.
assertEqual
(
ref
,
hyp
)
# With mask
lengths
=
torch
.
randint
(
low
=
0
,
high
=
num_frames
,
size
=
[
batch_size
,
])
ref
,
ref_lengths
=
imported
(
x
,
lengths
)
hyp
,
hyp_lengths
=
reloaded
(
x
,
lengths
)
self
.
assertEqual
(
ref
,
hyp
)
self
.
assertEqual
(
ref_lengths
,
hyp_lengths
)
torchaudio/models/wav2vec2/utils/__init__.py
View file @
f1a0b605
from
.import_huggingface
import
import_huggingface_model
from
.import_fairseq
import
import_fairseq_model
__all__
=
[
'import_huggingface_model'
,
'import_fairseq_model'
,
]
torchaudio/models/wav2vec2/utils/import_fairseq.py
0 → 100644
View file @
f1a0b605
"""Import fariseq's wav2vec2.0 pretrained weights to torchaudios's format.
For this module to work, you need `fairseq`.
"""
import
re
from
typing
import
Optional
from
torch.nn
import
Module
from
..model
import
Wav2Vec2Model
,
_get_model
def
_parse_config
(
w2v_model
,
num_out
):
encoder
=
w2v_model
.
encoder
conv_layers
=
w2v_model
.
feature_extractor
.
conv_layers
extractor_mode
=
'layer_norm'
if
'GroupNorm'
in
conv_layers
[
0
][
2
].
__class__
.
__name__
:
extractor_mode
=
'group_norm'
else
:
extractor_mode
=
'layer_norm'
conv_layer_config
=
[(
l
[
0
].
out_channels
,
l
[
0
].
kernel_size
[
0
],
l
[
0
].
stride
[
0
])
for
l
in
conv_layers
]
if
all
(
l
[
0
].
bias
is
None
for
l
in
conv_layers
):
conv_bias
=
False
elif
all
(
l
[
0
].
bias
is
not
None
for
l
in
conv_layers
):
conv_bias
=
True
else
:
raise
ValueError
(
'Either all the convolutions layers have bias term or none of them should.'
)
config
=
{
'extractor_mode'
:
extractor_mode
,
'extractor_conv_layer_config'
:
conv_layer_config
,
'extractor_conv_bias'
:
conv_bias
,
'encoder_embed_dim'
:
w2v_model
.
post_extract_proj
.
out_features
,
'encoder_projection_dropout'
:
w2v_model
.
dropout_input
.
p
,
'encoder_pos_conv_kernel'
:
encoder
.
pos_conv
[
0
].
kernel_size
[
0
],
'encoder_pos_conv_groups'
:
encoder
.
pos_conv
[
0
].
groups
,
'encoder_num_layers'
:
len
(
encoder
.
layers
),
'encoder_num_heads'
:
encoder
.
layers
[
0
].
self_attn
.
num_heads
,
'encoder_attention_dropout'
:
encoder
.
layers
[
0
].
self_attn
.
dropout_module
.
p
,
'encoder_ff_interm_features'
:
encoder
.
layers
[
0
].
fc1
.
out_features
,
'encoder_ff_interm_dropout'
:
encoder
.
layers
[
0
].
dropout2
.
p
,
'encoder_dropout'
:
encoder
.
layers
[
0
].
dropout3
.
p
,
'encoder_layer_norm_first'
:
encoder
.
layer_norm_first
,
'encoder_layer_drop'
:
encoder
.
layerdrop
,
'encoder_num_out'
:
num_out
,
}
return
config
def
_map_key
(
key
):
key_
=
key
if
key
.
startswith
(
'w2v_model.'
):
key
=
key
.
replace
(
'w2v_model.'
,
''
)
if
re
.
match
(
r
'(mask_emb|quantizer|project_q|final_proj|mask_emb)'
,
key
):
return
None
# Feature Extractor
# Group norm when "extractor_mode" is "default".
# (Only the first layer)
# "conv_layers.0.2.weight" -> "conv_layers.0.layer_norm.weight"
# "conv_layers.0.2.bias" -> "conv_layers.0.layer_norm.bias"
match
=
re
.
match
(
r
'feature_extractor\.conv_layers\.0\.2\.(weight|bias)'
,
key
)
if
match
:
return
f
"feature_extractor.conv_layers.0.layer_norm.
{
match
.
group
(
1
)
}
"
# Convolutions
# "conv_layers.X.0.weight" -> "conv_layers.X.conv.weight"
# "conv_layers.X.0.bias" -> "conv_layers.X.conv.bias"
match
=
re
.
match
(
r
'feature_extractor\.conv_layers\.(\d+)\.0\.(weight|bias)'
,
key
)
if
match
:
return
f
"feature_extractor.conv_layers.
{
match
.
group
(
1
)
}
.conv.
{
match
.
group
(
2
)
}
"
# Layer norm when "extractor_mode" is "layer_norm".
# "conv_layers.X.2.1.weight" -> "conv_layers.X.layer_norm.weight"
# "conv_layers.X.2.1.bias" -> "conv_layers.X.layer_norm.bias"
match
=
re
.
match
(
r
'feature_extractor\.conv_layers\.(\d+)\.2\.1\.(weight|bias)'
,
key
)
if
match
:
return
f
"feature_extractor.conv_layers.
{
match
.
group
(
1
)
}
.layer_norm.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"post_extract_proj\.(weight|bias)"
,
key
)
# Encoder - Feature projection
if
match
:
return
f
"encoder.feature_projection.projection.
{
match
.
group
(
1
)
}
"
match
=
re
.
match
(
r
"layer_norm\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.feature_projection.layer_norm.
{
match
.
group
(
1
)
}
"
# Encoder - Transformer - Convolutional positional embedding
match
=
re
.
match
(
r
"encoder\.pos_conv\.0\.(bias|weight_g|weight_v)"
,
key
)
if
match
:
return
f
"encoder.transformer.pos_conv_embed.conv.
{
match
.
group
(
1
)
}
"
match
=
re
.
match
(
r
"encoder\.layer_norm\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.transformer.layer_norm.
{
match
.
group
(
1
)
}
"
# Encoder - Transformer - Self attention layers
match
=
re
.
match
(
r
"encoder\.layers\.(\d+)\.self_attn\.((k_|v_|q_|out_)proj\.(weight|bias))"
,
key
)
if
match
:
return
f
"encoder.transformer.layers.
{
match
.
group
(
1
)
}
.attention.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"encoder\.layers\.(\d+)\.self_attn_layer_norm\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.transformer.layers.
{
match
.
group
(
1
)
}
.layer_norm.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"encoder\.layers\.(\d+)\.fc1\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.transformer.layers.
{
match
.
group
(
1
)
}
.feed_forward.intermediate_dense.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"encoder\.layers\.(\d+)\.fc2\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.transformer.layers.
{
match
.
group
(
1
)
}
.feed_forward.output_dense.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"encoder\.layers\.(\d+)\.final_layer_norm\.(weight|bias)"
,
key
)
if
match
:
return
f
"encoder.transformer.layers.
{
match
.
group
(
1
)
}
.final_layer_norm.
{
match
.
group
(
2
)
}
"
match
=
re
.
match
(
r
"proj\.(weight|bias)"
,
key
)
# Encoder - Readout layer
if
match
:
return
f
"encoder.readout.
{
match
.
group
(
1
)
}
"
raise
ValueError
(
f
'Unexpected key:
{
key_
}
'
)
def
_convert_state_dict
(
state_dict
):
converted
=
{}
for
k
,
v
in
state_dict
.
items
():
k
=
_map_key
(
k
)
if
k
is
not
None
:
converted
[
k
]
=
v
return
converted
def
import_fairseq_model
(
original
:
Module
,
num_out
:
Optional
[
int
]
=
None
)
->
Wav2Vec2Model
:
"""Build Wav2Vec2Model from pretrained parameters published by `fairseq`_.
Args:
original (torch.nn.Module):
An instance of fairseq's Wav2Vec2.0 model class.
Either ``fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder`` or
``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
num_out (int, optional):
The number of output labels. Required only when the original model is
an instance of ``fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model``.
Returns:
Wav2Vec2Model: Imported model.
Example - Loading pretrain-only model
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original, num_out=28)
>>>
>>> # Perform feature extraction
>>> waveform, _ = torchaudio.load('audio.wav')
>>> features, _ = imported.extract_features(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> reference = original.feature_extractor(waveform).transpose(1, 2)
>>> torch.testing.assert_allclose(features, reference)
Example - Fine-tuned model
>>> # Load model using fairseq
>>> model_file = 'wav2vec_small_960h.pt'
>>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file])
>>> original = model[0]
>>> imported = import_fairseq_model(original.w2v_encoder)
>>>
>>> # Perform encoding
>>> waveform, _ = torchaudio.load('audio.wav')
>>> emission, _ = imported(waveform)
>>>
>>> # Compare result with the original model from fairseq
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)
.. _fairseq: https://github.com/pytorch/fairseq
"""
class_
=
original
.
__class__
.
__name__
if
class_
==
'Wav2Vec2Model'
:
if
num_out
is
None
:
raise
ValueError
(
'When importing a pretrained model without readout layer, '
'`num_out` argument must be given.'
)
return
_import_pretrained
(
original
,
num_out
)
if
class_
==
'Wav2VecEncoder'
:
return
_import_finetuned
(
original
)
raise
ValueError
(
f
'Expected an instance of `Wav2Vec2Model` or `Wav2VecEncoder`. Found:
{
class_
}
'
)
def
_import_finetuned
(
original
:
Module
)
->
Wav2Vec2Model
:
config
=
_parse_config
(
original
.
w2v_model
,
original
.
proj
.
out_features
)
model
=
_get_model
(
**
config
)
model
.
load_state_dict
(
_convert_state_dict
(
original
.
state_dict
()))
return
model
def
_import_pretrained
(
original
:
Module
,
num_out
:
int
)
->
Wav2Vec2Model
:
config
=
_parse_config
(
original
,
num_out
)
model
=
_get_model
(
**
config
)
model
.
load_state_dict
(
_convert_state_dict
(
original
.
state_dict
()),
strict
=
False
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment