Make the core wav2vec2 factory function public (#1829)
This commit makes the following changes
1. Make the factory function with full customizability public.
i.e. `_get_model(...) -> wav2vec2_model(...)`.
2. Change the other architecture-specific factory functions so that they accept parameters not related to the model architecture (such as dropout).
i.e. `wav2vec2_base() -> wav2vec2_base(encoder_projection_dropout, encoder_attention_dropout, encoder_ff_interm_dropout, ...)`
### Why?
While adding the pre-trained weight support, I realized that separating API for model construction and pre-trained support achieves simple code organization because of the good separation of concern. As mentioned in #1821, in this framework,
1. Model implementation is responsible for computation logic,
2. factory functions are responsible for customizability and model construction,
3. and pre-trained weight API is responsible for constructing a model and loading pre-trained weights along with the complementary information (such as pre-processing and class labels).
(note: for simple models, combining 1 and 2 is also okay.)
This means that factory functions has to support all the customizability required by pre-trained weight API. The current implementation uses the internal function like `from .model import Wav2Vec2Model, _get_model`, which is a bit strange.
This PR rectifies it by making the mother factory function public.
This also clarifies the purpose of having the other factory functions as public API, which is just a syntax sugar for constructing un-trained model with specific architecture. So this commit also adds supplemental parameters to them.
Showing
This diff is collapsed.
Please register or sign in to comment