"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dfc76b25426d75d5dce489bd18cfd6a51fb01b97"
Unverified Commit 1498eb98 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

add FlaxAutoModelForImageClassification in main init (#12298)

parent 2affeb29
...@@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction ...@@ -266,3 +266,10 @@ FlaxAutoModelForNextSentencePrediction
.. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction .. autoclass:: transformers.FlaxAutoModelForNextSentencePrediction
:members: :members:
FlaxAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxAutoModelForImageClassification
:members:
...@@ -1509,6 +1509,7 @@ if is_flax_available(): ...@@ -1509,6 +1509,7 @@ if is_flax_available():
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -1520,6 +1521,7 @@ if is_flax_available(): ...@@ -1520,6 +1521,7 @@ if is_flax_available():
"FLAX_MODEL_MAPPING", "FLAX_MODEL_MAPPING",
"FlaxAutoModel", "FlaxAutoModel",
"FlaxAutoModelForCausalLM", "FlaxAutoModelForCausalLM",
"FlaxAutoModelForImageClassification",
"FlaxAutoModelForMaskedLM", "FlaxAutoModelForMaskedLM",
"FlaxAutoModelForMultipleChoice", "FlaxAutoModelForMultipleChoice",
"FlaxAutoModelForNextSentencePrediction", "FlaxAutoModelForNextSentencePrediction",
...@@ -2848,6 +2850,7 @@ if TYPE_CHECKING: ...@@ -2848,6 +2850,7 @@ if TYPE_CHECKING:
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
from .models.auto import ( from .models.auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -2859,6 +2862,7 @@ if TYPE_CHECKING: ...@@ -2859,6 +2862,7 @@ if TYPE_CHECKING:
FLAX_MODEL_MAPPING, FLAX_MODEL_MAPPING,
FlaxAutoModel, FlaxAutoModel,
FlaxAutoModelForCausalLM, FlaxAutoModelForCausalLM,
FlaxAutoModelForImageClassification,
FlaxAutoModelForMaskedLM, FlaxAutoModelForMaskedLM,
FlaxAutoModelForMultipleChoice, FlaxAutoModelForMultipleChoice,
FlaxAutoModelForNextSentencePrediction, FlaxAutoModelForNextSentencePrediction,
......
...@@ -87,6 +87,7 @@ if is_tf_available(): ...@@ -87,6 +87,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_auto"] = [ _import_structure["modeling_flax_auto"] = [
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
"FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"FLAX_MODEL_FOR_MASKED_LM_MAPPING", "FLAX_MODEL_FOR_MASKED_LM_MAPPING",
"FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -175,6 +176,7 @@ if TYPE_CHECKING: ...@@ -175,6 +176,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_auto import ( from .modeling_flax_auto import (
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
......
...@@ -115,7 +115,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -115,7 +115,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
(ViTConfig, FlaxViTForImageClassification), (ViTConfig, FlaxViTForImageClassification),
...@@ -188,7 +188,7 @@ FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING) ...@@ -188,7 +188,7 @@ FlaxAutoModel = auto_class_factory("FlaxAutoModel", FLAX_MODEL_MAPPING)
FlaxAutoModelForImageClassification = auto_class_factory( FlaxAutoModelForImageClassification = auto_class_factory(
"FlaxAutoModelForImageClassification", "FlaxAutoModelForImageClassification",
FLAX_MODEL_FOR_IMAGECLASSIFICATION_MAPPING, FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
head_doc="image classification modeling", head_doc="image classification modeling",
) )
......
...@@ -79,6 +79,9 @@ class FlaxPreTrainedModel: ...@@ -79,6 +79,9 @@ class FlaxPreTrainedModel:
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = None
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
FLAX_MODEL_FOR_MASKED_LM_MAPPING = None FLAX_MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -124,6 +127,15 @@ class FlaxAutoModelForCausalLM: ...@@ -124,6 +127,15 @@ class FlaxAutoModelForCausalLM:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxAutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxAutoModelForMaskedLM: class FlaxAutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment