Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
fd7fcf93
Unverified
Commit
fd7fcf93
authored
Oct 07, 2021
by
moto
Committed by
GitHub
Oct 07, 2021
Browse files
Add customization support to wav2vec2 labels (#1834)
parent
21a0d29e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
15 deletions
+34
-15
docs/source/models.rst
docs/source/models.rst
+1
-1
test/integration_tests/wav2vec2_model_test.py
test/integration_tests/wav2vec2_model_test.py
+1
-1
torchaudio/models/wav2vec2/pretrained.py
torchaudio/models/wav2vec2/pretrained.py
+32
-13
No files found.
docs/source/models.rst
View file @
fd7fcf93
...
@@ -118,7 +118,7 @@ Pre-trained Models
...
@@ -118,7 +118,7 @@ Pre-trained Models
.. automethod:: get_model
.. automethod:: get_model
.. auto
property::
labels
.. auto
method:: get_
labels
WAV2VEC2_BASE
WAV2VEC2_BASE
...
...
test/integration_tests/wav2vec2_model_test.py
View file @
fd7fcf93
...
@@ -65,6 +65,6 @@ def test_finetune_asr_model(
...
@@ -65,6 +65,6 @@ def test_finetune_asr_model(
model
=
bundle
.
get_model
().
eval
()
model
=
bundle
.
get_model
().
eval
()
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech_16000_en
)
waveform
,
sample_rate
=
torchaudio
.
load
(
sample_speech_16000_en
)
emission
,
_
=
model
(
waveform
)
emission
,
_
=
model
(
waveform
)
decoder
=
ctc_decoder
(
bundle
.
labels
)
decoder
=
ctc_decoder
(
bundle
.
get_
labels
()
)
result
=
decoder
(
emission
[
0
])
result
=
decoder
(
emission
[
0
])
assert
result
==
expected
assert
result
==
expected
torchaudio/models/wav2vec2/pretrained.py
View file @
fd7fcf93
...
@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
...
@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
Downloading:
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>> # Check the corresponding labels of the output.
>>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.labels
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.
get_
labels
()
>>> print(labels)
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> # Infer the label probability distribution
>>> # Infer the label probability distribution
...
@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
...
@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
model
.
load_state_dict
(
state_dict
)
model
.
load_state_dict
(
state_dict
)
return
model
return
model
@
property
def
get_labels
(
def
labels
(
self
)
->
Optional
[
Tuple
[
str
]]:
self
,
"""The optional output class labels (only applicable to ASR bundles)
*
,
bos
:
str
=
'<s>'
,
pad
:
str
=
'<pad>'
,
eos
:
str
=
'</s>'
,
unk
:
str
=
'<unk>'
,
)
->
Tuple
[
str
]:
"""The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns:
Returns:
Tuple of strings or None:
Tuple of strings:
For fine-tuned ASR models, returns the tuple of strings representing
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels. For non-ASR models, the value is ``None``.
the output class labels.
"""
return
self
.
_labels
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
"""
# noqa: E501
if
self
.
_labels
is
None
:
raise
ValueError
(
'Pre-trained models do not have labels.'
)
return
(
bos
,
pad
,
eos
,
unk
,
*
self
.
_labels
)
def
_get_labels
():
def
_get_labels
():
return
(
return
(
'<s>'
,
'<pad>'
,
'</s>'
,
'<unk>'
,
'|'
,
'|'
,
'E'
,
'E'
,
'T'
,
'T'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment