Unverified Commit 47ca3aa9 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor wav2vec2 pipeline util (#1925)

parent 19d8f1c2
from ._wav2vec2 import ( from ._wav2vec2.impl import (
Wav2Vec2Bundle, Wav2Vec2Bundle,
Wav2Vec2ASRBundle, Wav2Vec2ASRBundle,
WAV2VEC2_BASE, WAV2VEC2_BASE,
......
...@@ -5,6 +5,8 @@ import torch ...@@ -5,6 +5,8 @@ import torch
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model from torchaudio.models import wav2vec2_model, Wav2Vec2Model
from . import utils
__all__ = [] __all__ = []
...@@ -150,39 +152,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -150,39 +152,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
return (blank, *self._labels) return (blank, *self._labels)
def _get_labels():
return (
'|',
'E',
'T',
'A',
'O',
'N',
'I',
'H',
'S',
'R',
'D',
'L',
'U',
'M',
'W',
'C',
'F',
'G',
'Y',
'P',
'B',
'V',
'K',
"'",
'X',
'J',
'Q',
'Z',
)
WAV2VEC2_BASE = Wav2Vec2Bundle( WAV2VEC2_BASE = Wav2Vec2Bundle(
_path='wav2vec2_fairseq_base_ls960.pth', _path='wav2vec2_fairseq_base_ls960.pth',
_params={ _params={
...@@ -255,7 +224,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( ...@@ -255,7 +224,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module
...@@ -301,7 +270,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( ...@@ -301,7 +270,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
...@@ -347,7 +316,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle( ...@@ -347,7 +316,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.05, "encoder_layer_drop": 0.05,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module
...@@ -436,7 +405,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle( ...@@ -436,7 +405,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -482,7 +451,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle( ...@@ -482,7 +451,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -528,7 +497,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle( ...@@ -528,7 +497,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -617,7 +586,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle( ...@@ -617,7 +586,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -663,7 +632,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle( ...@@ -663,7 +632,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -709,7 +678,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle( ...@@ -709,7 +678,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": 29, "aux_num_out": 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -935,7 +904,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( ...@@ -935,7 +904,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'encoder_layer_drop': 0.1, 'encoder_layer_drop': 0.1,
'aux_num_out': 29, 'aux_num_out': 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration. HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration.
...@@ -982,7 +951,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( ...@@ -982,7 +951,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'encoder_layer_drop': 0.1, 'encoder_layer_drop': 0.1,
'aux_num_out': 29, 'aux_num_out': 29,
}, },
_labels=_get_labels(), _labels=utils._get_en_labels(),
_sample_rate=16000, _sample_rate=16000,
) )
HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration. HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration.
......
def _get_en_labels():
return (
'|',
'E',
'T',
'A',
'O',
'N',
'I',
'H',
'S',
'R',
'D',
'L',
'U',
'M',
'W',
'C',
'F',
'G',
'Y',
'P',
'B',
'V',
'K',
"'",
'X',
'J',
'Q',
'Z',
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment