Unverified Commit 6c25f522 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Refactor AutoModel classes and add Flax Auto classes (#11027)



* Refactor AutoModel classes and add Flax Auto classes

* Add new objects to the init

* Fix hubconf and sort models

* Fix TF tests

* Missing coma

* Update src/transformers/models/auto/auto_factory.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Fix init

* Fix dummies

* Other init to fix
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent eb3479e7
...@@ -189,3 +189,52 @@ FlaxAutoModel ...@@ -189,3 +189,52 @@ FlaxAutoModel
.. autoclass:: transformers.FlaxAutoModel .. autoclass:: transformers.FlaxAutoModel
:members: :members:
FlaxAutoModelForPreTraining
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForPreTraining
:members:
FlaxAutoModelForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForMaskedLM
:members:
FlaxAutoModelForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForSequenceClassification
:members:
FlaxAutoModelForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForQuestionAnswering
:members:
FlaxAutoModelForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForTokenClassification
:members:
FlaxAutoModelForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForMultipleChoice
:members:
FlaxAutoModelForNextSentencePrediction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
:members:
...@@ -22,9 +22,10 @@ sys.path.append(SRC_DIR) ...@@ -22,9 +22,10 @@ sys.path.append(SRC_DIR)
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModel, AutoModel,
AutoModelForCausalLM,
AutoModelForMaskedLM,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelWithLMHead,
AutoTokenizer, AutoTokenizer,
add_start_docstrings, add_start_docstrings,
) )
...@@ -86,22 +87,41 @@ def model(*args, **kwargs): ...@@ -86,22 +87,41 @@ def model(*args, **kwargs):
return AutoModel.from_pretrained(*args, **kwargs) return AutoModel.from_pretrained(*args, **kwargs)
@add_start_docstrings(AutoModelWithLMHead.__doc__) @add_start_docstrings(AutoModelForCausalLM.__doc__)
def modelWithLMHead(*args, **kwargs): def modelForCausalLM(*args, **kwargs):
r""" r"""
# Using torch.hub ! # Using torch.hub !
import torch import torch
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache. model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2') # Download model and configuration from huggingface.co and cache.
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attentions=True) # Update configuration during loading model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', 'gpt2', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_pretrained('./tf_model/gpt_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = torch.hub.load('huggingface/transformers', 'modelForCausalLM', './tf_model/gpt_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
return AutoModelWithLMHead.from_pretrained(*args, **kwargs) return AutoModelForCausalLM.from_pretrained(*args, **kwargs)
@add_start_docstrings(AutoModelForMaskedLM.__doc__)
def modelForMaskedLM(*args, **kwargs):
r"""
# Using torch.hub !
import torch
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased') # Download model and configuration from huggingface.co and cache.
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_pretrained('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelForMaskedLM', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
return AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
@add_start_docstrings(AutoModelForSequenceClassification.__doc__) @add_start_docstrings(AutoModelForSequenceClassification.__doc__)
......
...@@ -1300,7 +1300,26 @@ else: ...@@ -1300,7 +1300,26 @@ else:
# FLAX-backed objects # FLAX-backed objects
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
_import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"]) _import_structure["models.auto"].extend(
[
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
)
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
...@@ -2410,7 +2429,24 @@ if TYPE_CHECKING: ...@@ -2410,7 +2429,24 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .models.auto import (
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)
from .models.bert import ( from .models.bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
......
...@@ -82,7 +82,24 @@ if is_tf_available(): ...@@ -82,7 +82,24 @@ if is_tf_available():
] ]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_auto"] = ["FLAX_MODEL_MAPPING", "FlaxAutoModel"] _import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
"FLAX_MODEL_FOR_PRETRAINING_MAPPING",
"FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"FLAX_MODEL_MAPPING",
"FlaxAutoModel",
"FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction",
"FlaxAutoModelForPreTraining",
"FlaxAutoModelForQuestionAnswering",
"FlaxAutoModelForSequenceClassification",
"FlaxAutoModelForTokenClassification",
]
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -145,7 +162,24 @@ if TYPE_CHECKING: ...@@ -145,7 +162,24 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_auto import FLAX_MODEL_MAPPING, FlaxAutoModel from .modeling_flax_auto import (
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
FLAX_MODEL_FOR_PRETRAINING_MAPPING,
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModel,
FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction,
FlaxAutoModelForPreTraining,
FlaxAutoModelForQuestionAnswering,
FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification,
)
else: else:
import importlib import importlib
......
This diff is collapsed.
...@@ -256,8 +256,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): ...@@ -256,8 +256,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config in config_to_class if config in config_to_class
} }
lines = [ lines = [
f"{indent}- **{model_type}** -- :class:`~transformers.{cls_name}` ({MODEL_NAMES_MAPPING[model_type]} model)" f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type, cls_name in model_type_to_name.items() for model_type in sorted(model_type_to_name.keys())
] ]
else: else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()} config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()}
...@@ -265,8 +265,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): ...@@ -265,8 +265,8 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items() config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
} }
lines = [ lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{cls_name}` ({config_to_model_name[config_name]} model)" f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)"
for config_name, cls_name in config_to_name.items() for config_name in sorted(config_to_name.keys())
] ]
return "\n".join(lines) return "\n".join(lines)
......
...@@ -17,11 +17,20 @@ ...@@ -17,11 +17,20 @@
from collections import OrderedDict from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..bert.modeling_flax_bert import FlaxBertModel from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification,
FlaxBertModel,
)
from ..roberta.modeling_flax_roberta import FlaxRobertaModel from ..roberta.modeling_flax_roberta import FlaxRobertaModel
from .configuration_auto import AutoConfig, BertConfig, RobertaConfig from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -29,140 +38,90 @@ logger = logging.get_logger(__name__) ...@@ -29,140 +38,90 @@ logger = logging.get_logger(__name__)
FLAX_MODEL_MAPPING = OrderedDict( FLAX_MODEL_MAPPING = OrderedDict(
[ [
# Base model mapping
(RobertaConfig, FlaxRobertaModel), (RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
] ]
) )
FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
[
# Model for pre-training mapping
(BertConfig, FlaxBertForPreTraining),
]
)
class FlaxAutoModel(object): FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
r""" [
:class:`~transformers.FlaxAutoModel` is a generic model class that will be instantiated as one of the base model # Model for Masked LM mapping
classes of the library when created with the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or the (BertConfig, FlaxBertForMaskedLM),
`FlaxAutoModel.from_config(config)` class methods. ]
)
This class cannot be instantiated using `__init__()` (throws an error).
""" FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
[
def __init__(self): # Model for Sequence Classification mapping
raise EnvironmentError( (BertConfig, FlaxBertForSequenceClassification),
"FlaxAutoModel is designed to be instantiated " ]
"using the `FlaxAutoModel.from_pretrained(pretrained_model_name_or_path)` or " )
"`FlaxAutoModel.from_config(config)` methods."
)
@classmethod FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
def from_config(cls, config): [
r""" # Model for Question Answering mapping
Instantiates one of the base model classes of the library from a configuration. (BertConfig, FlaxBertForQuestionAnswering),
]
Args: )
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class: FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
[
- isInstance of `roberta` configuration class: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) # Model for Token Classification mapping
- isInstance of `bert` configuration class: :class:`~transformers.FlaxBertModel` (Bert model (BertConfig, FlaxBertForTokenClassification),
]
Examples:: )
config = BertConfig.from_pretrained('bert-base-uncased') FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
# Download configuration from huggingface.co and cache. [
model = FlaxAutoModel.from_config(config) # Model for Multiple Choice mapping
# E.g. model was saved using `save_pretrained('./test/saved_model/')` (BertConfig, FlaxBertForMultipleChoice),
""" ]
for config_class, model_class in FLAX_MODEL_MAPPING.items(): )
if isinstance(config, config_class):
return model_class(config) FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
raise ValueError( [
f"Unrecognized configuration class {config.__class__} " (BertConfig, FlaxBertForNextSentencePrediction),
f"for this kind of FlaxAutoModel: {cls.__name__}.\n" ]
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}." )
)
FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): FlaxAutoModelForPreTraining = auto_class_factory(
r""" "FlaxAutoModelForPreTraining", FLAX_MODEL_FOR_PRETRAINING_MAPPING, head_doc="pretraining"
Instantiates one of the base model classes of the library from a pre-trained model configuration. )
The `from_pretrained()` method takes care of returning the correct model class instance based on the FlaxAutoModelForMaskedLM = auto_class_factory(
`model_type` property of the config object, or when it's missing, falling back to using pattern matching on the "FlaxAutoModelForMaskedLM", FLAX_MODEL_FOR_MASKED_LM_MAPPING, head_doc="masked language modeling"
`pretrained_model_name_or_path` string. )
The base model class to instantiate is selected as the first pattern matching in the FlaxAutoModelForSequenceClassification = auto_class_factory(
`pretrained_model_name_or_path` string (in the following order): "AFlaxutoModelForSequenceClassification",
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
- contains `roberta`: :class:`~transformers.FlaxRobertaModel` (RoBERTa model) head_doc="sequence classification",
- contains `bert`: :class:`~transformers.FlaxBertModel` (Bert model) )
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) To FlaxAutoModelForQuestionAnswering = auto_class_factory(
train the model, you should first set it back in training mode with `model.train()` "FlaxAutoModelForQuestionAnswering", FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, head_doc="question answering"
)
Args:
pretrained_model_name_or_path: either: FlaxAutoModelForTokenClassification = auto_class_factory(
"FlaxAutoModelForTokenClassification", FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, head_doc="token classification"
- a string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. Valid )
model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under a user or
organization name, like ``dbmdz/bert-base-german-cased``. FlaxAutoModelForMultipleChoice = auto_class_factory(
- a path to a `directory` containing model weights saved using "AutoModelForMultipleChoice", FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, head_doc="multiple choice"
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. )
- a path or url to a `pytorch index checkpoint file` (e.g. `./pt_model/pytorch_model.bin`). In this
case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` FlaxAutoModelForNextSentencePrediction = auto_class_factory(
argument. "FlaxAutoModelForNextSentencePrediction",
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
model_args: (`optional`) Sequence of positional arguments: head_doc="next sentence prediction",
All remaining positional arguments will be passed to the underlying model's ``__init__`` method )
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a
pretrained model), or
- the model was saved using :func:`~transformers.FlaxPreTrainedModel.save_pretrained` and is reloaded
by supplying the save directory.
- the model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a
configuration JSON file named `config.json` is found in the directory.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model configuration should be cached if the
standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if
they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely received file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionary containing missing keys, unexpected keys and error
messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments:
These arguments will be passed to the configuration and the model.
Examples::
model = FlaxAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from huggingface.co and cache.
model = FlaxAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
for config_class, model_class in FLAX_MODEL_MAPPING.items():
if isinstance(config, config_class):
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, _from_auto=True, **kwargs
)
raise ValueError(
f"Unrecognized configuration class {config.__class__} "
f"for this kind of FlaxAutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in FLAX_MODEL_MAPPING.keys())}"
)
...@@ -11,6 +11,27 @@ class FlaxPreTrainedModel: ...@@ -11,6 +11,27 @@ class FlaxPreTrainedModel:
requires_flax(self) requires_flax(self)
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = None
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = None
FLAX_MODEL_FOR_PRETRAINING_MAPPING = None
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None
FLAX_MODEL_MAPPING = None FLAX_MODEL_MAPPING = None
...@@ -23,6 +44,69 @@ class FlaxAutoModel: ...@@ -23,6 +44,69 @@ class FlaxAutoModel:
requires_flax(self) requires_flax(self)
class FlaxAutoModelForMaskedLM:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForNextSentencePrediction:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForPreTraining:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxAutoModelForTokenClassification:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
class FlaxBertForMaskedLM: class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_flax(self) requires_flax(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment