Unverified Commit 5600bd25 authored by moto's avatar moto Committed by GitHub
Browse files

Add sample rate to Wav2Vec2 bundle (#1878)

parent 6c074666
......@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning
----------------------------------------------
.. autoclass:: Wav2Vec2Bundle
:members: sample_rate
.. automethod:: get_model
......@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
.. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
.. automethod:: get_model
......
......@@ -27,16 +27,30 @@ class Wav2Vec2Bundle:
Example - Feature Extraction
>>> import torchaudio
>>>
>>> bundle = torchaudio.pipelines.HUBERT_BASE
>>>
>>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_BASE.get_model()
>>> model = bundle.get_model()
Downloading:
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Extract acoustic features
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> features, _ = model.extract_features(waveform)
""" # noqa: E501
_path: str
_params: Dict[str, Any]
_sample_rate: float
@property
def sample_rate(self) -> float:
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return self._sample_rate
def get_model(self, *, dl_kwargs=None) -> Wav2Vec2Model:
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
......@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example - ASR
>>> import torchaudio
>>>
>>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
>>>
>>> # Build the model and load pretrained weight.
>>> model = torchaudio.models.HUBERT_ASR_LARGE.get_model()
>>> model = bundle.get_model()
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>>
>>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.get_labels()
>>> labels = bundle.get_labels()
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Infer the label probability distribution
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> emissions, _ = model(waveform)
>>>
>>> # Pass emission to decoder
>>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels)
......@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
""" # noqa: E501
if self._labels is None:
raise ValueError('Pre-trained models do not have labels.')
......@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.05,
"aux_num_out": None,
},
_sample_rate=16000,
)
WAV2VEC2_BASE.__doc__ = """wav2vec 2.0 model with "Base" configuration.
......@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_BASE_10M.__doc__ = """Build "base" wav2vec2 model with an extra linear module
......@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_BASE_100H.__doc__ = """Build "base" wav2vec2 model with an extra linear module
......@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_BASE_960H.__doc__ = """Build "base" wav2vec2 model with an extra linear module
......@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
"encoder_layer_drop": 0.2,
"aux_num_out": None,
},
_sample_rate=16000,
)
WAV2VEC2_LARGE.__doc__ = """Build "large" wav2vec2 model.
......@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_10M.__doc__ = """Build "large" wav2vec2 model with an extra linear module
......@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_100H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
......@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_960H.__doc__ = """Build "large" wav2vec2 model with an extra linear module
......@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_sample_rate=16000,
)
WAV2VEC2_LARGE_LV60K.__doc__ = """Build "large-lv60k" wav2vec2 model.
......@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_LV60K_10M.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
......@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_LV60K_100H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
......@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"aux_num_out": 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
WAV2VEC2_ASR_LARGE_LV60K_960H.__doc__ = """Build "large-lv60k" wav2vec2 model with an extra linear module
......@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
"encoder_layer_drop": 0.0,
"aux_num_out": None,
},
_sample_rate=16000,
)
WAV2VEC2_XLSR53.__doc__ = """wav2vec 2.0 model with "Base" configuration.
......@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.05,
'aux_num_out': None,
},
_sample_rate=16000,
)
HUBERT_BASE.__doc__ = """HuBERT model with "Base" configuration.
......@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.0,
'aux_num_out': None,
},
_sample_rate=16000,
)
HUBERT_LARGE.__doc__ = """HuBERT model with "Large" configuration.
......@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
'encoder_layer_drop': 0.0,
'aux_num_out': None,
},
_sample_rate=16000,
)
HUBERT_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration.
......@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'aux_num_out': 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
HUBERT_ASR_LARGE.__doc__ = """HuBERT model with "Large" configuration.
......@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'aux_num_out': 32,
},
_labels=_get_labels(),
_sample_rate=16000,
)
HUBERT_ASR_XLARGE.__doc__ = """HuBERT model with "Extra Large" configuration.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment