"tests/vscode:/vscode.git/clone" did not exist on "fe46dac2c2ea1a988929fba05e9d3d3c9b11dfd7"
Unverified Commit 5600bd25 authored by moto's avatar moto Committed by GitHub
Browse files

Add sample rate to Wav2Vec2 bundle (#1878)

parent 6c074666
...@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning ...@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning
---------------------------------------------- ----------------------------------------------
.. autoclass:: Wav2Vec2Bundle .. autoclass:: Wav2Vec2Bundle
:members: sample_rate
.. automethod:: get_model .. automethod:: get_model
...@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR ...@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
------------------------------------- -------------------------------------
.. autoclass:: Wav2Vec2ASRBundle .. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
.. automethod:: get_model .. automethod:: get_model
......
...@@ -27,16 +27,30 @@ class Wav2Vec2Bundle: ...@@ -27,16 +27,30 @@ class Wav2Vec2Bundle:
Example - Feature Extraction Example - Feature Extraction
>>> import torchaudio >>> import torchaudio
>>> >>>
>>> bundle = torchaudio.pipelines.HUBERT_BASE
>>>
>>> # Build the model and load pretrained weight. >>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_BASE.get_model() >>> model = bundle.get_model()
Downloading: Downloading:
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s] 100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Extract acoustic features >>> # Extract acoustic features
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> features, _ = model.extract_features(waveform) >>> features, _ = model.extract_features(waveform)
""" # noqa: E501 """ # noqa: E501
_path: str _path: str
_params: Dict[str, Any] _params: Dict[str, Any]
_sample_rate: float
@property
def sample_rate(self) -> float:
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model: def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model """get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
...@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example - ASR Example - ASR
>>> import torchaudio >>> import torchaudio
>>> >>>
>>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
>>>
>>> # Build the model and load pretrained weight. >>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_ASR_LARGE.get_model() >>> model = bundle.get_model()
Downloading: Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s] 100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>>
>>> # Check the corresponding labels of the output. >>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.get_labels() >>> labels = bundle.get_labels()
>>> print(labels) >>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Infer the label probability distribution >>> # Infer the label probability distribution
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> emissions, _ = model(waveform) >>> emissions, _ = model(waveform)
>>>
>>> # Pass emission to decoder >>> # Pass emission to decoder
>>> # `ctc_decode` is for illustration purpose only >>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels) >>> transcripts = ctc_decode(emissions, labels)
...@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle): ...@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> import torchaudio >>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels() >>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
""" # noqa: E501 """ # noqa: E501
if self._labels is None: if self._labels is None:
raise ValueError('Pre-trained models do not have labels.') raise ValueError('Pre-trained models do not have labels.')
...@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle( ...@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
"aux_num_out": None, "aux_num_out": None,
}, },
_sample_rate=16000,
) )
WAV2VEC2_BASE.__doc__ = """wav2vec 2.0 model with "Base" configuration. WAV2VEC2_BASE.__doc__ = """wav2vec 2.0 model with "Base" configuration.
...@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle( ...@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module
...@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle( ...@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_BASE_100H.__doc__ = """Build "base" wav2vec2 model with an extra linear module WAV2VEC2_ASR_BASE_100H.__doc__ = """Build "base" wav2vec2 model with an extra linear module
...@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle( ...@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module
...@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle( ...@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
"encoder_layer_drop": 0.2, "encoder_layer_drop": 0.2,
"aux_num_out": None, "aux_num_out": None,
}, },
_sample_rate=16000,
) )
WAV2VEC2_LARGE.__doc__ = """Build "large" wav2vec2 model. WAV2VEC2_LARGE.__doc__ = """Build "large" wav2vec2 model.
...@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle( ...@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle( ...@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle( ...@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
...@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle( ...@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": None, "aux_num_out": None,
}, },
_sample_rate=16000,
) )
WAV2VEC2_LARGE_LV60K.__doc__ = """Build "large-lv60k" wav2vec2 model. WAV2VEC2_LARGE_LV60K.__doc__ = """Build "large-lv60k" wav2vec2 model.
...@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle( ...@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle( ...@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle( ...@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32, "aux_num_out": 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
...@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle( ...@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
"encoder_layer_drop": 0.0, "encoder_layer_drop": 0.0,
"aux_num_out": None, "aux_num_out": None,
}, },
_sample_rate=16000,
) )
WAV2VEC2_XLSR53.__doc__ = """wav2vec 2.0 model with "Base" configuration. WAV2VEC2_XLSR53.__doc__ = """wav2vec 2.0 model with "Base" configuration.
...@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle( ...@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.05, 'encoder_layer_drop': 0.05,
'aux_num_out': None, 'aux_num_out': None,
}, },
_sample_rate=16000,
) )
HUBERT_BASE.__doc__ = """HuBERT model with "Base" configuration. HUBERT_BASE.__doc__ = """HuBERT model with "Base" configuration.
...@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle( ...@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.0, 'encoder_layer_drop': 0.0,
'aux_num_out': None, 'aux_num_out': None,
}, },
_sample_rate=16000,
) )
HUBERT_LARGE.__doc__ = """HuBERT model with "Large" configuration. HUBERT_LARGE.__doc__ = """HuBERT model with "Large" configuration.
...@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle( ...@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.0, 'encoder_layer_drop': 0.0,
'aux_num_out': None, 'aux_num_out': None,
}, },
_sample_rate=16000,
) )
HUBERT_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration. HUBERT_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration.
...@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle( ...@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'aux_num_out': 32, 'aux_num_out': 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration. HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration.
...@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle( ...@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'aux_num_out': 32, 'aux_num_out': 32,
}, },
_labels=_get_labels(), _labels=_get_labels(),
_sample_rate=16000,
) )
HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration. HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment