Commit dacd3fd4 authored by moto's avatar moto
Browse files

Rename factory functions `wav2vec2_asr_ARCH` to `wav2vec2_ft_ARCH` (#1804)

* Rename factory functions `wav2vec2_asr_ARCH` to `wav2vec2_ft_ARCH`

In #1783, we split the factory functions of wav2vec2 into ones for pretraining models
and ones for fine-tuning models (pretraining model + extra Linear module).

I picked the name scheme `wav2vec2_asr_ARCH` for factory functions of fine-tuning models,
but did not feel right, because the architecture code is more generic.
Even though the resulting model architecture was used for  ASR fine-tuning in the paper, 
it does not have to be ASR.
This became more evident as we add pre-trained parameters support, such as #1799.
It matters more for the weight files that for which task and on which dataset it was
trained on. For factory function, ASR task is not relevant.

Therefore renaming the functions by replacing `_asr_` to `_ft_` fine-tuning.

Note: Since the new functions are not release yet, this PR itself is not BC-breaking.
parent a4974c4c
...@@ -90,20 +90,20 @@ wav2vec2_large_lv60k ...@@ -90,20 +90,20 @@ wav2vec2_large_lv60k
.. autofunction:: wav2vec2_large_lv60k .. autofunction:: wav2vec2_large_lv60k
wav2vec2_asr_base wav2vec2_ft_base
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_base .. autofunction:: wav2vec2_ft_base
wav2vec2_asr_large wav2vec2_ft_large
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large .. autofunction:: wav2vec2_ft_large
wav2vec2_asr_large_lv60k wav2vec2_ft_large_lv60k
^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_asr_large_lv60k .. autofunction:: wav2vec2_ft_large_lv60k
hubert_base hubert_base
^^^^^^^^^^^ ^^^^^^^^^^^
...@@ -120,15 +120,15 @@ hubert_xlarge ...@@ -120,15 +120,15 @@ hubert_xlarge
.. autofunction:: hubert_xlarge .. autofunction:: hubert_xlarge
hubert_asr_large hubert_ft_large
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
.. autofunction:: hubert_asr_large .. autofunction:: hubert_ft_large
hubert_asr_xlarge hubert_ft_xlarge
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_asr_xlarge .. autofunction:: hubert_ft_xlarge
.. currentmodule:: torchaudio.models.wav2vec2.utils .. currentmodule:: torchaudio.models.wav2vec2.utils
......
...@@ -3,17 +3,17 @@ import sys ...@@ -3,17 +3,17 @@ import sys
import torch import torch
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base, wav2vec2_ft_base,
wav2vec2_asr_large, wav2vec2_ft_large,
wav2vec2_asr_large_lv60k, wav2vec2_ft_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
hubert_base, hubert_base,
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
hubert_asr_large, hubert_ft_large,
hubert_asr_xlarge, hubert_ft_xlarge,
) )
from torchaudio.models.wav2vec2.utils import ( from torchaudio.models.wav2vec2.utils import (
import_fairseq_model, import_fairseq_model,
...@@ -75,12 +75,12 @@ ALL_PRETRAINING_CONFIGS = parameterized.expand([ ...@@ -75,12 +75,12 @@ ALL_PRETRAINING_CONFIGS = parameterized.expand([
(HUBERT_XLARGE_LL60K, hubert_xlarge), (HUBERT_XLARGE_LL60K, hubert_xlarge),
], name_func=_name_func) ], name_func=_name_func)
FINETUNING_CONFIGS = parameterized.expand([ FINETUNING_CONFIGS = parameterized.expand([
(WAV2VEC2_BASE_960H, wav2vec2_asr_base), (WAV2VEC2_BASE_960H, wav2vec2_ft_base),
(WAV2VEC2_LARGE_960H, wav2vec2_asr_large), (WAV2VEC2_LARGE_960H, wav2vec2_ft_large),
(WAV2VEC2_LARGE_LV60K_960H, wav2vec2_asr_large_lv60k), (WAV2VEC2_LARGE_LV60K_960H, wav2vec2_ft_large_lv60k),
(WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_asr_large_lv60k), (WAV2VEC2_LARGE_LV60K_SELF_960H, wav2vec2_ft_large_lv60k),
(HUBERT_LARGE, hubert_asr_large), (HUBERT_LARGE, hubert_ft_large),
(HUBERT_XLARGE, hubert_asr_xlarge), (HUBERT_XLARGE, hubert_ft_xlarge),
], name_func=_name_func) ], name_func=_name_func)
......
...@@ -2,9 +2,9 @@ import json ...@@ -2,9 +2,9 @@ import json
import torch import torch
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base, wav2vec2_ft_base,
wav2vec2_asr_large, wav2vec2_ft_large,
wav2vec2_asr_large_lv60k, wav2vec2_ft_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
...@@ -50,11 +50,11 @@ PRETRAIN_CONFIGS = parameterized.expand([ ...@@ -50,11 +50,11 @@ PRETRAIN_CONFIGS = parameterized.expand([
(HF_BASE_10K_VOXPOPULI, wav2vec2_base), (HF_BASE_10K_VOXPOPULI, wav2vec2_base),
], name_func=_name_func) ], name_func=_name_func)
FINETUNE_CONFIGS = parameterized.expand([ FINETUNE_CONFIGS = parameterized.expand([
(HF_BASE_960H, wav2vec2_asr_base), (HF_BASE_960H, wav2vec2_ft_base),
(HF_LARGE_960H, wav2vec2_asr_large), (HF_LARGE_960H, wav2vec2_ft_large),
(HF_LARGE_LV60_960H, wav2vec2_asr_large_lv60k), (HF_LARGE_LV60_960H, wav2vec2_ft_large_lv60k),
(HF_LARGE_LV60_SELF_960H, wav2vec2_asr_large_lv60k), (HF_LARGE_LV60_SELF_960H, wav2vec2_ft_large_lv60k),
(HF_LARGE_XLSR_DE, wav2vec2_asr_large_lv60k), (HF_LARGE_XLSR_DE, wav2vec2_ft_large_lv60k),
], name_func=_name_func) ], name_func=_name_func)
......
...@@ -4,17 +4,17 @@ import torch ...@@ -4,17 +4,17 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torchaudio.models.wav2vec2 import ( from torchaudio.models.wav2vec2 import (
wav2vec2_asr_base, wav2vec2_ft_base,
wav2vec2_asr_large, wav2vec2_ft_large,
wav2vec2_asr_large_lv60k, wav2vec2_ft_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
hubert_base, hubert_base,
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
hubert_asr_large, hubert_ft_large,
hubert_asr_xlarge, hubert_ft_xlarge,
) )
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
...@@ -40,11 +40,11 @@ pretrain_factory_funcs = parameterized.expand([ ...@@ -40,11 +40,11 @@ pretrain_factory_funcs = parameterized.expand([
finetune_factory_funcs = parameterized.expand([ finetune_factory_funcs = parameterized.expand([
(wav2vec2_asr_base, ), (wav2vec2_ft_base, ),
(wav2vec2_asr_large, ), (wav2vec2_ft_large, ),
(wav2vec2_asr_large_lv60k, ), (wav2vec2_ft_large_lv60k, ),
(hubert_asr_large, ), (hubert_ft_large, ),
(hubert_asr_xlarge, ), (hubert_ft_xlarge, ),
], name_func=_name_func) ], name_func=_name_func)
...@@ -67,7 +67,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -67,7 +67,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cpu_smoke_test(self, dtype): def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_asr_base(num_out=32) model = wav2vec2_ft_base(num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32, ), (torch.float64, )])
...@@ -75,7 +75,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -75,7 +75,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_asr_base(num_out=32) model = wav2vec2_ft_base(num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
def _feature_extractor_test(self, model): def _feature_extractor_test(self, model):
...@@ -194,7 +194,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -194,7 +194,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs @finetune_factory_funcs
def test_finetune_torchscript(self, factory_func): def test_finetune_torchscript(self, factory_func):
"""Wav2Vec2Model should be scriptable""" """Wav2Vec2Model should be scriptable"""
if factory_func.__name__ == 'hubert_asr_xlarge' and os.name == 'nt': if factory_func is hubert_ft_xlarge and os.name == 'nt':
self.skipTest( self.skipTest(
'hubert_asr_xlarge is known to fail on Windows CI. ' 'hubert_asr_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776') 'See https://github.com/pytorch/pytorch/issues/65776')
......
...@@ -5,17 +5,17 @@ from .deepspeech import DeepSpeech ...@@ -5,17 +5,17 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2 from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_asr_base, wav2vec2_ft_base,
wav2vec2_asr_large, wav2vec2_ft_large,
wav2vec2_asr_large_lv60k, wav2vec2_ft_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
hubert_base, hubert_base,
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
hubert_asr_large, hubert_ft_large,
hubert_asr_xlarge, hubert_ft_xlarge,
) )
__all__ = [ __all__ = [
...@@ -25,17 +25,17 @@ __all__ = [ ...@@ -25,17 +25,17 @@ __all__ = [
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_asr_base', 'wav2vec2_ft_base',
'wav2vec2_asr_large', 'wav2vec2_ft_large',
'wav2vec2_asr_large_lv60k', 'wav2vec2_ft_large_lv60k',
'wav2vec2_base', 'wav2vec2_base',
'wav2vec2_large', 'wav2vec2_large',
'wav2vec2_large_lv60k', 'wav2vec2_large_lv60k',
'hubert_base', 'hubert_base',
'hubert_large', 'hubert_large',
'hubert_xlarge', 'hubert_xlarge',
'hubert_asr_large', 'hubert_ft_large',
'hubert_asr_xlarge', 'hubert_ft_xlarge',
'Tacotron2', 'Tacotron2',
'tacotron2', 'tacotron2',
] ]
from .model import ( from .model import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_asr_base, wav2vec2_ft_base,
wav2vec2_asr_large, wav2vec2_ft_large,
wav2vec2_asr_large_lv60k, wav2vec2_ft_large_lv60k,
wav2vec2_base, wav2vec2_base,
wav2vec2_large, wav2vec2_large,
wav2vec2_large_lv60k, wav2vec2_large_lv60k,
hubert_base, hubert_base,
hubert_large, hubert_large,
hubert_xlarge, hubert_xlarge,
hubert_asr_large, hubert_ft_large,
hubert_asr_xlarge, hubert_ft_xlarge,
) )
from . import utils from . import utils
__all__ = [ __all__ = [
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_asr_base', 'wav2vec2_ft_base',
'wav2vec2_asr_large', 'wav2vec2_ft_large',
'wav2vec2_asr_large_lv60k', 'wav2vec2_ft_large_lv60k',
'wav2vec2_base', 'wav2vec2_base',
'wav2vec2_large', 'wav2vec2_large',
'wav2vec2_large_lv60k', 'wav2vec2_large_lv60k',
'hubert_base', 'hubert_base',
'hubert_large', 'hubert_large',
'hubert_xlarge', 'hubert_xlarge',
'hubert_asr_large', 'hubert_ft_large',
'hubert_asr_xlarge', 'hubert_ft_xlarge',
'utils', 'utils',
] ]
...@@ -173,7 +173,7 @@ def wav2vec2_base() -> Wav2Vec2Model: ...@@ -173,7 +173,7 @@ def wav2vec2_base() -> Wav2Vec2Model:
) )
def wav2vec2_asr_base(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_base(num_out: int) -> Wav2Vec2Model:
"""Build "base" wav2vec2 with an extra linear module """Build "base" wav2vec2 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
...@@ -235,7 +235,7 @@ def wav2vec2_large() -> Wav2Vec2Model: ...@@ -235,7 +235,7 @@ def wav2vec2_large() -> Wav2Vec2Model:
) )
def wav2vec2_asr_large(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_large(num_out: int) -> Wav2Vec2Model:
"""Build "large" wav2vec2.0 model with an extra linear module """Build "large" wav2vec2.0 model with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
...@@ -297,7 +297,7 @@ def wav2vec2_large_lv60k() -> Wav2Vec2Model: ...@@ -297,7 +297,7 @@ def wav2vec2_large_lv60k() -> Wav2Vec2Model:
) )
def wav2vec2_asr_large_lv60k(num_out: int) -> Wav2Vec2Model: def wav2vec2_ft_large_lv60k(num_out: int) -> Wav2Vec2Model:
"""Build "Large LV-60k" wav2vec2.0 with an extra linear module """Build "Large LV-60k" wav2vec2.0 with an extra linear module
This is one of the model architectures used in *wav2vec 2.0* This is one of the model architectures used in *wav2vec 2.0*
...@@ -388,7 +388,7 @@ def hubert_large() -> Wav2Vec2Model: ...@@ -388,7 +388,7 @@ def hubert_large() -> Wav2Vec2Model:
) )
def hubert_asr_large(num_out) -> Wav2Vec2Model: def hubert_ft_large(num_out) -> Wav2Vec2Model:
"""Build "Large" HuBERT model with an extra linear module """Build "Large" HuBERT model with an extra linear module
...@@ -451,7 +451,7 @@ def hubert_xlarge() -> Wav2Vec2Model: ...@@ -451,7 +451,7 @@ def hubert_xlarge() -> Wav2Vec2Model:
) )
def hubert_asr_xlarge(num_out) -> Wav2Vec2Model: def hubert_ft_xlarge(num_out) -> Wav2Vec2Model:
"""Build "extra large" HuBERT model with an extra linear module """Build "extra large" HuBERT model with an extra linear module
This is one of the model architecture used in *HuBERT* This is one of the model architecture used in *HuBERT*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment