Unverified Commit 31a69c36 authored by moto's avatar moto Committed by GitHub
Browse files

Make the core wav2vec2 factory function public (#1829)

This commit makes the following changes
1. Make the factory function with full customizability public.
    i.e. `_get_model(...) -> wav2vec2_model(...)`.
2. Change the other architecture-specific factory functions so that they accept parameters not related to the model architecture (such as dropout).
    i.e. `wav2vec2_base() -> wav2vec2_base(encoder_projection_dropout, encoder_attention_dropout, encoder_ff_interm_dropout, ...)`

### Why?

While adding the pre-trained weight support, I realized that separating API for model construction and pre-trained support achieves simple code organization because of the good separation of concern. As mentioned in #1821, in this framework,
  1. Model implementation is responsible for computation logic,
  2. factory functions are responsible for customizability and model construction,
  3. and pre-trained weight API is responsible for constructing a model and loading pre-trained weights along with the complementary information (such as pre-processing and class labels).

(note: for simple models, combining 1 and 2 is also okay.)

This means that factory functions has to support all the customizability required by pre-trained weight API. The current implementation uses the internal function like `from .model import Wav2Vec2Model, _get_model`, which is a bit strange.

This PR rectifies it by making the mother factory function public.
This also clarifies the purpose of having the other factory functions as public API, which is just a syntax sugar for constructing un-trained model with specific architecture. So this commit also adds supplemental parameters to them.
parent 9a34e7c0
...@@ -75,6 +75,12 @@ Wav2Vec2Model ...@@ -75,6 +75,12 @@ Wav2Vec2Model
Factory Functions Factory Functions
----------------- -----------------
wav2vec2_model
^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_model
wav2vec2_base wav2vec2_base
^^^^^^^^^^^^^ ^^^^^^^^^^^^^
......
...@@ -226,7 +226,7 @@ class TestFairseqIntegration(TorchaudioTestCase): ...@@ -226,7 +226,7 @@ class TestFairseqIntegration(TorchaudioTestCase):
original = self._get_model(config, num_out).eval() original = self._get_model(config, num_out).eval()
imported = import_fairseq_model(original).eval() imported = import_fairseq_model(original).eval()
reloaded = factory_func(num_out=num_out) reloaded = factory_func(aux_num_out=num_out)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
......
...@@ -221,7 +221,7 @@ class TestHFIntegration(TorchaudioTestCase): ...@@ -221,7 +221,7 @@ class TestHFIntegration(TorchaudioTestCase):
def test_recreate_finetune(self, config, factory_func): def test_recreate_finetune(self, config, factory_func):
"""Imported models can be recreated via a factory function without Hugging Face transformers.""" """Imported models can be recreated via a factory function without Hugging Face transformers."""
imported = import_huggingface_model(self._get_model(config)).eval() imported = import_huggingface_model(self._get_model(config)).eval()
reloaded = factory_func(num_out=imported.aux.out_features) reloaded = factory_func(aux_num_out=imported.aux.out_features)
reloaded.load_state_dict(imported.state_dict()) reloaded.load_state_dict(imported.state_dict())
reloaded.eval() reloaded.eval()
self._test_recreate(imported, reloaded, config) self._test_recreate(imported, reloaded, config)
...@@ -74,7 +74,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -74,7 +74,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cpu_smoke_test(self, dtype): def test_cpu_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
model = wav2vec2_ft_base(num_out=32) model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cpu'), dtype) self._smoke_test(model, torch.device('cpu'), dtype)
@parameterized.expand([(torch.float32, ), (torch.float64, )]) @parameterized.expand([(torch.float32, ), (torch.float64, )])
...@@ -82,7 +82,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -82,7 +82,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
def test_cuda_smoke_test(self, dtype): def test_cuda_smoke_test(self, dtype):
model = wav2vec2_base() model = wav2vec2_base()
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
model = wav2vec2_ft_base(num_out=32) model = wav2vec2_ft_base(aux_num_out=32)
self._smoke_test(model, torch.device('cuda'), dtype) self._smoke_test(model, torch.device('cuda'), dtype)
def _feature_extractor_test(self, model): def _feature_extractor_test(self, model):
...@@ -120,7 +120,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -120,7 +120,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs @finetune_factory_funcs
def test_finetune_feature_extractor_test(self, factory_func): def test_finetune_feature_extractor_test(self, factory_func):
"""`extract_features` method does not fail""" """`extract_features` method does not fail"""
self._feature_extractor_test(factory_func(num_out=32)) self._feature_extractor_test(factory_func(aux_num_out=32))
def _test_batch_consistency(self, model): def _test_batch_consistency(self, model):
model.eval() model.eval()
...@@ -173,7 +173,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -173,7 +173,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@finetune_factory_funcs @finetune_factory_funcs
def test_finetune_zero_length(self, factory_func): def test_finetune_zero_length(self, factory_func):
"""Passing zero length should not fail""" """Passing zero length should not fail"""
self._test_zero_length(factory_func(num_out=32)) self._test_zero_length(factory_func(aux_num_out=32))
def _test_torchscript(self, model): def _test_torchscript(self, model):
model.eval() model.eval()
...@@ -209,7 +209,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -209,7 +209,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
self.skipTest( self.skipTest(
'hubert_ft_xlarge is known to fail on Windows CI. ' 'hubert_ft_xlarge is known to fail on Windows CI. '
'See https://github.com/pytorch/pytorch/issues/65776') 'See https://github.com/pytorch/pytorch/issues/65776')
self._test_torchscript(factory_func(num_out=32)) self._test_torchscript(factory_func(aux_num_out=32))
def _test_quantize_smoke_test(self, model): def _test_quantize_smoke_test(self, model):
model.eval() model.eval()
...@@ -239,7 +239,7 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -239,7 +239,7 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine @skipIfNoQengine
def test_finetune_quantize(self, factory_func): def test_finetune_quantize(self, factory_func):
"""Wav2Vec2Model should support basic quantization""" """Wav2Vec2Model should support basic quantization"""
self._test_quantize_smoke_test(factory_func(num_out=32)) self._test_quantize_smoke_test(factory_func(aux_num_out=32))
def _test_quantize_torchscript(self, model): def _test_quantize_torchscript(self, model):
model.eval() model.eval()
...@@ -278,4 +278,4 @@ class TestWav2Vec2Model(TorchaudioTestCase): ...@@ -278,4 +278,4 @@ class TestWav2Vec2Model(TorchaudioTestCase):
@skipIfNoQengine @skipIfNoQengine
def test_finetune_quantize_torchscript(self, factory_func): def test_finetune_quantize_torchscript(self, factory_func):
"""Quantized Wav2Vec2Model should be scriptable""" """Quantized Wav2Vec2Model should be scriptable"""
self._test_quantize_torchscript(factory_func(num_out=32)) self._test_quantize_torchscript(factory_func(aux_num_out=32))
...@@ -5,6 +5,7 @@ from .deepspeech import DeepSpeech ...@@ -5,6 +5,7 @@ from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2 from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base, wav2vec2_ft_base,
wav2vec2_ft_large, wav2vec2_ft_large,
wav2vec2_ft_large_lv60k, wav2vec2_ft_large_lv60k,
...@@ -46,6 +47,7 @@ __all__ = [ ...@@ -46,6 +47,7 @@ __all__ = [
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base', 'wav2vec2_ft_base',
'wav2vec2_ft_large', 'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k', 'wav2vec2_ft_large_lv60k',
......
from .model import ( from .model import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model,
wav2vec2_ft_base, wav2vec2_ft_base,
wav2vec2_ft_large, wav2vec2_ft_large,
wav2vec2_ft_large_lv60k, wav2vec2_ft_large_lv60k,
...@@ -16,6 +17,7 @@ from . import utils ...@@ -16,6 +17,7 @@ from . import utils
__all__ = [ __all__ = [
'Wav2Vec2Model', 'Wav2Vec2Model',
'wav2vec2_model',
'wav2vec2_ft_base', 'wav2vec2_ft_base',
'wav2vec2_ft_large', 'wav2vec2_ft_large',
'wav2vec2_ft_large_lv60k', 'wav2vec2_ft_large_lv60k',
......
This diff is collapsed.
...@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Any, Optional ...@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Any, Optional
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
from .model import _get_model, Wav2Vec2Model from .model import wav2vec2_model, Wav2Vec2Model
__all__ = [] __all__ = []
...@@ -67,7 +67,7 @@ class Wav2Vec2PretrainedModelBundle: ...@@ -67,7 +67,7 @@ class Wav2Vec2PretrainedModelBundle:
Args: Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
""" """
model = _get_model(**self._params) model = wav2vec2_model(**self._params)
url = f'https://download.pytorch.org/models/audio/{self._path}' url = f'https://download.pytorch.org/models/audio/{self._path}'
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs) state_dict = load_state_dict_from_url(url, **dl_kwargs)
......
...@@ -6,7 +6,7 @@ import re ...@@ -6,7 +6,7 @@ import re
from torch.nn import Module from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model from ..model import Wav2Vec2Model, wav2vec2_model
def _parse_config(w2v_model): def _parse_config(w2v_model):
...@@ -190,27 +190,27 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model: ...@@ -190,27 +190,27 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model: def _import_wav2vec2_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model) config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features) model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict())) model.load_state_dict(_convert_state_dict(original.state_dict()))
return model return model
def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model: def _import_wav2vec2_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original) config = _parse_config(original)
model = _get_model(**config, aux_num_out=None) model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model: def _import_hubert_finetuning(original: Module) -> Wav2Vec2Model:
config = _parse_config(original.w2v_model) config = _parse_config(original.w2v_model)
model = _get_model(**config, aux_num_out=original.proj.out_features) model = wav2vec2_model(**config, aux_num_out=original.proj.out_features)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model: def _import_hubert_pretraining(original: Module) -> Wav2Vec2Model:
config = _parse_config(original) config = _parse_config(original)
model = _get_model(**config, aux_num_out=None) model = wav2vec2_model(**config, aux_num_out=None)
model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False) model.load_state_dict(_convert_state_dict(original.state_dict()), strict=False)
return model return model
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
from torch.nn import Module from torch.nn import Module
from ..model import Wav2Vec2Model, _get_model from ..model import Wav2Vec2Model, wav2vec2_model
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -40,7 +40,7 @@ def _build(config, original): ...@@ -40,7 +40,7 @@ def _build(config, original):
'"lm_head" module is not imported.') '"lm_head" module is not imported.')
aux_num_out = None aux_num_out = None
wav2vec2 = original wav2vec2 = original
imported = _get_model(**config, aux_num_out=aux_num_out) imported = wav2vec2_model(**config, aux_num_out=aux_num_out)
imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict()) imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict())
imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict()) imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict())
imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict()) imported.encoder.transformer.load_state_dict(wav2vec2.encoder.state_dict())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment