Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
5600bd25
"tests/vscode:/vscode.git/clone" did not exist on "fe46dac2c2ea1a988929fba05e9d3d3c9b11dfd7"
Unverified
Commit
5600bd25
authored
Oct 15, 2021
by
moto
Committed by
GitHub
Oct 15, 2021
Browse files
Add sample rate to Wav2Vec2 bundle (#1878)
parent
6c074666
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
7 deletions
+46
-7
docs/source/pipelines.rst
docs/source/pipelines.rst
+2
-0
torchaudio/pipelines/_wav2vec2.py
torchaudio/pipelines/_wav2vec2.py
+44
-7
No files found.
docs/source/pipelines.rst
View file @
5600bd25
...
@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning
...
@@ -9,6 +9,7 @@ wav2vec 2.0 / HuBERT - Representation Learning
----------------------------------------------
----------------------------------------------
.. autoclass:: Wav2Vec2Bundle
.. autoclass:: Wav2Vec2Bundle
:members: sample_rate
.. automethod:: get_model
.. automethod:: get_model
...
@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
...
@@ -73,6 +74,7 @@ wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
-------------------------------------
.. autoclass:: Wav2Vec2ASRBundle
.. autoclass:: Wav2Vec2ASRBundle
:members: sample_rate
.. automethod:: get_model
.. automethod:: get_model
...
...
torchaudio/pipelines/_wav2vec2.py
View file @
5600bd25
...
@@ -27,16 +27,30 @@ class Wav2Vec2Bundle:
...
@@ -27,16 +27,30 @@ class Wav2Vec2Bundle:
Example - Feature Extraction
Example - Feature Extraction
>>> import torchaudio
>>> import torchaudio
>>>
>>>
>>> bundle = torchaudio.pipelines.HUBERT_BASE
>>>
>>> # Build the model and load pretrained weight.
>>> # Build the model and load pretrained weight.
>>> model =
torchaudio.models.HUBERT_BASE
.get_model()
>>> model =
bundle
.get_model()
Downloading:
Downloading:
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
100%|███████████████████████████████| 360M/360M [00:06<00:00, 60.6MB/s]
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Extract acoustic features
>>> # Extract acoustic features
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> features, _ = model.extract_features(waveform)
>>> features, _ = model.extract_features(waveform)
"""
# noqa: E501
"""
# noqa: E501
_path
:
str
_path
:
str
_params
:
Dict
[
str
,
Any
]
_params
:
Dict
[
str
,
Any
]
_sample_rate
:
float
@
property
def
sample_rate
(
self
)
->
float
:
"""Sample rate of the audio that the model is trained on.
:type: float
"""
return
self
.
_sample_rate
def
get_model
(
self
,
*
,
dl_kwargs
=
None
)
->
Wav2Vec2Model
:
def
get_model
(
self
,
*
,
dl_kwargs
=
None
)
->
Wav2Vec2Model
:
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
"""get_model(self, *, dl_kwargs=None) -> torchaudio.models.Wav2Vec2Model
...
@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
...
@@ -77,17 +91,24 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
Example - ASR
Example - ASR
>>> import torchaudio
>>> import torchaudio
>>>
>>>
>>> bundle = torchaudio.pipelines.HUBERT_ASR_LARGE
>>>
>>> # Build the model and load pretrained weight.
>>> # Build the model and load pretrained weight.
>>> model =
torchaudio.models.HUBERT_ASR_LARGE
.get_model()
>>> model =
bundle
.get_model()
Downloading:
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>>
>>> # Check the corresponding labels of the output.
>>> # Check the corresponding labels of the output.
>>> labels =
torchaudio.models.HUBERT_ASR_LARGE
.get_labels()
>>> labels =
bundle
.get_labels()
>>> print(labels)
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>>
>>> # Resample audio to the expected sampling rate
>>> waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
>>>
>>> # Infer the label probability distribution
>>> # Infer the label probability distribution
>>> waveform, sample_rate = torchaudio.load('my_speech.mp3')
>>> emissions, _ = model(waveform)
>>> emissions, _ = model(waveform)
>>>
>>> # Pass emission to decoder
>>> # Pass emission to decoder
>>> # `ctc_decode` is for illustration purpose only
>>> # `ctc_decode` is for illustration purpose only
>>> transcripts = ctc_decode(emissions, labels)
>>> transcripts = ctc_decode(emissions, labels)
...
@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
...
@@ -121,8 +142,6 @@ class Wav2Vec2ASRBundle(Wav2Vec2Bundle):
>>> import torchaudio
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
"""
# noqa: E501
"""
# noqa: E501
if
self
.
_labels
is
None
:
if
self
.
_labels
is
None
:
raise
ValueError
(
'Pre-trained models do not have labels.'
)
raise
ValueError
(
'Pre-trained models do not have labels.'
)
...
@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
...
@@ -190,6 +209,7 @@ WAV2VEC2_BASE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.05
,
'encoder_layer_drop'
:
0.05
,
"aux_num_out"
:
None
,
"aux_num_out"
:
None
,
},
},
_sample_rate
=
16000
,
)
)
WAV2VEC2_BASE
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
WAV2VEC2_BASE
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
...
@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
...
@@ -234,6 +254,7 @@ WAV2VEC2_ASR_BASE_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_BASE_10M
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_BASE_10M
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
...
@@ -279,6 +300,7 @@ WAV2VEC2_ASR_BASE_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_BASE_100H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_BASE_100H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
...
@@ -324,6 +346,7 @@ WAV2VEC2_ASR_BASE_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_BASE_960H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_BASE_960H
.
__doc__
=
"""Build "base" wav2vec2 model with an extra linear module
...
@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
...
@@ -367,6 +390,7 @@ WAV2VEC2_LARGE = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.2
,
"encoder_layer_drop"
:
0.2
,
"aux_num_out"
:
None
,
"aux_num_out"
:
None
,
},
},
_sample_rate
=
16000
,
)
)
WAV2VEC2_LARGE
.
__doc__
=
"""Build "large" wav2vec2 model.
WAV2VEC2_LARGE
.
__doc__
=
"""Build "large" wav2vec2 model.
...
@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
...
@@ -411,6 +435,7 @@ WAV2VEC2_ASR_LARGE_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_10M
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_10M
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
...
@@ -456,6 +481,7 @@ WAV2VEC2_ASR_LARGE_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_100H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_100H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
...
@@ -501,6 +527,7 @@ WAV2VEC2_ASR_LARGE_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_960H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_960H
.
__doc__
=
"""Build "large" wav2vec2 model with an extra linear module
...
@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
...
@@ -544,6 +571,7 @@ WAV2VEC2_LARGE_LV60K = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.0
,
"encoder_layer_drop"
:
0.0
,
"aux_num_out"
:
None
,
"aux_num_out"
:
None
,
},
},
_sample_rate
=
16000
,
)
)
WAV2VEC2_LARGE_LV60K
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model.
WAV2VEC2_LARGE_LV60K
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model.
...
@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
...
@@ -588,6 +616,7 @@ WAV2VEC2_ASR_LARGE_LV60K_10M = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_LV60K_10M
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_LV60K_10M
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
...
@@ -633,6 +662,7 @@ WAV2VEC2_ASR_LARGE_LV60K_100H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_LV60K_100H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_LV60K_100H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
...
@@ -678,6 +708,7 @@ WAV2VEC2_ASR_LARGE_LV60K_960H = Wav2Vec2ASRBundle(
"aux_num_out"
:
32
,
"aux_num_out"
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
WAV2VEC2_ASR_LARGE_LV60K_960H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
WAV2VEC2_ASR_LARGE_LV60K_960H
.
__doc__
=
"""Build "large-lv60k" wav2vec2 model with an extra linear module
...
@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
...
@@ -723,6 +754,7 @@ WAV2VEC2_XLSR53 = Wav2Vec2Bundle(
"encoder_layer_drop"
:
0.0
,
"encoder_layer_drop"
:
0.0
,
"aux_num_out"
:
None
,
"aux_num_out"
:
None
,
},
},
_sample_rate
=
16000
,
)
)
WAV2VEC2_XLSR53
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
WAV2VEC2_XLSR53
.
__doc__
=
"""wav2vec 2.0 model with "Base" configuration.
...
@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
...
@@ -769,6 +801,7 @@ HUBERT_BASE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.05
,
'encoder_layer_drop'
:
0.05
,
'aux_num_out'
:
None
,
'aux_num_out'
:
None
,
},
},
_sample_rate
=
16000
,
)
)
HUBERT_BASE
.
__doc__
=
"""HuBERT model with "Base" configuration.
HUBERT_BASE
.
__doc__
=
"""HuBERT model with "Base" configuration.
...
@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
...
@@ -812,6 +845,7 @@ HUBERT_LARGE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.0
,
'encoder_layer_drop'
:
0.0
,
'aux_num_out'
:
None
,
'aux_num_out'
:
None
,
},
},
_sample_rate
=
16000
,
)
)
HUBERT_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
HUBERT_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
...
@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
...
@@ -855,6 +889,7 @@ HUBERT_XLARGE = Wav2Vec2Bundle(
'encoder_layer_drop'
:
0.0
,
'encoder_layer_drop'
:
0.0
,
'aux_num_out'
:
None
,
'aux_num_out'
:
None
,
},
},
_sample_rate
=
16000
,
)
)
HUBERT_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
HUBERT_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
...
@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
...
@@ -899,6 +934,7 @@ HUBERT_ASR_LARGE = Wav2Vec2ASRBundle(
'aux_num_out'
:
32
,
'aux_num_out'
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
HUBERT_ASR_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
HUBERT_ASR_LARGE
.
__doc__
=
"""HuBERT model with "Large" configuration.
...
@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
...
@@ -945,6 +981,7 @@ HUBERT_ASR_XLARGE = Wav2Vec2ASRBundle(
'aux_num_out'
:
32
,
'aux_num_out'
:
32
,
},
},
_labels
=
_get_labels
(),
_labels
=
_get_labels
(),
_sample_rate
=
16000
,
)
)
HUBERT_ASR_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
HUBERT_ASR_XLARGE
.
__doc__
=
"""HuBERT model with "Extra Large" configuration.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment