Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
5600bd25
Unverified
Commit
5600bd25
authored
Oct 15, 2021
by
moto
Committed by
GitHub
Oct 15, 2021
Browse files
Add sample rate to Wav2Vec2 bundle (#1878)
parent
6c074666
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
7 deletions
+46
-7
docs/source/pipelines.rst
docs/source/pipelines.rst
+2
-0
torchaudio/pipelines/_wav2vec2.py
torchaudio/pipelines/_wav2vec2.py
+44
-7
No files found.
docs/source/pipelines.rst
View file @
5600bd25
...
...
@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning
----------------------------------------------
.. autoclass:: Wav2Vec2Bundle
:members: sample_rate
.. automethod:: get_model
...
...
@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
.. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
.. automethod:: get_model
...
...
torchaudio/pipelines/_wav2vec2.py
View file @
5600bd25
...
...
@@ -27,16 +27,30 @@ class Wav2Vec2Bundle:
Example - Feature Extraction
>>> import torchaudio
>>>
>>> bundle = torchaudio.pipelines.HUBERT_BASE
>>>
>>> # Build the model and load pretrained weight.
>>> model =
torchaudio.models.HUBERT_BASE
.get_model()
>>> model =
bundle
.get_model()
Downloading:
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Extract acoustic features
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> features, _ = model.extract_features(waveform)
"""
# noqa: E501
_path
:
str
_params
:
Dict
[
str
,
Any
]
_sample_rate
:
float
@
property
def
sample_rate
(
self
)
->
float
:
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return
self
.
_sample_rate
def
get_model
(
self
,
*
,
dl_kwargs
=
None
)
->
Wav2Vec2Model
:
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
...
...
@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example - ASR
>>> import torchaudio
>>>
>>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
>>>
>>> # Build the model and load pretrained weight.
>>> model =
torchaudio.models.HUBERT_ASR_LARGE
.get_model()
>>> model =
bundle
.get_model()
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>>
>>> # Check the corresponding labels of the output.
>>> labels =
torchaudio.models.HUBERT_ASR_LARGE
.get_labels()
>>> labels =
bundle
.get_labels()
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Infer the label probability distribution
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> emissions, _ = model(waveform)
>>>
>>> # Pass emission to decoder
>>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels)
...
...
@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
"""
# noqa: E501
if
self
.
_labels
is
None
:
raise
ValueError
(
'Pre-trained models do not have labels.'
)
...
...
@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.05
,
"aux_num_out"
:
None
,
},
_sample_rate
=
16000
,
)
WAV2VEC2_BASE
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
...
...
@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_BASE_10M
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
...
@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_BASE_100H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
...
@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_BASE_960H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
...
@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.2
,
"aux_num_out"
:
None
,
},
_sample_rate
=
16000
,
)
WAV2VEC2_LARGE
.
__doc__
=
"""Build "large" wav2vec2 model.
...
...
@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_10M
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
...
@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_100H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
...
@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_960H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
...
@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.0
,
"aux_num_out"
:
None
,
},
_sample_rate
=
16000
,
)
WAV2VEC2_LARGE_LV60K
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model.
...
...
@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_LV60K_10M
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
...
@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_LV60K_100H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
...
@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
WAV2VEC2_ASR_LARGE_LV60K_960H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
...
@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.0
,
"aux_num_out"
:
None
,
},
_sample_rate
=
16000
,
)
WAV2VEC2_XLSR53
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
...
...
@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.05
,
'aux_num_out'
:
None
,
},
_sample_rate
=
16000
,
)
HUBERT_BASE
.
__doc__
=
"""HuBERT model with "Base" configuration.
...
...
@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.0
,
'aux_num_out'
:
None
,
},
_sample_rate
=
16000
,
)
HUBERT_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
...
...
@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.0
,
'aux_num_out'
:
None
,
},
_sample_rate
=
16000
,
)
HUBERT_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
...
...
@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'aux_num_out'
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
HUBERT_ASR_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
...
...
@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'aux_num_out'
:
32
,
},
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
HUBERT_ASR_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment