Commit 0582e73c authored by moto's avatar moto
Browse files

Make the core wav2vec2 factory function public (#1829)

This commit makes the following changes
1. Make the factory function with full customizability public.
    i.e. `_get_model(...) -> wav2vec2_model(...)`.
2. Change the other architecture-specific factory functions so that they accept parameters not related to the model architecture (such as dropout).
    i.e. `wav2vec2_base() -> wav2vec2_base(encoder_projection_dropout, encoder_attention_dropout, encoder_ff_interm_dropout, ...)`

### Why?

While adding the pre-trained weight support, I realized that separating API for model construction and pre-trained support achieves simple code organization because of the good separation of concern. As mentioned in #1821, in this framework,
  1. Model implementation is responsible for computation logic,
  2. factory functions are responsible for customizability and model construction,
  3. and pre-trained weight API is responsible for constructing a model and loading pre-trained weights along with the complementary information (such as pre-processing and class labels).

(note: for simple models, combining 1 and 2 is also okay.)

This means that factory functions has to support all the customizability required by pre-trained weight API. The current implementation uses the internal function like `from .model import Wav2Vec2Model, _get_model`, which is a bit strange.

This PR rectifies it by making the mother factory function public.
This also clarifies the purpose of having the other factory functions as public API, which is just a syntax sugar for constructing un-trained model with specific architecture. So this commit also adds supplemental parameters to them.
parent 8f270d09
......@@ -75,6 +75,12 @@ Wav2Vec2Model
Factory Functions
-----------------
wav2vec2_model
^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_model
wav2vec2_base
^^^^^^^^^^^^^
......
......@@ -226,7 +226,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out)
reloaded = factory_func(aux_num_out=num_out)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
......
......@@ -221,7 +221,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features)
reloaded = factory_func(aux_num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict())
reloaded.eval()
self._test_recreate(imported, reloaded, config)
......@@ -67,7 +67,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_ft_base(num_out=32)
model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )])
......@@ -75,7 +75,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_ft_base(num_out=32)
model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype)
def _feature_extractor_test(self, model):
......@@ -113,7 +113,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(num_out=32))
self._feature_extractor_test(factory_func(aux_num_out=32))
def _test_batch_consistency(self, model):
model.eval()
......@@ -166,7 +166,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs
def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail"""
self._test_zero_length(factory_func(num_out=32))
self._test_zero_length(factory_func(aux_num_out=32))
def _test_torchscript(self, model):
model.eval()
......@@ -202,7 +202,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.skipTest(
'hubert_ft_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776')
self._test_torchscript(factory_func(num_out=32))
self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model):
model.eval()
......@@ -232,7 +232,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine
def test_finetune_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(num_out=32))
self._test_quantize_smoke_test(factory_func(aux_num_out=32))
def _test_quantize_torchscript(self, model):
model.eval()
......@@ -271,4 +271,4 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(num_out=32))
self._test_quantize_torchscript(factory_func(aux_num_out=32))
......@@ -5,6 +5,7 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
......@@ -46,6 +47,7 @@ __all__ = [
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base',
'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k',
......
from .model import (
Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base,
wav2vec2_ft_large,
wav2vec2_ft_large_lv60k,
......@@ -16,6 +17,7 @@ from . import utils
__all__ = [
'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base',
'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k',
......
This diff is collapsed.
......@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Any, Optional
from torch.hub import load_state_dict_from_url
from .model import _get_model, Wav2Vec2Model
from .model import wav2vec2_model, Wav2Vec2Model
__all__ = []
......@@ -67,7 +67,7 @@ class Wav2Vec2PretrainedModelBundle:
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
"""
model = _get_model(**self._params)
model = wav2vec2_model(**self._params)
url = f'https://download.pytorch.org/models/audio/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
......
......@@ -6,7 +6,7 @@ import re
from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model
from ..model import Wav2Vec2Model, wav2vec2_model
def _parse_config(w2v_model):
......@@ -190,27 +190,27 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features)
model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict()))
return model
def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original)
model = _get_model(**config, aux_num_out=None)
model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model
def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features)
model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model
def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original)
model = _get_model(**config, aux_num_out=None)
model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model
......@@ -4,7 +4,7 @@ import logging
from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model
from ..model import Wav2Vec2Model, wav2vec2_model
_LG = logging.getLogger(__name__)
......@@ -40,7 +40,7 @@ def _build(config, original):
'"lm_head" module is not imported.')
aux_num_out = None
wav2vec2 = original
imported = _get_model(**config, aux_num_out=aux_num_out)
imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment