Unverified Commit fd7fcf93 authored by moto's avatar moto Committed by GitHub
Browse files

Add customization support to wav2vec2 labels (#1834)

parent 21a0d29e
......@@ -118,7 +118,7 @@ Pre-trained Models
.. automethod:: get_model
.. autoproperty:: labels
.. automethod:: get_labels
WAV2VEC2_BASE
......
......@@ -65,6 +65,6 @@ def test_finetune_asr_model(
model = bundle.get_model().eval()
waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
emission, _ = model(waveform)
decoder = ctc_decoder(bundle.labels)
decoder = ctc_decoder(bundle.get_labels())
result = decoder(emission[0])
assert result == expected
......@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.labels
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.get_labels()
>>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> # Infer the label probability distribution
......@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
model.load_state_dict(state_dict)
return model
@property
def labels(self) -> Optional[Tuple[str]]:
"""The optional output class labels (only applicable to ASR bundles)
def get_labels(
self,
*,
bos: str = '<s>',
pad: str = '<pad>',
eos: str = '</s>',
unk: str = '<unk>',
) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns:
Tuple of strings or None:
For fine-tuned ASR models, returns the tuple of strings representing
the output class labels. For non-ASR models, the value is ``None``.
"""
return self._labels
Tuple of strings:
For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels.
Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
""" # noqa: E501
if self._labels is None:
raise ValueError('Pre-trained models do not have labels.')
return (bos, pad, eos, unk, *self._labels)
def _get_labels():
return (
'<s>',
'<pad>',
'</s>',
'<unk>',
'|',
'E',
'T',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment