• moto's avatar
    Make the core wav2vec2 factory function public (#1829) · 0582e73c
    moto authored
    This commit makes the following changes
    1. Make the factory function with full customizability public.
        i.e. `_get_model(...) -> wav2vec2_model(...)`.
    2. Change the other architecture-specific factory functions so that they accept parameters not related to the model architecture (such as dropout).
        i.e. `wav2vec2_base() -> wav2vec2_base(encoder_projection_dropout, encoder_attention_dropout, encoder_ff_interm_dropout, ...)`
    
    ### Why?
    
    While adding the pre-trained weight support, I realized that separating API for model construction and pre-trained support achieves simple code organization because of the good separation of concern. As mentioned in #1821, in this framework,
      1. Model implementation is responsible for computation logic,
      2. factory functions are responsible for customizability and model construction,
      3. and pre-trained weight API is responsible for constructing a model and loading pre-trained weights along with the complementary information (such as pre-processing and class labels).
    
    (note: for simple models, combining 1 and 2 is also okay.)
    
    This means that factory functions has to support all the customizability required by pre-trained weight API. The current implementation uses the internal function like `from .model import Wav2Vec2Model, _get_model`, which is a bit strange.
    
    This PR rectifies it by making the mother factory function public.
    This also clarifies the purpose of having the other factory functions as public API, which is just a syntax sugar for constructing un-trained model with specific architecture. So this commit also adds supplemental parameters to them.
    0582e73c
models.rst 3.33 KB