Unverified Commit fd7fcf93 authored by moto's avatar moto Committed by GitHub
Browse files

Add customization support to wav2vec2 labels (#1834)

parent 21a0d29e
...@@ -118,7 +118,7 @@ Pre-trained Models ...@@ -118,7 +118,7 @@ Pre-trained Models
.. automethod:: get_model .. automethod:: get_model
.. autoproperty:: labels .. automethod:: get_labels
WAV2VEC2_BASE WAV2VEC2_BASE
......
...@@ -65,6 +65,6 @@ def test_finetune_asr_model( ...@@ -65,6 +65,6 @@ def test_finetune_asr_model(
model = bundle.get_model().eval() model = bundle.get_model().eval()
waveform, sample_rate = torchaudio.load(sample_speech_16000_en) waveform, sample_rate = torchaudio.load(sample_speech_16000_en)
emission, _ = model(waveform) emission, _ = model(waveform)
decoder = ctc_decoder(bundle.labels) decoder = ctc_decoder(bundle.get_labels())
result = decoder(emission[0]) result = decoder(emission[0])
assert result == expected assert result == expected
...@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle: ...@@ -43,7 +43,7 @@ class Wav2Vec2PretrainedModelBundle:
Downloading: Downloading:
100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s] 100%|███████████████████████████████| 1.18G/1.18G [00:17<00:00, 73.8MB/s]
>>> # Check the corresponding labels of the output. >>> # Check the corresponding labels of the output.
>>> labels = torchaudio.models.HUBERT_ASR_LARGE.labels >>> labels = torchaudio.models.HUBERT_ASR_LARGE.get_labels()
>>> print(labels) >>> print(labels)
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z') ('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> # Infer the label probability distribution >>> # Infer the label probability distribution
...@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle: ...@@ -74,24 +74,43 @@ class Wav2Vec2PretrainedModelBundle:
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
@property def get_labels(
def labels(self) -> Optional[Tuple[str]]: self,
"""The optional output class labels (only applicable to ASR bundles) *,
bos: str = '<s>',
pad: str = '<pad>',
eos: str = '</s>',
unk: str = '<unk>',
) -> Tuple[str]:
"""The output class labels (only applicable to fine-tuned bundles)
The first four tokens are BOS, padding, EOS and UNK tokens and they can be customized.
Args:
bos (str, optional): Beginning of sentence token. (default: ``'<s>'``)
pad (str, optional): Padding token. (default: ``'<pad>'``)
eos (str, optional): End of sentence token. (default: ``'</s>'``)
unk (str, optional): Token for unknown class. (default: ``'<unk>'``)
Returns: Returns:
Tuple of strings or None: Tuple of strings:
For fine-tuned ASR models, returns the tuple of strings representing For models fine-tuned on ASR, returns the tuple of strings representing
the output class labels. For non-ASR models, the value is ``None``. the output class labels.
"""
return self._labels Example
>>> import torchaudio
>>> torchaudio.models.HUBERT_ASR_LARGE.get_labels()
('<s>', '<pad>', '</s>', '<unk>', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')
>>> torchaudio.models.HUBERT_LARGE.get_labels()
ValueError: Pre-trained models do not have labels.
""" # noqa: E501
if self._labels is None:
raise ValueError('Pre-trained models do not have labels.')
return (bos, pad, eos, unk, *self._labels)
def _get_labels(): def _get_labels():
return ( return (
'<s>',
'<pad>',
'</s>',
'<unk>',
'|', '|',
'E', 'E',
'T', 'T',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment