Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
90067748
Unverified
Commit
90067748
authored
Apr 04, 2023
by
Shubhamai
Committed by
GitHub
Apr 04, 2023
Browse files
Flax Regnet (#21867)
* initial commit * review changes * post model PR merge * updating doc
parent
fc5b7419
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
1136 additions
and
12 deletions
+1136
-12
docs/source/de/index.mdx
docs/source/de/index.mdx
+1
-1
docs/source/en/index.mdx
docs/source/en/index.mdx
+1
-1
docs/source/en/model_doc/regnet.mdx
docs/source/en/model_doc/regnet.mdx
+13
-1
docs/source/es/index.mdx
docs/source/es/index.mdx
+1
-1
docs/source/fr/index.mdx
docs/source/fr/index.mdx
+1
-1
docs/source/it/index.mdx
docs/source/it/index.mdx
+1
-1
docs/source/ja/index.mdx
docs/source/ja/index.mdx
+1
-1
docs/source/ko/index.mdx
docs/source/ko/index.mdx
+1
-1
docs/source/pt/index.mdx
docs/source/pt/index.mdx
+1
-1
docs/source/zh/index.mdx
docs/source/zh/index.mdx
+1
-1
src/transformers/__init__.py
src/transformers/__init__.py
+4
-0
src/transformers/models/auto/modeling_flax_auto.py
src/transformers/models/auto/modeling_flax_auto.py
+2
-0
src/transformers/models/regnet/__init__.py
src/transformers/models/regnet/__init__.py
+31
-1
src/transformers/models/regnet/modeling_flax_regnet.py
src/transformers/models/regnet/modeling_flax_regnet.py
+818
-0
src/transformers/models/resnet/modeling_flax_resnet.py
src/transformers/models/resnet/modeling_flax_resnet.py
+1
-1
src/transformers/utils/dummy_flax_objects.py
src/transformers/utils/dummy_flax_objects.py
+21
-0
tests/models/regnet/test_modeling_flax_regnet.py
tests/models/regnet/test_modeling_flax_regnet.py
+237
-0
No files found.
docs/source/de/index.mdx
View file @
90067748
...
@@ -283,7 +283,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
...
@@ -283,7 +283,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/en/index.mdx
View file @
90067748
...
@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow.
...
@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/en/model_doc/regnet.mdx
View file @
90067748
...
@@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f
...
@@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f
## TFRegNetForImageClassification
## TFRegNetForImageClassification
[[autodoc]] TFRegNetForImageClassification
[[autodoc]] TFRegNetForImageClassification
- call
- call
\ No newline at end of file
## FlaxRegNetModel
[[autodoc]] FlaxRegNetModel
- __call__
## FlaxRegNetForImageClassification
[[autodoc]] FlaxRegNetForImageClassification
- __call__
\ No newline at end of file
docs/source/es/index.mdx
View file @
90067748
...
@@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow.
...
@@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ |
❌
|
❌
|
| RegNet | ❌ | ❌ | ✅ |
✅
|
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/fr/index.mdx
View file @
90067748
...
@@ -347,7 +347,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè
...
@@ -347,7 +347,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/it/index.mdx
View file @
90067748
...
@@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
...
@@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ |
❌
|
❌
|
| RegNet | ❌ | ❌ | ✅ |
✅
|
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/ja/index.mdx
View file @
90067748
...
@@ -337,7 +337,7 @@ specific language governing permissions and limitations under the License.
...
@@ -337,7 +337,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/ko/index.mdx
View file @
90067748
...
@@ -306,7 +306,7 @@ specific language governing permissions and limitations under the License.
...
@@ -306,7 +306,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/pt/index.mdx
View file @
90067748
...
@@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
...
@@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ |
❌
|
❌
|
| RegNet | ❌ | ❌ | ✅ |
✅
|
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
docs/source/zh/index.mdx
View file @
90067748
...
@@ -336,7 +336,7 @@ Flax), PyTorch, 和/或者 TensorFlow.
...
@@ -336,7 +336,7 @@ Flax), PyTorch, 和/或者 TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ |
❌
|
| RegNet | ❌ | ❌ | ✅ | ✅ |
✅
|
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
...
...
src/transformers/__init__.py
View file @
90067748
...
@@ -3661,6 +3661,9 @@ else:
...
@@ -3661,6 +3661,9 @@ else:
"FlaxPegasusPreTrainedModel"
,
"FlaxPegasusPreTrainedModel"
,
]
]
)
)
_import_structure
[
"models.regnet"
].
extend
(
[
"FlaxRegNetForImageClassification"
,
"FlaxRegNetModel"
,
"FlaxRegNetPreTrainedModel"
]
)
_import_structure
[
"models.resnet"
].
extend
(
_import_structure
[
"models.resnet"
].
extend
(
[
"FlaxResNetForImageClassification"
,
"FlaxResNetModel"
,
"FlaxResNetPreTrainedModel"
]
[
"FlaxResNetForImageClassification"
,
"FlaxResNetModel"
,
"FlaxResNetPreTrainedModel"
]
)
)
...
@@ -6739,6 +6742,7 @@ if TYPE_CHECKING:
...
@@ -6739,6 +6742,7 @@ if TYPE_CHECKING:
from
.models.mt5
import
FlaxMT5EncoderModel
,
FlaxMT5ForConditionalGeneration
,
FlaxMT5Model
from
.models.mt5
import
FlaxMT5EncoderModel
,
FlaxMT5ForConditionalGeneration
,
FlaxMT5Model
from
.models.opt
import
FlaxOPTForCausalLM
,
FlaxOPTModel
,
FlaxOPTPreTrainedModel
from
.models.opt
import
FlaxOPTForCausalLM
,
FlaxOPTModel
,
FlaxOPTPreTrainedModel
from
.models.pegasus
import
FlaxPegasusForConditionalGeneration
,
FlaxPegasusModel
,
FlaxPegasusPreTrainedModel
from
.models.pegasus
import
FlaxPegasusForConditionalGeneration
,
FlaxPegasusModel
,
FlaxPegasusPreTrainedModel
from
.models.regnet
import
FlaxRegNetForImageClassification
,
FlaxRegNetModel
,
FlaxRegNetPreTrainedModel
from
.models.resnet
import
FlaxResNetForImageClassification
,
FlaxResNetModel
,
FlaxResNetPreTrainedModel
from
.models.resnet
import
FlaxResNetForImageClassification
,
FlaxResNetModel
,
FlaxResNetPreTrainedModel
from
.models.roberta
import
(
from
.models.roberta
import
(
FlaxRobertaForCausalLM
,
FlaxRobertaForCausalLM
,
...
...
src/transformers/models/auto/modeling_flax_auto.py
View file @
90067748
...
@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
...
@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
(
"mt5"
,
"FlaxMT5Model"
),
(
"mt5"
,
"FlaxMT5Model"
),
(
"opt"
,
"FlaxOPTModel"
),
(
"opt"
,
"FlaxOPTModel"
),
(
"pegasus"
,
"FlaxPegasusModel"
),
(
"pegasus"
,
"FlaxPegasusModel"
),
(
"regnet"
,
"FlaxRegNetModel"
),
(
"resnet"
,
"FlaxResNetModel"
),
(
"resnet"
,
"FlaxResNetModel"
),
(
"roberta"
,
"FlaxRobertaModel"
),
(
"roberta"
,
"FlaxRobertaModel"
),
(
"roberta-prelayernorm"
,
"FlaxRobertaPreLayerNormModel"
),
(
"roberta-prelayernorm"
,
"FlaxRobertaPreLayerNormModel"
),
...
@@ -120,6 +121,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
...
@@ -120,6 +121,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
[
# Model for Image-classsification
# Model for Image-classsification
(
"beit"
,
"FlaxBeitForImageClassification"
),
(
"beit"
,
"FlaxBeitForImageClassification"
),
(
"regnet"
,
"FlaxRegNetForImageClassification"
),
(
"resnet"
,
"FlaxResNetForImageClassification"
),
(
"resnet"
,
"FlaxResNetForImageClassification"
),
(
"vit"
,
"FlaxViTForImageClassification"
),
(
"vit"
,
"FlaxViTForImageClassification"
),
]
]
...
...
src/transformers/models/regnet/__init__.py
View file @
90067748
...
@@ -13,7 +13,13 @@
...
@@ -13,7 +13,13 @@
# limitations under the License.
# limitations under the License.
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
from
...utils
import
OptionalDependencyNotAvailable
,
_LazyModule
,
is_tf_available
,
is_torch_available
from
...utils
import
(
OptionalDependencyNotAvailable
,
_LazyModule
,
is_flax_available
,
is_tf_available
,
is_torch_available
,
)
_import_structure
=
{
"configuration_regnet"
:
[
"REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"RegNetConfig"
]}
_import_structure
=
{
"configuration_regnet"
:
[
"REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"RegNetConfig"
]}
...
@@ -44,6 +50,18 @@ else:
...
@@ -44,6 +50,18 @@ else:
"TFRegNetPreTrainedModel"
,
"TFRegNetPreTrainedModel"
,
]
]
try
:
if
not
is_flax_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
_import_structure
[
"modeling_flax_regnet"
]
=
[
"FlaxRegNetForImageClassification"
,
"FlaxRegNetModel"
,
"FlaxRegNetPreTrainedModel"
,
]
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
.configuration_regnet
import
REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RegNetConfig
from
.configuration_regnet
import
REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
RegNetConfig
...
@@ -74,6 +92,18 @@ if TYPE_CHECKING:
...
@@ -74,6 +92,18 @@ if TYPE_CHECKING:
TFRegNetPreTrainedModel
,
TFRegNetPreTrainedModel
,
)
)
try
:
if
not
is_flax_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
from
.modeling_flax_regnet
import
(
FlaxRegNetForImageClassification
,
FlaxRegNetModel
,
FlaxRegNetPreTrainedModel
,
)
else
:
else
:
import
sys
import
sys
...
...
src/transformers/models/regnet/modeling_flax_regnet.py
0 → 100644
View file @
90067748
# coding=utf-8
# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
Optional
,
Tuple
import
flax.linen
as
nn
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
,
freeze
,
unfreeze
from
flax.traverse_util
import
flatten_dict
,
unflatten_dict
from
transformers
import
RegNetConfig
from
transformers.modeling_flax_outputs
import
(
FlaxBaseModelOutputWithNoAttention
,
FlaxBaseModelOutputWithPooling
,
FlaxBaseModelOutputWithPoolingAndNoAttention
,
FlaxImageClassifierOutputWithNoAttention
,
)
from
transformers.modeling_flax_utils
import
(
ACT2FN
,
FlaxPreTrainedModel
,
append_replace_return_docstrings
,
overwrite_call_docstring
,
)
from
transformers.utils
import
(
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
)
REGNET_START_DOCSTRING
=
r
"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
REGNET_INPUTS_DOCSTRING
=
r
"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`RegNetImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.resnet.modeling_flax_resnet.Identity
class
Identity
(
nn
.
Module
):
"""Identity function."""
@
nn
.
compact
def
__call__
(
self
,
x
,
**
kwargs
):
return
x
class
FlaxRegNetConvLayer
(
nn
.
Module
):
out_channels
:
int
kernel_size
:
int
=
3
stride
:
int
=
1
groups
:
int
=
1
activation
:
Optional
[
str
]
=
"relu"
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
convolution
=
nn
.
Conv
(
self
.
out_channels
,
kernel_size
=
(
self
.
kernel_size
,
self
.
kernel_size
),
strides
=
self
.
stride
,
padding
=
self
.
kernel_size
//
2
,
feature_group_count
=
self
.
groups
,
use_bias
=
False
,
kernel_init
=
nn
.
initializers
.
variance_scaling
(
2.0
,
mode
=
"fan_out"
,
distribution
=
"truncated_normal"
),
dtype
=
self
.
dtype
,
)
self
.
normalization
=
nn
.
BatchNorm
(
momentum
=
0.9
,
epsilon
=
1e-05
,
dtype
=
self
.
dtype
)
self
.
activation_func
=
ACT2FN
[
self
.
activation
]
if
self
.
activation
is
not
None
else
Identity
()
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
hidden_state
=
self
.
convolution
(
hidden_state
)
hidden_state
=
self
.
normalization
(
hidden_state
,
use_running_average
=
deterministic
)
hidden_state
=
self
.
activation_func
(
hidden_state
)
return
hidden_state
class
FlaxRegNetEmbeddings
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
embedder
=
FlaxRegNetConvLayer
(
self
.
config
.
embedding_size
,
kernel_size
=
3
,
stride
=
2
,
activation
=
self
.
config
.
hidden_act
,
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
pixel_values
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
num_channels
=
pixel_values
.
shape
[
-
1
]
if
num_channels
!=
self
.
config
.
num_channels
:
raise
ValueError
(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state
=
self
.
embedder
(
pixel_values
,
deterministic
=
deterministic
)
return
hidden_state
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet
class
FlaxRegNetShortCut
(
nn
.
Module
):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
out_channels
:
int
stride
:
int
=
2
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
convolution
=
nn
.
Conv
(
self
.
out_channels
,
kernel_size
=
(
1
,
1
),
strides
=
self
.
stride
,
use_bias
=
False
,
kernel_init
=
nn
.
initializers
.
variance_scaling
(
2.0
,
mode
=
"fan_out"
,
distribution
=
"truncated_normal"
),
dtype
=
self
.
dtype
,
)
self
.
normalization
=
nn
.
BatchNorm
(
momentum
=
0.9
,
epsilon
=
1e-05
,
dtype
=
self
.
dtype
)
def
__call__
(
self
,
x
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
hidden_state
=
self
.
convolution
(
x
)
hidden_state
=
self
.
normalization
(
hidden_state
,
use_running_average
=
deterministic
)
return
hidden_state
class
FlaxRegNetSELayerCollection
(
nn
.
Module
):
in_channels
:
int
reduced_channels
:
int
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
conv_1
=
nn
.
Conv
(
self
.
reduced_channels
,
kernel_size
=
(
1
,
1
),
kernel_init
=
nn
.
initializers
.
variance_scaling
(
2.0
,
mode
=
"fan_out"
,
distribution
=
"truncated_normal"
),
dtype
=
self
.
dtype
,
name
=
"0"
,
)
# 0 is the name used in corresponding pytorch implementation
self
.
conv_2
=
nn
.
Conv
(
self
.
in_channels
,
kernel_size
=
(
1
,
1
),
kernel_init
=
nn
.
initializers
.
variance_scaling
(
2.0
,
mode
=
"fan_out"
,
distribution
=
"truncated_normal"
),
dtype
=
self
.
dtype
,
name
=
"2"
,
)
# 2 is the name used in corresponding pytorch implementation
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
hidden_state
=
self
.
conv_1
(
hidden_state
)
hidden_state
=
nn
.
relu
(
hidden_state
)
hidden_state
=
self
.
conv_2
(
hidden_state
)
attention
=
nn
.
sigmoid
(
hidden_state
)
return
attention
class
FlaxRegNetSELayer
(
nn
.
Module
):
"""
Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
"""
in_channels
:
int
reduced_channels
:
int
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
pooler
=
partial
(
nn
.
avg_pool
,
padding
=
((
0
,
0
),
(
0
,
0
)))
self
.
attention
=
FlaxRegNetSELayerCollection
(
self
.
in_channels
,
self
.
reduced_channels
,
dtype
=
self
.
dtype
)
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
pooled
=
self
.
pooler
(
hidden_state
,
window_shape
=
(
hidden_state
.
shape
[
1
],
hidden_state
.
shape
[
2
]),
strides
=
(
hidden_state
.
shape
[
1
],
hidden_state
.
shape
[
2
]),
)
attention
=
self
.
attention
(
pooled
)
hidden_state
=
hidden_state
*
attention
return
hidden_state
class
FlaxRegNetXLayerCollection
(
nn
.
Module
):
config
:
RegNetConfig
out_channels
:
int
stride
:
int
=
1
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
groups
=
max
(
1
,
self
.
out_channels
//
self
.
config
.
groups_width
)
self
.
layer
=
[
FlaxRegNetConvLayer
(
self
.
out_channels
,
kernel_size
=
1
,
activation
=
self
.
config
.
hidden_act
,
dtype
=
self
.
dtype
,
name
=
"0"
,
),
FlaxRegNetConvLayer
(
self
.
out_channels
,
stride
=
self
.
stride
,
groups
=
groups
,
activation
=
self
.
config
.
hidden_act
,
dtype
=
self
.
dtype
,
name
=
"1"
,
),
FlaxRegNetConvLayer
(
self
.
out_channels
,
kernel_size
=
1
,
activation
=
None
,
dtype
=
self
.
dtype
,
name
=
"2"
,
),
]
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
for
layer
in
self
.
layer
:
hidden_state
=
layer
(
hidden_state
,
deterministic
=
deterministic
)
return
hidden_state
class
FlaxRegNetXLayer
(
nn
.
Module
):
"""
RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
"""
config
:
RegNetConfig
in_channels
:
int
out_channels
:
int
stride
:
int
=
1
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
should_apply_shortcut
=
self
.
in_channels
!=
self
.
out_channels
or
self
.
stride
!=
1
self
.
shortcut
=
(
FlaxRegNetShortCut
(
self
.
out_channels
,
stride
=
self
.
stride
,
dtype
=
self
.
dtype
,
)
if
should_apply_shortcut
else
Identity
()
)
self
.
layer
=
FlaxRegNetXLayerCollection
(
self
.
config
,
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
stride
=
self
.
stride
,
dtype
=
self
.
dtype
,
)
self
.
activation_func
=
ACT2FN
[
self
.
config
.
hidden_act
]
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
residual
=
hidden_state
hidden_state
=
self
.
layer
(
hidden_state
)
residual
=
self
.
shortcut
(
residual
,
deterministic
=
deterministic
)
hidden_state
+=
residual
hidden_state
=
self
.
activation_func
(
hidden_state
)
return
hidden_state
class
FlaxRegNetYLayerCollection
(
nn
.
Module
):
config
:
RegNetConfig
in_channels
:
int
out_channels
:
int
stride
:
int
=
1
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
groups
=
max
(
1
,
self
.
out_channels
//
self
.
config
.
groups_width
)
self
.
layer
=
[
FlaxRegNetConvLayer
(
self
.
out_channels
,
kernel_size
=
1
,
activation
=
self
.
config
.
hidden_act
,
dtype
=
self
.
dtype
,
name
=
"0"
,
),
FlaxRegNetConvLayer
(
self
.
out_channels
,
stride
=
self
.
stride
,
groups
=
groups
,
activation
=
self
.
config
.
hidden_act
,
dtype
=
self
.
dtype
,
name
=
"1"
,
),
FlaxRegNetSELayer
(
self
.
out_channels
,
reduced_channels
=
int
(
round
(
self
.
in_channels
/
4
)),
dtype
=
self
.
dtype
,
name
=
"2"
,
),
FlaxRegNetConvLayer
(
self
.
out_channels
,
kernel_size
=
1
,
activation
=
None
,
dtype
=
self
.
dtype
,
name
=
"3"
,
),
]
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
for
layer
in
self
.
layer
:
hidden_state
=
layer
(
hidden_state
)
return
hidden_state
class
FlaxRegNetYLayer
(
nn
.
Module
):
"""
RegNet's Y layer: an X layer with Squeeze and Excitation.
"""
config
:
RegNetConfig
in_channels
:
int
out_channels
:
int
stride
:
int
=
1
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
should_apply_shortcut
=
self
.
in_channels
!=
self
.
out_channels
or
self
.
stride
!=
1
self
.
shortcut
=
(
FlaxRegNetShortCut
(
self
.
out_channels
,
stride
=
self
.
stride
,
dtype
=
self
.
dtype
,
)
if
should_apply_shortcut
else
Identity
()
)
self
.
layer
=
FlaxRegNetYLayerCollection
(
self
.
config
,
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
stride
=
self
.
stride
,
dtype
=
self
.
dtype
,
)
self
.
activation_func
=
ACT2FN
[
self
.
config
.
hidden_act
]
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
residual
=
hidden_state
hidden_state
=
self
.
layer
(
hidden_state
)
residual
=
self
.
shortcut
(
residual
,
deterministic
=
deterministic
)
hidden_state
+=
residual
hidden_state
=
self
.
activation_func
(
hidden_state
)
return
hidden_state
class
FlaxRegNetStageLayersCollection
(
nn
.
Module
):
"""
A RegNet stage composed by stacked layers.
"""
config
:
RegNetConfig
in_channels
:
int
out_channels
:
int
stride
:
int
=
2
depth
:
int
=
2
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
layer
=
FlaxRegNetXLayer
if
self
.
config
.
layer_type
==
"x"
else
FlaxRegNetYLayer
layers
=
[
# downsampling is done in the first layer with stride of 2
layer
(
self
.
config
,
self
.
in_channels
,
self
.
out_channels
,
stride
=
self
.
stride
,
dtype
=
self
.
dtype
,
name
=
"0"
,
)
]
for
i
in
range
(
self
.
depth
-
1
):
layers
.
append
(
layer
(
self
.
config
,
self
.
out_channels
,
self
.
out_channels
,
dtype
=
self
.
dtype
,
name
=
str
(
i
+
1
),
)
)
self
.
layers
=
layers
def
__call__
(
self
,
x
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
hidden_state
=
x
for
layer
in
self
.
layers
:
hidden_state
=
layer
(
hidden_state
,
deterministic
=
deterministic
)
return
hidden_state
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet
class
FlaxRegNetStage
(
nn
.
Module
):
"""
A RegNet stage composed by stacked layers.
"""
config
:
RegNetConfig
in_channels
:
int
out_channels
:
int
stride
:
int
=
2
depth
:
int
=
2
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
layers
=
FlaxRegNetStageLayersCollection
(
self
.
config
,
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
stride
=
self
.
stride
,
depth
=
self
.
depth
,
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
x
:
jnp
.
ndarray
,
deterministic
:
bool
=
True
)
->
jnp
.
ndarray
:
return
self
.
layers
(
x
,
deterministic
=
deterministic
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet
class
FlaxRegNetStageCollection
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
in_out_channels
=
zip
(
self
.
config
.
hidden_sizes
,
self
.
config
.
hidden_sizes
[
1
:])
stages
=
[
FlaxRegNetStage
(
self
.
config
,
self
.
config
.
embedding_size
,
self
.
config
.
hidden_sizes
[
0
],
stride
=
2
if
self
.
config
.
downsample_in_first_stage
else
1
,
depth
=
self
.
config
.
depths
[
0
],
dtype
=
self
.
dtype
,
name
=
"0"
,
)
]
for
i
,
((
in_channels
,
out_channels
),
depth
)
in
enumerate
(
zip
(
in_out_channels
,
self
.
config
.
depths
[
1
:])):
stages
.
append
(
FlaxRegNetStage
(
self
.
config
,
in_channels
,
out_channels
,
depth
=
depth
,
dtype
=
self
.
dtype
,
name
=
str
(
i
+
1
))
)
self
.
stages
=
stages
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
output_hidden_states
:
bool
=
False
,
deterministic
:
bool
=
True
,
)
->
FlaxBaseModelOutputWithNoAttention
:
hidden_states
=
()
if
output_hidden_states
else
None
for
stage_module
in
self
.
stages
:
if
output_hidden_states
:
hidden_states
=
hidden_states
+
(
hidden_state
.
transpose
(
0
,
3
,
1
,
2
),)
hidden_state
=
stage_module
(
hidden_state
,
deterministic
=
deterministic
)
return
hidden_state
,
hidden_states
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet
class
FlaxRegNetEncoder
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
stages
=
FlaxRegNetStageCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
def
__call__
(
self
,
hidden_state
:
jnp
.
ndarray
,
output_hidden_states
:
bool
=
False
,
return_dict
:
bool
=
True
,
deterministic
:
bool
=
True
,
)
->
FlaxBaseModelOutputWithNoAttention
:
hidden_state
,
hidden_states
=
self
.
stages
(
hidden_state
,
output_hidden_states
=
output_hidden_states
,
deterministic
=
deterministic
)
if
output_hidden_states
:
hidden_states
=
hidden_states
+
(
hidden_state
.
transpose
(
0
,
3
,
1
,
2
),)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_state
,
hidden_states
]
if
v
is
not
None
)
return
FlaxBaseModelOutputWithNoAttention
(
last_hidden_state
=
hidden_state
,
hidden_states
=
hidden_states
,
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET
class
FlaxRegNetPreTrainedModel
(
FlaxPreTrainedModel
):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class
=
RegNetConfig
base_model_prefix
=
"regnet"
main_input_name
=
"pixel_values"
module_class
:
nn
.
Module
=
None
def
__init__
(
self
,
config
:
RegNetConfig
,
input_shape
=
(
1
,
224
,
224
,
3
),
seed
:
int
=
0
,
dtype
:
jnp
.
dtype
=
jnp
.
float32
,
_do_init
:
bool
=
True
,
**
kwargs
,
):
module
=
self
.
module_class
(
config
=
config
,
dtype
=
dtype
,
**
kwargs
)
if
input_shape
is
None
:
input_shape
=
(
1
,
config
.
image_size
,
config
.
image_size
,
config
.
num_channels
)
super
().
__init__
(
config
,
module
,
input_shape
=
input_shape
,
seed
=
seed
,
dtype
=
dtype
,
_do_init
=
_do_init
)
def
init_weights
(
self
,
rng
:
jax
.
random
.
PRNGKey
,
input_shape
:
Tuple
,
params
:
FrozenDict
=
None
)
->
FrozenDict
:
# init input tensors
pixel_values
=
jnp
.
zeros
(
input_shape
,
dtype
=
self
.
dtype
)
rngs
=
{
"params"
:
rng
}
random_params
=
self
.
module
.
init
(
rngs
,
pixel_values
,
return_dict
=
False
)
if
params
is
not
None
:
random_params
=
flatten_dict
(
unfreeze
(
random_params
))
params
=
flatten_dict
(
unfreeze
(
params
))
for
missing_key
in
self
.
_missing_keys
:
params
[
missing_key
]
=
random_params
[
missing_key
]
self
.
_missing_keys
=
set
()
return
freeze
(
unflatten_dict
(
params
))
else
:
return
random_params
@
add_start_docstrings_to_model_forward
(
REGNET_INPUTS_DOCSTRING
)
def
__call__
(
self
,
pixel_values
,
params
:
dict
=
None
,
train
:
bool
=
False
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
):
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
return_dict
pixel_values
=
jnp
.
transpose
(
pixel_values
,
(
0
,
2
,
3
,
1
))
# Handle any PRNG if needed
rngs
=
{}
return
self
.
module
.
apply
(
{
"params"
:
params
[
"params"
]
if
params
is
not
None
else
self
.
params
[
"params"
],
"batch_stats"
:
params
[
"batch_stats"
]
if
params
is
not
None
else
self
.
params
[
"batch_stats"
],
},
jnp
.
array
(
pixel_values
,
dtype
=
jnp
.
float32
),
not
train
,
output_hidden_states
,
return_dict
,
rngs
=
rngs
,
mutable
=
[
"batch_stats"
]
if
train
else
False
,
# Returing tuple with batch_stats only when train is True
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet
class
FlaxRegNetModule
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
# the dtype of the computation
def
setup
(
self
):
self
.
embedder
=
FlaxRegNetEmbeddings
(
self
.
config
,
dtype
=
self
.
dtype
)
self
.
encoder
=
FlaxRegNetEncoder
(
self
.
config
,
dtype
=
self
.
dtype
)
# Adaptive average pooling used in resnet
self
.
pooler
=
partial
(
nn
.
avg_pool
,
padding
=
((
0
,
0
),
(
0
,
0
)),
)
def
__call__
(
self
,
pixel_values
,
deterministic
:
bool
=
True
,
output_hidden_states
:
bool
=
False
,
return_dict
:
bool
=
True
,
)
->
FlaxBaseModelOutputWithPoolingAndNoAttention
:
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
embedding_output
=
self
.
embedder
(
pixel_values
,
deterministic
=
deterministic
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
deterministic
=
deterministic
,
)
last_hidden_state
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
last_hidden_state
,
window_shape
=
(
last_hidden_state
.
shape
[
1
],
last_hidden_state
.
shape
[
2
]),
strides
=
(
last_hidden_state
.
shape
[
1
],
last_hidden_state
.
shape
[
2
]),
).
transpose
(
0
,
3
,
1
,
2
)
last_hidden_state
=
last_hidden_state
.
transpose
(
0
,
3
,
1
,
2
)
if
not
return_dict
:
return
(
last_hidden_state
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
FlaxBaseModelOutputWithPoolingAndNoAttention
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
.
hidden_states
,
)
@
add_start_docstrings
(
"The bare RegNet model outputting raw features without any specific head on top."
,
REGNET_START_DOCSTRING
,
)
class
FlaxRegNetModel
(
FlaxRegNetPreTrainedModel
):
module_class
=
FlaxRegNetModule
FLAX_VISION_MODEL_DOCSTRING
=
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, FlaxRegNetModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
overwrite_call_docstring
(
FlaxRegNetModel
,
FLAX_VISION_MODEL_DOCSTRING
)
append_replace_return_docstrings
(
FlaxRegNetModel
,
output_type
=
FlaxBaseModelOutputWithPooling
,
config_class
=
RegNetConfig
,
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet
class
FlaxRegNetClassifierCollection
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
classifier
=
nn
.
Dense
(
self
.
config
.
num_labels
,
dtype
=
self
.
dtype
,
name
=
"1"
)
def
__call__
(
self
,
x
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
return
self
.
classifier
(
x
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET
class
FlaxRegNetForImageClassificationModule
(
nn
.
Module
):
config
:
RegNetConfig
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
regnet
=
FlaxRegNetModule
(
config
=
self
.
config
,
dtype
=
self
.
dtype
)
if
self
.
config
.
num_labels
>
0
:
self
.
classifier
=
FlaxRegNetClassifierCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
else
:
self
.
classifier
=
Identity
()
def
__call__
(
self
,
pixel_values
=
None
,
deterministic
:
bool
=
True
,
output_hidden_states
=
None
,
return_dict
=
None
,
):
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
outputs
=
self
.
regnet
(
pixel_values
,
deterministic
=
deterministic
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
pooled_output
=
outputs
.
pooler_output
if
return_dict
else
outputs
[
1
]
logits
=
self
.
classifier
(
pooled_output
[:,
:,
0
,
0
])
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
2
:]
return
output
return
FlaxImageClassifierOutputWithNoAttention
(
logits
=
logits
,
hidden_states
=
outputs
.
hidden_states
)
@
add_start_docstrings
(
"""
RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
"""
,
REGNET_START_DOCSTRING
,
)
class
FlaxRegNetForImageClassification
(
FlaxRegNetPreTrainedModel
):
module_class
=
FlaxRegNetForImageClassificationModule
FLAX_VISION_CLASSIF_DOCSTRING
=
"""
Returns:
Example:
```python
>>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
```
"""
overwrite_call_docstring
(
FlaxRegNetForImageClassification
,
FLAX_VISION_CLASSIF_DOCSTRING
)
append_replace_return_docstrings
(
FlaxRegNetForImageClassification
,
output_type
=
FlaxImageClassifierOutputWithNoAttention
,
config_class
=
RegNetConfig
,
)
src/transformers/models/resnet/modeling_flax_resnet.py
View file @
90067748
...
@@ -89,7 +89,7 @@ class Identity(nn.Module):
...
@@ -89,7 +89,7 @@ class Identity(nn.Module):
"""Identity function."""
"""Identity function."""
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
,
**
kwargs
):
return
x
return
x
...
...
src/transformers/utils/dummy_flax_objects.py
View file @
90067748
...
@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
...
@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends
(
self
,
[
"flax"
])
requires_backends
(
self
,
[
"flax"
])
class
FlaxRegNetForImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
])
class
FlaxRegNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
])
class
FlaxRegNetPreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"flax"
])
class
FlaxResNetForImageClassification
(
metaclass
=
DummyObject
):
class
FlaxResNetForImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"flax"
]
_backends
=
[
"flax"
]
...
...
tests/models/regnet/test_modeling_flax_regnet.py
0 → 100644
View file @
90067748
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
unittest
from
transformers
import
RegNetConfig
,
is_flax_available
from
transformers.testing_utils
import
require_flax
,
slow
from
transformers.utils
import
cached_property
,
is_vision_available
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_flax_common
import
FlaxModelTesterMixin
,
floats_tensor
if
is_flax_available
():
import
jax
import
jax.numpy
as
jnp
from
transformers.models.regnet.modeling_flax_regnet
import
FlaxRegNetForImageClassification
,
FlaxRegNetModel
if
is_vision_available
():
from
PIL
import
Image
from
transformers
import
AutoFeatureExtractor
class
FlaxRegNetModelTester
(
unittest
.
TestCase
):
def
__init__
(
self
,
parent
,
batch_size
=
3
,
image_size
=
32
,
num_channels
=
3
,
embeddings_size
=
10
,
hidden_sizes
=
[
10
,
20
,
30
,
40
],
depths
=
[
1
,
1
,
2
,
1
],
is_training
=
True
,
use_labels
=
True
,
hidden_act
=
"relu"
,
num_labels
=
3
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
image_size
=
image_size
self
.
num_channels
=
num_channels
self
.
embeddings_size
=
embeddings_size
self
.
hidden_sizes
=
hidden_sizes
self
.
depths
=
depths
self
.
is_training
=
is_training
self
.
use_labels
=
use_labels
self
.
hidden_act
=
hidden_act
self
.
num_labels
=
num_labels
self
.
scope
=
scope
self
.
num_stages
=
len
(
hidden_sizes
)
def
prepare_config_and_inputs
(
self
):
pixel_values
=
floats_tensor
([
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
])
config
=
self
.
get_config
()
return
config
,
pixel_values
def
get_config
(
self
):
return
RegNetConfig
(
num_channels
=
self
.
num_channels
,
embeddings_size
=
self
.
embeddings_size
,
hidden_sizes
=
self
.
hidden_sizes
,
depths
=
self
.
depths
,
hidden_act
=
self
.
hidden_act
,
num_labels
=
self
.
num_labels
,
image_size
=
self
.
image_size
,
)
def
create_and_check_model
(
self
,
config
,
pixel_values
):
model
=
FlaxRegNetModel
(
config
=
config
)
result
=
model
(
pixel_values
)
# Output shape (b, c, h, w)
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
self
.
hidden_sizes
[
-
1
],
self
.
image_size
//
32
,
self
.
image_size
//
32
),
)
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
):
config
.
num_labels
=
self
.
num_labels
model
=
FlaxRegNetForImageClassification
(
config
=
config
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_labels
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config
,
pixel_values
=
config_and_inputs
inputs_dict
=
{
"pixel_values"
:
pixel_values
}
return
config
,
inputs_dict
@
require_flax
class
FlaxResNetModelTest
(
FlaxModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
FlaxRegNetModel
,
FlaxRegNetForImageClassification
)
if
is_flax_available
()
else
()
is_encoder_decoder
=
False
test_head_masking
=
False
has_attentions
=
False
def
setUp
(
self
)
->
None
:
self
.
model_tester
=
FlaxRegNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
RegNetConfig
,
has_text_modality
=
False
)
def
test_config
(
self
):
self
.
create_and_test_config_common_properties
()
self
.
config_tester
.
create_and_test_config_to_json_string
()
self
.
config_tester
.
create_and_test_config_to_json_file
()
self
.
config_tester
.
create_and_test_config_from_and_save_pretrained
()
self
.
config_tester
.
create_and_test_config_with_num_labels
()
self
.
config_tester
.
check_config_can_be_init_without_params
()
self
.
config_tester
.
check_config_arguments_init
()
def
create_and_test_config_common_properties
(
self
):
return
def
test_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
def
test_for_image_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_for_image_classification
(
*
config_and_inputs
)
@
unittest
.
skip
(
reason
=
"RegNet does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
pass
@
unittest
.
skip
(
reason
=
"RegNet does not support input and output embeddings"
)
def
test_model_common_attributes
(
self
):
pass
def
test_forward_signature
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
signature
=
inspect
.
signature
(
model
.
__call__
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"pixel_values"
]
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
def
test_hidden_states_output
(
self
):
def
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
):
model
=
model_class
(
config
)
outputs
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
hidden_states
=
outputs
.
encoder_hidden_states
if
config
.
is_encoder_decoder
else
outputs
.
hidden_states
expected_num_stages
=
self
.
model_tester
.
num_stages
self
.
assertEqual
(
len
(
hidden_states
),
expected_num_stages
+
1
)
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
inputs_dict
[
"output_hidden_states"
]
=
True
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
)
# check that output_hidden_states also work using config
del
inputs_dict
[
"output_hidden_states"
]
config
.
output_hidden_states
=
True
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
)
def
test_jit_compilation
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
with
self
.
subTest
(
model_class
.
__name__
):
prepared_inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
)
@
jax
.
jit
def
model_jitted
(
pixel_values
,
**
kwargs
):
return
model
(
pixel_values
=
pixel_values
,
**
kwargs
)
with
self
.
subTest
(
"JIT Enabled"
):
jitted_outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
with
self
.
subTest
(
"JIT Disabled"
):
with
jax
.
disable_jit
():
outputs
=
model_jitted
(
**
prepared_inputs_dict
).
to_tuple
()
self
.
assertEqual
(
len
(
outputs
),
len
(
jitted_outputs
))
for
jitted_output
,
output
in
zip
(
jitted_outputs
,
outputs
):
self
.
assertEqual
(
jitted_output
.
shape
,
output
.
shape
)
# We will verify our results on an image of cute cats
def
prepare_img
():
image
=
Image
.
open
(
"./tests/fixtures/tests_samples/COCO/000000039769.png"
)
return
image
@
require_flax
class
FlaxRegNetModelIntegrationTest
(
unittest
.
TestCase
):
@
cached_property
def
default_feature_extractor
(
self
):
return
AutoFeatureExtractor
.
from_pretrained
(
"facebook/regnet-y-040"
)
if
is_vision_available
()
else
None
@
slow
def
test_inference_image_classification_head
(
self
):
model
=
FlaxRegNetForImageClassification
.
from_pretrained
(
"facebook/regnet-y-040"
)
feature_extractor
=
self
.
default_feature_extractor
image
=
prepare_img
()
inputs
=
feature_extractor
(
images
=
image
,
return_tensors
=
"np"
)
outputs
=
model
(
**
inputs
)
# verify the logits
expected_shape
=
(
1
,
1000
)
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
expected_slice
=
jnp
.
array
([
-
0.4180
,
-
1.5051
,
-
3.4836
])
self
.
assertTrue
(
jnp
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment