Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b1838cfc
Commit
b1838cfc
authored
Oct 07, 2021
by
moto
Browse files
Add customization support to wav2vec2 labels (#1834)
parent
01764dee
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
15 deletions
+34
-15
docs/source/models.rst
docs/source/models.rst
+1
-1
test/integration_tests/wav2vec2_model_test.py
test/integration_tests/wav2vec2_model_test.py
+1
-1
torchaudio/models/wav2vec2/pretrained.py
torchaudio/models/wav2vec2/pretrained.py
+32
-13
No files found.
docs/source/models.rst
View file @
b1838cfc
...
@@ -118,7 +118,7 @@ Pre-trained Models
...
@@ -118,7 +118,7 @@ Pre-trained Models
.. automethod:: get_model
.. automethod:: get_model
.. auto
property::
labels
.. auto
method:: get_
labels
WAV2VEC2_BASE
WAV2VEC2_BASE
...
...
test/integration_tests/wav2vec2_model_test.py
View file @
b1838cfc
...
@@ -65,6 +65,6 @@ def test_finetune_asr_model(
...
@@ -65,6 +65,6 @@ def test_finetune_asr_model(
model
=
bundle
.
get_model
().
eval
()
model
=
bundle
.
get_model
().
eval
()
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech_16000_en
)
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech_16000_en
)
emission
,
_
=
model
(
waveform
)
emission
,
_
=
model
(
waveform
)
decoder
=
ctc_decoder
(
bundle
.
labels
)
decoder
=
ctc_decoder
(
bundle
.
get_
labels
()
)
result
=
decoder
(
emission
[
0
])
result
=
decoder
(
emission
[
0
])
assert
result
==
expected
assert
result
==
expected
torchaudio/models/wav2vec2/pretrained.py
View file @
b1838cfc
...
@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
...
@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
Downloading:
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>> # Check the corresponding labels of the output.
>>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.labels
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.
get_
labels
()
>>> print(labels)
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> # Infer the label probability distribution
>>> # Infer the label probability distribution
...
@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
...
@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
model
.
load_state_dict
(
state_dict
)
model
.
load_state_dict
(
state_dict
)
return
model
return
model
@
property
def
get_labels
(
def
labels
(
self
)
->
Optional
[
Tuple
[
str
]]:
self
,
"""The optional output class labels (only applicable to ASR bundles)
*
,
bos
:
str
=
'<s>'
,
pad
:
str
=
'<pad>'
,
eos
:
str
=
'</s>'
,
unk
:
str
=
'<unk>'
,
)
->
Tuple
[
str
]:
"""The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns:
Returns:
Tuple of strings or None:
Tuple of strings:
For fine-tuned ASR models, returns the tuple of strings representing
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels. For non-ASR models, the value is ``None``.
the output class labels.
"""
return
self
.
_labels
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
"""
# noqa: E501
if
self
.
_labels
is
None
:
raise
ValueError
(
'Pre-trained models do not have labels.'
)
return
(
bos
,
pad
,
eos
,
unk
,
*
self
.
_labels
)
def
_get_labels
():
def
_get_labels
():
return
(
return
(
'<s>'
,
'<pad>'
,
'</s>'
,
'<unk>'
,
'|'
,
'|'
,
'E'
,
'E'
,
'T'
,
'T'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment