Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
49ab1623
Unverified
Commit
49ab1623
authored
Feb 20, 2023
by
Alara Dirik
Committed by
GitHub
Feb 20, 2023
Browse files
Add EfficientNet (#21563)
* Add EfficientNet to transformers
parent
c9a06714
Changes
34
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2091 additions
and
1 deletion
+2091
-1
src/transformers/models/auto/configuration_auto.py
src/transformers/models/auto/configuration_auto.py
+3
-0
src/transformers/models/auto/image_processing_auto.py
src/transformers/models/auto/image_processing_auto.py
+1
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+3
-0
src/transformers/models/efficientnet/__init__.py
src/transformers/models/efficientnet/__init__.py
+84
-0
src/transformers/models/efficientnet/configuration_efficientnet.py
...formers/models/efficientnet/configuration_efficientnet.py
+169
-0
src/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
...rs/models/efficientnet/convert_efficientnet_to_pytorch.py
+339
-0
src/transformers/models/efficientnet/image_processing_efficientnet.py
...mers/models/efficientnet/image_processing_efficientnet.py
+347
-0
src/transformers/models/efficientnet/modeling_efficientnet.py
...transformers/models/efficientnet/modeling_efficientnet.py
+662
-0
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+24
-0
src/transformers/utils/dummy_vision_objects.py
src/transformers/utils/dummy_vision_objects.py
+7
-0
tests/models/convnext/test_modeling_convnext.py
tests/models/convnext/test_modeling_convnext.py
+0
-1
tests/models/efficientnet/__init__.py
tests/models/efficientnet/__init__.py
+0
-0
tests/models/efficientnet/test_image_processing_efficientnet.py
...models/efficientnet/test_image_processing_efficientnet.py
+195
-0
tests/models/efficientnet/test_modeling_efficientnet.py
tests/models/efficientnet/test_modeling_efficientnet.py
+257
-0
No files found.
src/transformers/models/auto/configuration_auto.py
View file @
49ab1623
...
...
@@ -74,6 +74,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
(
"dpr"
,
"DPRConfig"
),
(
"dpt"
,
"DPTConfig"
),
(
"efficientformer"
,
"EfficientFormerConfig"
),
(
"efficientnet"
,
"EfficientNetConfig"
),
(
"electra"
,
"ElectraConfig"
),
(
"encoder-decoder"
,
"EncoderDecoderConfig"
),
(
"ernie"
,
"ErnieConfig"
),
...
...
@@ -248,6 +249,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
(
"dpr"
,
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"dpt"
,
"DPT_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"efficientformer"
,
"EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"efficientnet"
,
"EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"electra"
,
"ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"ernie"
,
"ERNIE_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
(
"ernie_m"
,
"ERNIE_M_PRETRAINED_CONFIG_ARCHIVE_MAP"
),
...
...
@@ -417,6 +419,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
(
"dpr"
,
"DPR"
),
(
"dpt"
,
"DPT"
),
(
"efficientformer"
,
"EfficientFormer"
),
(
"efficientnet"
,
"EfficientNet"
),
(
"electra"
,
"ELECTRA"
),
(
"encoder-decoder"
,
"Encoder decoder"
),
(
"ernie"
,
"ERNIE"
),
...
...
src/transformers/models/auto/image_processing_auto.py
View file @
49ab1623
...
...
@@ -57,6 +57,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
(
"donut-swin"
,
"DonutImageProcessor"
),
(
"dpt"
,
"DPTImageProcessor"
),
(
"efficientformer"
,
"EfficientFormerImageProcessor"
),
(
"efficientnet"
,
"EfficientNetImageProcessor"
),
(
"flava"
,
"FlavaImageProcessor"
),
(
"git"
,
"CLIPImageProcessor"
),
(
"glpn"
,
"GLPNImageProcessor"
),
...
...
src/transformers/models/auto/modeling_auto.py
View file @
49ab1623
...
...
@@ -73,6 +73,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
(
"dpr"
,
"DPRQuestionEncoder"
),
(
"dpt"
,
"DPTModel"
),
(
"efficientformer"
,
"EfficientFormerModel"
),
(
"efficientnet"
,
"EfficientNetModel"
),
(
"electra"
,
"ElectraModel"
),
(
"ernie"
,
"ErnieModel"
),
(
"ernie_m"
,
"ErnieMModel"
),
...
...
@@ -419,6 +420,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
"EfficientFormerForImageClassificationWithTeacher"
,
),
),
(
"efficientnet"
,
"EfficientNetForImageClassification"
),
(
"imagegpt"
,
"ImageGPTForImageClassification"
),
(
"levit"
,
(
"LevitForImageClassification"
,
"LevitForImageClassificationWithTeacher"
)),
(
"mobilenet_v1"
,
"MobileNetV1ForImageClassification"
),
...
...
@@ -933,6 +935,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
(
"bit"
,
"BitBackbone"
),
(
"convnext"
,
"ConvNextBackbone"
),
(
"dinat"
,
"DinatBackbone"
),
(
"efficientnet"
,
"EfficientNetBackbone"
),
(
"maskformer-swin"
,
"MaskFormerSwinBackbone"
),
(
"nat"
,
"NatBackbone"
),
(
"resnet"
,
"ResNetBackbone"
),
...
...
src/transformers/models/efficientnet/__init__.py
0 → 100644
View file @
49ab1623
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
# rely on isort to merge the imports
from
...utils
import
OptionalDependencyNotAvailable
,
_LazyModule
,
is_torch_available
,
is_vision_available
_import_structure
=
{
"configuration_efficientnet"
:
[
"EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"EfficientNetConfig"
,
"EfficientNetOnnxConfig"
,
]
}
try
:
if
not
is_vision_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
_import_structure
[
"image_processing_efficientnet"
]
=
[
"EfficientNetImageProcessor"
]
try
:
if
not
is_torch_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
_import_structure
[
"modeling_efficientnet"
]
=
[
"EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"EfficientNetForImageClassification"
,
"EfficientNetModel"
,
"EfficientNetPreTrainedModel"
,
]
if
TYPE_CHECKING
:
from
.configuration_efficientnet
import
(
EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
EfficientNetConfig
,
EfficientNetOnnxConfig
,
)
try
:
if
not
is_vision_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
from
.image_processing_efficientnet
import
EfficientNetImageProcessor
try
:
if
not
is_torch_available
():
raise
OptionalDependencyNotAvailable
()
except
OptionalDependencyNotAvailable
:
pass
else
:
from
.modeling_efficientnet
import
(
EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LIST
,
EfficientNetForImageClassification
,
EfficientNetModel
,
EfficientNetPreTrainedModel
,
)
else
:
import
sys
sys
.
modules
[
__name__
]
=
_LazyModule
(
__name__
,
globals
()[
"__file__"
],
_import_structure
)
src/transformers/models/efficientnet/configuration_efficientnet.py
0 → 100644
View file @
49ab1623
# coding=utf-8
# Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" EfficientNet model configuration"""
from
collections
import
OrderedDict
from
typing
import
List
,
Mapping
from
packaging
import
version
from
...configuration_utils
import
PretrainedConfig
from
...onnx
import
OnnxConfig
from
...utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
EFFICIENTNET_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
"google/efficientnet-b7"
:
"https://huggingface.co/google/efficientnet-b7/resolve/main/config.json"
,
}
class
EfficientNetConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`EfficientNetModel`]. It is used to instantiate an
EfficientNet model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the EfficientNet
[google/efficientnet-b7](https://huggingface.co/google/efficientnet-b7) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
image_size (`int`, *optional*, defaults to 600):
The input image size.
width_coefficient (`float`, *optional*, defaults to 2.0):
Scaling coefficient for network width at each stage.
depth_coefficient (`float`, *optional*, defaults to 3.1):
Scaling coefficient for network depth at each stage.
depth_divisor `int`, *optional*, defaults to 8):
A unit of network width.
kernel_sizes (`List[int]`, *optional*, defaults to `[3, 3, 5, 3, 5, 5, 3]`):
List of kernel sizes to be used in each block.
in_channels (`List[int]`, *optional*, defaults to `[32, 16, 24, 40, 80, 112, 192]`):
List of input channel sizes to be used in each block for convolutional layers.
out_channels (`List[int]`, *optional*, defaults to `[16, 24, 40, 80, 112, 192, 320]`):
List of output channel sizes to be used in each block for convolutional layers.
depthwise_padding (`List[int]`, *optional*, defaults to `[]`):
List of block indices with square padding.
strides: (`List[int]`, *optional*, defaults to `[1, 2, 2, 2, 1, 2, 1]`):
List of stride sizes to be used in each block for convolutional layers.
num_block_repeats (`List[int]`, *optional*, defaults to `[1, 2, 2, 3, 3, 4, 1]`):
List of the number of times each block is to repeated.
expand_ratios (`List[int]`, *optional*, defaults to `[1, 6, 6, 6, 6, 6, 6]`):
List of scaling coefficient of each block.
squeeze_expansion_ratio (`float`, *optional*, defaults to 0.25):
Squeeze expansion ratio.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in each block. If string, `"gelu"`, `"relu"`,
`"selu", `"gelu_new"`, `"silu"` and `"mish"` are supported.
hiddem_dim (`int`, *optional*, defaults to 1280):
The hidden dimension of the layer before the classification head.
pooling_type (`str` or `function`, *optional*, defaults to `"mean"`):
Type of final pooling to be applied before the dense classification head. Available options are [`"mean"`,
`"max"`]
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
batch_norm_eps (`float`, *optional*, defaults to 1e-3):
The epsilon used by the batch normalization layers.
batch_norm_momentum (`float`, *optional*, defaults to 0.99):
The momentum used by the batch normalization layers.
dropout_rate (`float`, *optional*, defaults to 0.5):
The dropout rate to be applied before final classifier layer.
drop_connect_rate (`float`, *optional*, defaults to 0.2):
The drop rate for skip connections.
Example:
```python
>>> from transformers import EfficientNetConfig, EfficientNetModel
>>> # Initializing a EfficientNet efficientnet-b7 style configuration
>>> configuration = EfficientNetConfig()
>>> # Initializing a model (with random weights) from the efficientnet-b7 style configuration
>>> model = EfficientNetModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"efficientnet"
def
__init__
(
self
,
num_channels
:
int
=
3
,
image_size
:
int
=
600
,
width_coefficient
:
float
=
2.0
,
depth_coefficient
:
float
=
3.1
,
depth_divisor
:
int
=
8
,
kernel_sizes
:
List
[
int
]
=
[
3
,
3
,
5
,
3
,
5
,
5
,
3
],
in_channels
:
List
[
int
]
=
[
32
,
16
,
24
,
40
,
80
,
112
,
192
],
out_channels
:
List
[
int
]
=
[
16
,
24
,
40
,
80
,
112
,
192
,
320
],
depthwise_padding
:
List
[
int
]
=
[],
strides
:
List
[
int
]
=
[
1
,
2
,
2
,
2
,
1
,
2
,
1
],
num_block_repeats
:
List
[
int
]
=
[
1
,
2
,
2
,
3
,
3
,
4
,
1
],
expand_ratios
:
List
[
int
]
=
[
1
,
6
,
6
,
6
,
6
,
6
,
6
],
squeeze_expansion_ratio
:
float
=
0.25
,
hidden_act
:
str
=
"swish"
,
hidden_dim
:
int
=
2560
,
pooling_type
:
str
=
"mean"
,
initializer_range
:
float
=
0.02
,
batch_norm_eps
:
float
=
0.001
,
batch_norm_momentum
:
float
=
0.99
,
dropout_rate
:
float
=
0.5
,
drop_connect_rate
:
float
=
0.2
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
num_channels
=
num_channels
self
.
image_size
=
image_size
self
.
width_coefficient
=
width_coefficient
self
.
depth_coefficient
=
depth_coefficient
self
.
depth_divisor
=
depth_divisor
self
.
kernel_sizes
=
kernel_sizes
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
depthwise_padding
=
depthwise_padding
self
.
strides
=
strides
self
.
num_block_repeats
=
num_block_repeats
self
.
expand_ratios
=
expand_ratios
self
.
squeeze_expansion_ratio
=
squeeze_expansion_ratio
self
.
hidden_act
=
hidden_act
self
.
hidden_dim
=
hidden_dim
self
.
pooling_type
=
pooling_type
self
.
initializer_range
=
initializer_range
self
.
batch_norm_eps
=
batch_norm_eps
self
.
batch_norm_momentum
=
batch_norm_momentum
self
.
dropout_rate
=
dropout_rate
self
.
drop_connect_rate
=
drop_connect_rate
self
.
num_hidden_layers
=
sum
(
num_block_repeats
)
*
4
class
EfficientNetOnnxConfig
(
OnnxConfig
):
torch_onnx_minimum_version
=
version
.
parse
(
"1.11"
)
@
property
def
inputs
(
self
)
->
Mapping
[
str
,
Mapping
[
int
,
str
]]:
return
OrderedDict
(
[
(
"pixel_values"
,
{
0
:
"batch"
,
1
:
"num_channels"
,
2
:
"height"
,
3
:
"width"
}),
]
)
@
property
def
atol_for_validation
(
self
)
->
float
:
return
1e-5
src/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
0 → 100644
View file @
49ab1623
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert EfficientNet checkpoints from the original repository.
URL: https://github.com/keras-team/keras/blob/v2.11.0/keras/applications/efficientnet.py"""
import
argparse
import
json
import
os
import
numpy
as
np
import
PIL
import
requests
import
tensorflow.keras.applications.efficientnet
as
efficientnet
import
torch
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
tensorflow.keras.preprocessing
import
image
from
transformers
import
(
EfficientNetConfig
,
EfficientNetForImageClassification
,
EfficientNetImageProcessor
,
)
from
transformers.utils
import
logging
logging
.
set_verbosity_info
()
logger
=
logging
.
get_logger
(
__name__
)
model_classes
=
{
"b0"
:
efficientnet
.
EfficientNetB0
,
"b1"
:
efficientnet
.
EfficientNetB1
,
"b2"
:
efficientnet
.
EfficientNetB2
,
"b3"
:
efficientnet
.
EfficientNetB3
,
"b4"
:
efficientnet
.
EfficientNetB4
,
"b5"
:
efficientnet
.
EfficientNetB5
,
"b6"
:
efficientnet
.
EfficientNetB6
,
"b7"
:
efficientnet
.
EfficientNetB7
,
}
CONFIG_MAP
=
{
"b0"
:
{
"hidden_dim"
:
1280
,
"width_coef"
:
1.0
,
"depth_coef"
:
1.0
,
"image_size"
:
224
,
"dropout_rate"
:
0.2
,
"dw_padding"
:
[],
},
"b1"
:
{
"hidden_dim"
:
1280
,
"width_coef"
:
1.0
,
"depth_coef"
:
1.1
,
"image_size"
:
240
,
"dropout_rate"
:
0.2
,
"dw_padding"
:
[
16
],
},
"b2"
:
{
"hidden_dim"
:
1408
,
"width_coef"
:
1.1
,
"depth_coef"
:
1.2
,
"image_size"
:
260
,
"dropout_rate"
:
0.3
,
"dw_padding"
:
[
5
,
8
,
16
],
},
"b3"
:
{
"hidden_dim"
:
1536
,
"width_coef"
:
1.2
,
"depth_coef"
:
1.4
,
"image_size"
:
300
,
"dropout_rate"
:
0.3
,
"dw_padding"
:
[
5
,
18
],
},
"b4"
:
{
"hidden_dim"
:
1792
,
"width_coef"
:
1.4
,
"depth_coef"
:
1.8
,
"image_size"
:
380
,
"dropout_rate"
:
0.4
,
"dw_padding"
:
[
6
],
},
"b5"
:
{
"hidden_dim"
:
2048
,
"width_coef"
:
1.6
,
"depth_coef"
:
2.2
,
"image_size"
:
456
,
"dropout_rate"
:
0.4
,
"dw_padding"
:
[
13
,
27
],
},
"b6"
:
{
"hidden_dim"
:
2304
,
"width_coef"
:
1.8
,
"depth_coef"
:
2.6
,
"image_size"
:
528
,
"dropout_rate"
:
0.5
,
"dw_padding"
:
[
31
],
},
"b7"
:
{
"hidden_dim"
:
2560
,
"width_coef"
:
2.0
,
"depth_coef"
:
3.1
,
"image_size"
:
600
,
"dropout_rate"
:
0.5
,
"dw_padding"
:
[
18
],
},
}
def
get_efficientnet_config
(
model_name
):
config
=
EfficientNetConfig
()
config
.
hidden_dim
=
CONFIG_MAP
[
model_name
][
"hidden_dim"
]
config
.
width_coefficient
=
CONFIG_MAP
[
model_name
][
"width_coef"
]
config
.
depth_coefficient
=
CONFIG_MAP
[
model_name
][
"depth_coef"
]
config
.
image_size
=
CONFIG_MAP
[
model_name
][
"image_size"
]
config
.
dropout_rate
=
CONFIG_MAP
[
model_name
][
"dropout_rate"
]
config
.
depthwise_padding
=
CONFIG_MAP
[
model_name
][
"dw_padding"
]
repo_id
=
"huggingface/label-files"
filename
=
"imagenet-1k-id2label.json"
config
.
num_labels
=
1000
id2label
=
json
.
load
(
open
(
hf_hub_download
(
repo_id
,
filename
,
repo_type
=
"dataset"
),
"r"
))
id2label
=
{
int
(
k
):
v
for
k
,
v
in
id2label
.
items
()}
config
.
id2label
=
id2label
config
.
label2id
=
{
v
:
k
for
k
,
v
in
id2label
.
items
()}
return
config
# We will verify our results on an image of cute cats
def
prepare_img
():
url
=
"http://images.cocodataset.org/val2017/000000039769.jpg"
im
=
Image
.
open
(
requests
.
get
(
url
,
stream
=
True
).
raw
)
return
im
def
convert_image_processor
(
model_name
):
size
=
CONFIG_MAP
[
model_name
][
"image_size"
]
preprocessor
=
EfficientNetImageProcessor
(
size
=
{
"height"
:
size
,
"width"
:
size
},
image_mean
=
[
0.485
,
0.456
,
0.406
],
image_std
=
[
0.47853944
,
0.4732864
,
0.47434163
],
do_center_crop
=
False
,
)
return
preprocessor
# here we list all keys to be renamed (original name on the left, our name on the right)
def
rename_keys
(
original_param_names
):
block_names
=
[
v
.
split
(
"_"
)[
0
].
split
(
"block"
)[
1
]
for
v
in
original_param_names
if
v
.
startswith
(
"block"
)]
block_names
=
sorted
(
list
(
set
(
block_names
)))
num_blocks
=
len
(
block_names
)
block_name_mapping
=
{
b
:
str
(
i
)
for
b
,
i
in
zip
(
block_names
,
range
(
num_blocks
))}
rename_keys
=
[]
rename_keys
.
append
((
"stem_conv/kernel:0"
,
"embeddings.convolution.weight"
))
rename_keys
.
append
((
"stem_bn/gamma:0"
,
"embeddings.batchnorm.weight"
))
rename_keys
.
append
((
"stem_bn/beta:0"
,
"embeddings.batchnorm.bias"
))
rename_keys
.
append
((
"stem_bn/moving_mean:0"
,
"embeddings.batchnorm.running_mean"
))
rename_keys
.
append
((
"stem_bn/moving_variance:0"
,
"embeddings.batchnorm.running_var"
))
for
b
in
block_names
:
hf_b
=
block_name_mapping
[
b
]
rename_keys
.
append
((
f
"block
{
b
}
_expand_conv/kernel:0"
,
f
"encoder.blocks.
{
hf_b
}
.expansion.expand_conv.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_expand_bn/gamma:0"
,
f
"encoder.blocks.
{
hf_b
}
.expansion.expand_bn.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_expand_bn/beta:0"
,
f
"encoder.blocks.
{
hf_b
}
.expansion.expand_bn.bias"
))
rename_keys
.
append
(
(
f
"block
{
b
}
_expand_bn/moving_mean:0"
,
f
"encoder.blocks.
{
hf_b
}
.expansion.expand_bn.running_mean"
)
)
rename_keys
.
append
(
(
f
"block
{
b
}
_expand_bn/moving_variance:0"
,
f
"encoder.blocks.
{
hf_b
}
.expansion.expand_bn.running_var"
)
)
rename_keys
.
append
(
(
f
"block
{
b
}
_dwconv/depthwise_kernel:0"
,
f
"encoder.blocks.
{
hf_b
}
.depthwise_conv.depthwise_conv.weight"
)
)
rename_keys
.
append
((
f
"block
{
b
}
_bn/gamma:0"
,
f
"encoder.blocks.
{
hf_b
}
.depthwise_conv.depthwise_norm.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_bn/beta:0"
,
f
"encoder.blocks.
{
hf_b
}
.depthwise_conv.depthwise_norm.bias"
))
rename_keys
.
append
(
(
f
"block
{
b
}
_bn/moving_mean:0"
,
f
"encoder.blocks.
{
hf_b
}
.depthwise_conv.depthwise_norm.running_mean"
)
)
rename_keys
.
append
(
(
f
"block
{
b
}
_bn/moving_variance:0"
,
f
"encoder.blocks.
{
hf_b
}
.depthwise_conv.depthwise_norm.running_var"
)
)
rename_keys
.
append
((
f
"block
{
b
}
_se_reduce/kernel:0"
,
f
"encoder.blocks.
{
hf_b
}
.squeeze_excite.reduce.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_se_reduce/bias:0"
,
f
"encoder.blocks.
{
hf_b
}
.squeeze_excite.reduce.bias"
))
rename_keys
.
append
((
f
"block
{
b
}
_se_expand/kernel:0"
,
f
"encoder.blocks.
{
hf_b
}
.squeeze_excite.expand.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_se_expand/bias:0"
,
f
"encoder.blocks.
{
hf_b
}
.squeeze_excite.expand.bias"
))
rename_keys
.
append
(
(
f
"block
{
b
}
_project_conv/kernel:0"
,
f
"encoder.blocks.
{
hf_b
}
.projection.project_conv.weight"
)
)
rename_keys
.
append
((
f
"block
{
b
}
_project_bn/gamma:0"
,
f
"encoder.blocks.
{
hf_b
}
.projection.project_bn.weight"
))
rename_keys
.
append
((
f
"block
{
b
}
_project_bn/beta:0"
,
f
"encoder.blocks.
{
hf_b
}
.projection.project_bn.bias"
))
rename_keys
.
append
(
(
f
"block
{
b
}
_project_bn/moving_mean:0"
,
f
"encoder.blocks.
{
hf_b
}
.projection.project_bn.running_mean"
)
)
rename_keys
.
append
(
(
f
"block
{
b
}
_project_bn/moving_variance:0"
,
f
"encoder.blocks.
{
hf_b
}
.projection.project_bn.running_var"
)
)
rename_keys
.
append
((
"top_conv/kernel:0"
,
"encoder.top_conv.weight"
))
rename_keys
.
append
((
"top_bn/gamma:0"
,
"encoder.top_bn.weight"
))
rename_keys
.
append
((
"top_bn/beta:0"
,
"encoder.top_bn.bias"
))
rename_keys
.
append
((
"top_bn/moving_mean:0"
,
"encoder.top_bn.running_mean"
))
rename_keys
.
append
((
"top_bn/moving_variance:0"
,
"encoder.top_bn.running_var"
))
key_mapping
=
{}
for
item
in
rename_keys
:
if
item
[
0
]
in
original_param_names
:
key_mapping
[
item
[
0
]]
=
"efficientnet."
+
item
[
1
]
key_mapping
[
"predictions/kernel:0"
]
=
"classifier.weight"
key_mapping
[
"predictions/bias:0"
]
=
"classifier.bias"
return
key_mapping
def
replace_params
(
hf_params
,
tf_params
,
key_mapping
):
for
key
,
value
in
tf_params
.
items
():
if
"normalization"
in
key
:
continue
hf_key
=
key_mapping
[
key
]
if
"_conv"
in
key
and
"kernel"
in
key
:
new_hf_value
=
torch
.
from_numpy
(
value
).
permute
(
3
,
2
,
0
,
1
)
elif
"depthwise_kernel"
in
key
:
new_hf_value
=
torch
.
from_numpy
(
value
).
permute
(
2
,
3
,
0
,
1
)
elif
"kernel"
in
key
:
new_hf_value
=
torch
.
from_numpy
(
np
.
transpose
(
value
))
else
:
new_hf_value
=
torch
.
from_numpy
(
value
)
# Replace HF parameters with original TF model parameters
assert
hf_params
[
hf_key
].
shape
==
new_hf_value
.
shape
hf_params
[
hf_key
].
copy_
(
new_hf_value
)
@
torch
.
no_grad
()
def
convert_efficientnet_checkpoint
(
model_name
,
pytorch_dump_folder_path
,
save_model
,
push_to_hub
):
"""
Copy/paste/tweak model's weights to our EfficientNet structure.
"""
# Load original model
original_model
=
model_classes
[
model_name
](
include_top
=
True
,
weights
=
"imagenet"
,
input_tensor
=
None
,
input_shape
=
None
,
pooling
=
None
,
classes
=
1000
,
classifier_activation
=
"softmax"
,
)
tf_params
=
original_model
.
trainable_variables
tf_non_train_params
=
original_model
.
non_trainable_variables
tf_params
=
{
param
.
name
:
param
.
numpy
()
for
param
in
tf_params
}
for
param
in
tf_non_train_params
:
tf_params
[
param
.
name
]
=
param
.
numpy
()
tf_param_names
=
[
k
for
k
in
tf_params
.
keys
()]
# Load HuggingFace model
config
=
get_efficientnet_config
(
model_name
)
hf_model
=
EfficientNetForImageClassification
(
config
).
eval
()
hf_params
=
hf_model
.
state_dict
()
# Create src-to-dst parameter name mapping dictionary
print
(
"Converting parameters..."
)
key_mapping
=
rename_keys
(
tf_param_names
)
replace_params
(
hf_params
,
tf_params
,
key_mapping
)
# Initialize preprocessor and preprocess input image
preprocessor
=
convert_image_processor
(
model_name
)
inputs
=
preprocessor
(
images
=
prepare_img
(),
return_tensors
=
"pt"
)
# HF model inference
hf_model
.
eval
()
with
torch
.
no_grad
():
outputs
=
hf_model
(
**
inputs
)
hf_logits
=
outputs
.
logits
.
detach
().
numpy
()
# Original model inference
original_model
.
trainable
=
False
image_size
=
CONFIG_MAP
[
model_name
][
"image_size"
]
img
=
prepare_img
().
resize
((
image_size
,
image_size
),
resample
=
PIL
.
Image
.
NEAREST
)
x
=
image
.
img_to_array
(
img
)
x
=
np
.
expand_dims
(
x
,
axis
=
0
)
original_logits
=
original_model
.
predict
(
x
)
# Check whether original and HF model outputs match -> np.allclose
assert
np
.
allclose
(
original_logits
,
hf_logits
,
atol
=
1e-3
),
"The predicted logits are not the same."
print
(
"Model outputs match!"
)
if
save_model
:
# Create folder to save model
if
not
os
.
path
.
isdir
(
pytorch_dump_folder_path
):
os
.
mkdir
(
pytorch_dump_folder_path
)
# Save converted model and feature extractor
hf_model
.
save_pretrained
(
pytorch_dump_folder_path
)
preprocessor
.
save_pretrained
(
pytorch_dump_folder_path
)
if
push_to_hub
:
# Push model and feature extractor to hub
print
(
f
"Pushing converted
{
model_name
}
to the hub..."
)
model_name
=
f
"efficientnet-
{
model_name
}
"
preprocessor
.
push_to_hub
(
model_name
)
hf_model
.
push_to_hub
(
model_name
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
# Required parameters
parser
.
add_argument
(
"--model_name"
,
default
=
"b0"
,
type
=
str
,
help
=
"Version name of the EfficientNet model you want to convert, select from [b0, b1, b2, b3, b4, b5, b6, b7]."
,
)
parser
.
add_argument
(
"--pytorch_dump_folder_path"
,
default
=
"hf_model"
,
type
=
str
,
help
=
"Path to the output PyTorch model directory."
,
)
parser
.
add_argument
(
"--save_model"
,
action
=
"store_true"
,
help
=
"Save model to local"
)
parser
.
add_argument
(
"--push_to_hub"
,
action
=
"store_true"
,
help
=
"Push model and feature extractor to the hub"
)
args
=
parser
.
parse_args
()
convert_efficientnet_checkpoint
(
args
.
model_name
,
args
.
pytorch_dump_folder_path
,
args
.
save_model
,
args
.
push_to_hub
)
src/transformers/models/efficientnet/image_processing_efficientnet.py
0 → 100644
View file @
49ab1623
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for EfficientNet."""
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
from
transformers.utils
import
is_vision_available
from
transformers.utils.generic
import
TensorType
from
...image_processing_utils
import
BaseImageProcessor
,
BatchFeature
,
get_size_dict
from
...image_transforms
import
center_crop
,
normalize
,
rescale
,
resize
,
to_channel_dimension_format
from
...image_utils
import
(
IMAGENET_STANDARD_MEAN
,
IMAGENET_STANDARD_STD
,
ChannelDimension
,
ImageInput
,
PILImageResampling
,
make_list_of_images
,
to_numpy_array
,
valid_images
,
)
from
...utils
import
logging
if
is_vision_available
():
import
PIL
logger
=
logging
.
get_logger
(
__name__
)
class
EfficientNetImageProcessor
(
BaseImageProcessor
):
r
"""
Constructs a EfficientNet image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in `preprocess`.
size (`Dict[str, int]` *optional*, defaults to `{"height": 346, "width": 346}`):
Size of the image after `resize`. Can be overridden by `size` in `preprocess`.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.NEAREST`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in `preprocess`.
do_center_crop (`bool`, *optional*, defaults to `False`):
Whether to center crop the image. If the input size is smaller than `crop_size` along any edge, the image
is padded with 0's and then center cropped. Can be overridden by `do_center_crop` in `preprocess`.
crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 289, "width": 289}`):
Desired output size when applying center-cropping. Can be overridden by `crop_size` in `preprocess`.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale`
parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the
`preprocess` method.
rescale_offset (`bool`, *optional*, defaults to `False`):
Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range]. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
include_top (`bool`, *optional*, defaults to `True`):
Whether to rescale the image again. Should be set to True if the inputs are used for image classification.
"""
model_input_names
=
[
"pixel_values"
]
def
__init__
(
self
,
do_resize
:
bool
=
True
,
size
:
Dict
[
str
,
int
]
=
None
,
resample
:
PILImageResampling
=
PIL
.
Image
.
NEAREST
,
do_center_crop
:
bool
=
False
,
crop_size
:
Dict
[
str
,
int
]
=
None
,
rescale_factor
:
Union
[
int
,
float
]
=
1
/
255
,
rescale_offset
:
bool
=
False
,
do_rescale
:
bool
=
True
,
do_normalize
:
bool
=
True
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
include_top
:
bool
=
True
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
**
kwargs
)
size
=
size
if
size
is
not
None
else
{
"height"
:
346
,
"width"
:
346
}
size
=
get_size_dict
(
size
)
crop_size
=
crop_size
if
crop_size
is
not
None
else
{
"height"
:
289
,
"width"
:
289
}
crop_size
=
get_size_dict
(
crop_size
,
param_name
=
"crop_size"
)
self
.
do_resize
=
do_resize
self
.
size
=
size
self
.
resample
=
resample
self
.
do_center_crop
=
do_center_crop
self
.
crop_size
=
crop_size
self
.
do_rescale
=
do_rescale
self
.
rescale_factor
=
rescale_factor
self
.
rescale_offset
=
rescale_offset
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
if
image_mean
is
not
None
else
IMAGENET_STANDARD_MEAN
self
.
image_std
=
image_std
if
image_std
is
not
None
else
IMAGENET_STANDARD_STD
self
.
include_top
=
include_top
def
resize
(
self
,
image
:
np
.
ndarray
,
size
:
Dict
[
str
,
int
],
resample
:
PILImageResampling
=
PIL
.
Image
.
NEAREST
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Resize an image to `(size["height"], size["width"])` using the specified resampling filter.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling` filter, *optional*, defaults to `PILImageResampling.NEAREST`):
Resampling filter to use when resizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size
=
get_size_dict
(
size
)
if
"height"
not
in
size
or
"width"
not
in
size
:
raise
ValueError
(
f
"The size dictionary must have keys 'height' and 'width'. Got
{
size
.
keys
()
}
"
)
return
resize
(
image
,
size
=
(
size
[
"height"
],
size
[
"width"
]),
resample
=
resample
,
data_format
=
data_format
,
**
kwargs
)
def
center_crop
(
self
,
image
:
np
.
ndarray
,
size
:
Dict
[
str
,
int
],
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Center crop an image to `(crop_size["height"], crop_size["width"])`. If the input size is smaller than
`crop_size` along any edge, the image is padded with 0's and then center cropped.
Args:
image (`np.ndarray`):
Image to center crop.
size (`Dict[str, int]`):
Size of the output image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
size
=
get_size_dict
(
size
)
if
"height"
not
in
size
or
"width"
not
in
size
:
raise
ValueError
(
f
"The size dictionary must have keys 'height' and 'width'. Got
{
size
.
keys
()
}
"
)
return
center_crop
(
image
,
size
=
(
size
[
"height"
],
size
[
"width"
]),
data_format
=
data_format
,
**
kwargs
)
def
rescale
(
self
,
image
:
np
.
ndarray
,
scale
:
Union
[
int
,
float
],
offset
:
bool
=
True
,
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
):
"""
Rescale an image by a scale factor. image = image * scale.
Args:
image (`np.ndarray`):
Image to rescale.
scale (`int` or `float`):
Scale to apply to the image.
offset (`bool`, *optional*):
Whether to scale the image in both negative and positive directions.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
if
offset
:
rescaled_image
=
(
image
-
127.5
)
*
scale
if
data_format
is
not
None
:
rescaled_image
=
to_channel_dimension_format
(
rescaled_image
,
data_format
)
rescaled_image
=
rescaled_image
.
astype
(
np
.
float32
)
else
:
rescaled_image
=
rescale
(
image
,
scale
=
scale
,
data_format
=
data_format
,
**
kwargs
)
return
rescale
(
image
,
scale
=
scale
,
data_format
=
data_format
,
**
kwargs
)
def
normalize
(
self
,
image
:
np
.
ndarray
,
mean
:
Union
[
float
,
List
[
float
]],
std
:
Union
[
float
,
List
[
float
]],
data_format
:
Optional
[
Union
[
str
,
ChannelDimension
]]
=
None
,
**
kwargs
,
)
->
np
.
ndarray
:
"""
Normalize an image. image = (image - image_mean) / image_std.
Args:
image (`np.ndarray`):
Image to normalize.
image_mean (`float` or `List[float]`):
Image mean.
image_std (`float` or `List[float]`):
Image standard deviation.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
"""
return
normalize
(
image
,
mean
=
mean
,
std
=
std
,
data_format
=
data_format
,
**
kwargs
)
def
preprocess
(
self
,
images
:
ImageInput
,
do_resize
:
bool
=
None
,
size
:
Dict
[
str
,
int
]
=
None
,
resample
=
None
,
do_center_crop
:
bool
=
None
,
crop_size
:
Dict
[
str
,
int
]
=
None
,
do_rescale
:
bool
=
None
,
rescale_factor
:
float
=
None
,
rescale_offset
:
bool
=
None
,
do_normalize
:
bool
=
None
,
image_mean
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
image_std
:
Optional
[
Union
[
float
,
List
[
float
]]]
=
None
,
include_top
:
bool
=
None
,
return_tensors
:
Optional
[
Union
[
str
,
TensorType
]]
=
None
,
data_format
:
ChannelDimension
=
ChannelDimension
.
FIRST
,
**
kwargs
,
)
->
PIL
.
Image
.
Image
:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after `resize`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
PILImageResampling filter to use if resizing the image Only has an effect if `do_resize` is set to
`True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the image after center crop. If one edge the image is smaller than `crop_size`, it will be
padded with zeros and then cropped
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
rescale_offset (`bool`, *optional*, defaults to `self.rescale_offset`):
Whether to rescale the image between [-scale_range, scale_range] instead of [0, scale_range].
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation.
include_top (`bool`, *optional*, defaults to `self.include_top`):
Rescales the image again for image classification if set to True.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- `None`: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `ChannelDimension.LAST`: image in (height, width, num_channels) format.
"""
do_resize
=
do_resize
if
do_resize
is
not
None
else
self
.
do_resize
resample
=
resample
if
resample
is
not
None
else
self
.
resample
do_center_crop
=
do_center_crop
if
do_center_crop
is
not
None
else
self
.
do_center_crop
do_rescale
=
do_rescale
if
do_rescale
is
not
None
else
self
.
do_rescale
rescale_factor
=
rescale_factor
if
rescale_factor
is
not
None
else
self
.
rescale_factor
rescale_offset
=
rescale_offset
if
rescale_offset
is
not
None
else
self
.
rescale_offset
do_normalize
=
do_normalize
if
do_normalize
is
not
None
else
self
.
do_normalize
image_mean
=
image_mean
if
image_mean
is
not
None
else
self
.
image_mean
image_std
=
image_std
if
image_std
is
not
None
else
self
.
image_std
include_top
=
include_top
if
include_top
is
not
None
else
self
.
include_top
size
=
size
if
size
is
not
None
else
self
.
size
size
=
get_size_dict
(
size
)
crop_size
=
crop_size
if
crop_size
is
not
None
else
self
.
crop_size
crop_size
=
get_size_dict
(
crop_size
,
param_name
=
"crop_size"
)
images
=
make_list_of_images
(
images
)
if
not
valid_images
(
images
):
raise
ValueError
(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if
do_resize
and
size
is
None
or
resample
is
None
:
raise
ValueError
(
"Size and resample must be specified if do_resize is True."
)
if
do_center_crop
and
crop_size
is
None
:
raise
ValueError
(
"Crop size must be specified if do_center_crop is True."
)
if
do_rescale
and
rescale_factor
is
None
:
raise
ValueError
(
"Rescale factor must be specified if do_rescale is True."
)
if
do_normalize
and
(
image_mean
is
None
or
image_std
is
None
):
raise
ValueError
(
"Image mean and std must be specified if do_normalize is True."
)
# All transformations expect numpy arrays.
images
=
[
to_numpy_array
(
image
)
for
image
in
images
]
if
do_resize
:
images
=
[
self
.
resize
(
image
=
image
,
size
=
size
,
resample
=
resample
)
for
image
in
images
]
if
do_center_crop
:
images
=
[
self
.
center_crop
(
image
=
image
,
size
=
crop_size
)
for
image
in
images
]
if
do_rescale
:
images
=
[
self
.
rescale
(
image
=
image
,
scale
=
rescale_factor
,
offset
=
rescale_offset
)
for
image
in
images
]
if
do_normalize
:
images
=
[
self
.
normalize
(
image
=
image
,
mean
=
image_mean
,
std
=
image_std
)
for
image
in
images
]
if
include_top
:
images
=
[
self
.
normalize
(
image
=
image
,
mean
=
[
0
,
0
,
0
],
std
=
image_std
)
for
image
in
images
]
images
=
[
to_channel_dimension_format
(
image
,
data_format
)
for
image
in
images
]
data
=
{
"pixel_values"
:
images
}
return
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
src/transformers/models/efficientnet/modeling_efficientnet.py
0 → 100644
View file @
49ab1623
This diff is collapsed.
Click to expand it.
src/transformers/utils/dummy_pt_objects.py
View file @
49ab1623
...
...
@@ -2317,6 +2317,30 @@ class EfficientFormerPreTrainedModel(metaclass=DummyObject):
requires_backends
(
self
,
[
"torch"
])
EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
EfficientNetForImageClassification
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
EfficientNetModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
class
EfficientNetPreTrainedModel
(
metaclass
=
DummyObject
):
_backends
=
[
"torch"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
...
...
src/transformers/utils/dummy_vision_objects.py
View file @
49ab1623
...
...
@@ -191,6 +191,13 @@ class EfficientFormerImageProcessor(metaclass=DummyObject):
requires_backends
(
self
,
[
"vision"
])
class
EfficientNetImageProcessor
(
metaclass
=
DummyObject
):
_backends
=
[
"vision"
]
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"vision"
])
class
FlavaFeatureExtractor
(
metaclass
=
DummyObject
):
_backends
=
[
"vision"
]
...
...
tests/models/convnext/test_modeling_convnext.py
View file @
49ab1623
...
...
@@ -82,7 +82,6 @@ class ConvNextModelTester:
labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_labels
)
config
=
self
.
get_config
()
return
config
,
pixel_values
,
labels
def
get_config
(
self
):
...
...
tests/models/efficientnet/__init__.py
0 → 100644
View file @
49ab1623
tests/models/efficientnet/test_image_processing_efficientnet.py
0 → 100644
View file @
49ab1623
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
transformers.testing_utils
import
require_torch
,
require_vision
from
transformers.utils
import
is_torch_available
,
is_vision_available
from
...test_image_processing_common
import
ImageProcessingSavingTestMixin
,
prepare_image_inputs
if
is_torch_available
():
import
torch
if
is_vision_available
():
from
PIL
import
Image
from
transformers
import
EfficientNetImageProcessor
class
EfficientNetImageProcessorTester
(
unittest
.
TestCase
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
num_channels
=
3
,
image_size
=
18
,
min_resolution
=
30
,
max_resolution
=
400
,
do_resize
=
True
,
size
=
None
,
do_normalize
=
True
,
image_mean
=
[
0.5
,
0.5
,
0.5
],
image_std
=
[
0.5
,
0.5
,
0.5
],
):
size
=
size
if
size
is
not
None
else
{
"height"
:
18
,
"width"
:
18
}
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
num_channels
=
num_channels
self
.
image_size
=
image_size
self
.
min_resolution
=
min_resolution
self
.
max_resolution
=
max_resolution
self
.
do_resize
=
do_resize
self
.
size
=
size
self
.
do_normalize
=
do_normalize
self
.
image_mean
=
image_mean
self
.
image_std
=
image_std
def
prepare_image_processor_dict
(
self
):
return
{
"image_mean"
:
self
.
image_mean
,
"image_std"
:
self
.
image_std
,
"do_normalize"
:
self
.
do_normalize
,
"do_resize"
:
self
.
do_resize
,
"size"
:
self
.
size
,
}
@
require_torch
@
require_vision
class
EfficientNetImageProcessorTest
(
ImageProcessingSavingTestMixin
,
unittest
.
TestCase
):
image_processing_class
=
EfficientNetImageProcessor
if
is_vision_available
()
else
None
def
setUp
(
self
):
self
.
image_processor_tester
=
EfficientNetImageProcessorTester
(
self
)
@
property
def
image_processor_dict
(
self
):
return
self
.
image_processor_tester
.
prepare_image_processor_dict
()
def
test_image_processor_properties
(
self
):
image_processing
=
self
.
image_processing_class
(
**
self
.
image_processor_dict
)
self
.
assertTrue
(
hasattr
(
image_processing
,
"image_mean"
))
self
.
assertTrue
(
hasattr
(
image_processing
,
"image_std"
))
self
.
assertTrue
(
hasattr
(
image_processing
,
"do_normalize"
))
self
.
assertTrue
(
hasattr
(
image_processing
,
"do_resize"
))
self
.
assertTrue
(
hasattr
(
image_processing
,
"size"
))
def
test_image_processor_from_dict_with_kwargs
(
self
):
image_processor
=
self
.
image_processing_class
.
from_dict
(
self
.
image_processor_dict
)
self
.
assertEqual
(
image_processor
.
size
,
{
"height"
:
18
,
"width"
:
18
})
image_processor
=
self
.
image_processing_class
.
from_dict
(
self
.
image_processor_dict
,
size
=
42
)
self
.
assertEqual
(
image_processor
.
size
,
{
"height"
:
42
,
"width"
:
42
})
def
test_call_pil
(
self
):
# Initialize image_processing
image_processing
=
self
.
image_processing_class
(
**
self
.
image_processor_dict
)
# create random PIL images
image_inputs
=
prepare_image_inputs
(
self
.
image_processor_tester
,
equal_resolution
=
False
)
for
image
in
image_inputs
:
self
.
assertIsInstance
(
image
,
Image
.
Image
)
# Test not batched input
encoded_images
=
image_processing
(
image_inputs
[
0
],
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
1
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
# Test batched
encoded_images
=
image_processing
(
image_inputs
,
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
self
.
image_processor_tester
.
batch_size
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
def
test_call_numpy
(
self
):
# Initialize image_processing
image_processing
=
self
.
image_processing_class
(
**
self
.
image_processor_dict
)
# create random numpy tensors
image_inputs
=
prepare_image_inputs
(
self
.
image_processor_tester
,
equal_resolution
=
False
,
numpify
=
True
)
for
image
in
image_inputs
:
self
.
assertIsInstance
(
image
,
np
.
ndarray
)
# Test not batched input
encoded_images
=
image_processing
(
image_inputs
[
0
],
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
1
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
# Test batched
encoded_images
=
image_processing
(
image_inputs
,
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
self
.
image_processor_tester
.
batch_size
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
def
test_call_pytorch
(
self
):
# Initialize image_processing
image_processing
=
self
.
image_processing_class
(
**
self
.
image_processor_dict
)
# create random PyTorch tensors
image_inputs
=
prepare_image_inputs
(
self
.
image_processor_tester
,
equal_resolution
=
False
,
torchify
=
True
)
for
image
in
image_inputs
:
self
.
assertIsInstance
(
image
,
torch
.
Tensor
)
# Test not batched input
encoded_images
=
image_processing
(
image_inputs
[
0
],
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
1
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
# Test batched
encoded_images
=
image_processing
(
image_inputs
,
return_tensors
=
"pt"
).
pixel_values
self
.
assertEqual
(
encoded_images
.
shape
,
(
self
.
image_processor_tester
.
batch_size
,
self
.
image_processor_tester
.
num_channels
,
self
.
image_processor_tester
.
size
[
"height"
],
self
.
image_processor_tester
.
size
[
"width"
],
),
)
tests/models/efficientnet/test_modeling_efficientnet.py
0 → 100644
View file @
49ab1623
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch EfficientNet model. """
import
inspect
import
unittest
from
transformers
import
EfficientNetConfig
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
from
transformers.utils
import
cached_property
,
is_torch_available
,
is_vision_available
from
...test_configuration_common
import
ConfigTester
from
...test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
if
is_torch_available
():
import
torch
from
transformers
import
EfficientNetForImageClassification
,
EfficientNetModel
from
transformers.models.efficientnet.modeling_efficientnet
import
EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LIST
if
is_vision_available
():
from
PIL
import
Image
from
transformers
import
AutoImageProcessor
class
EfficientNetModelTester
:
def
__init__
(
self
,
parent
,
batch_size
=
13
,
image_size
=
32
,
num_channels
=
3
,
kernel_sizes
=
[
3
,
3
,
5
],
in_channels
=
[
32
,
16
,
24
],
out_channels
=
[
16
,
24
,
40
],
strides
=
[
1
,
1
,
2
],
num_block_repeats
=
[
1
,
1
,
2
],
expand_ratios
=
[
1
,
6
,
6
],
is_training
=
True
,
use_labels
=
True
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
num_labels
=
10
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
image_size
=
image_size
self
.
num_channels
=
num_channels
self
.
kernel_sizes
=
kernel_sizes
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
strides
=
strides
self
.
num_block_repeats
=
num_block_repeats
self
.
expand_ratios
=
expand_ratios
self
.
is_training
=
is_training
self
.
hidden_act
=
hidden_act
self
.
num_labels
=
num_labels
self
.
use_labels
=
use_labels
def
prepare_config_and_inputs
(
self
):
pixel_values
=
floats_tensor
([
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
])
labels
=
None
if
self
.
use_labels
:
labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_labels
)
config
=
self
.
get_config
()
return
config
,
pixel_values
,
labels
def
get_config
(
self
):
return
EfficientNetConfig
(
num_channels
=
self
.
num_channels
,
kernel_sizes
=
self
.
kernel_sizes
,
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
strides
=
self
.
strides
,
num_block_repeats
=
self
.
num_block_repeats
,
expand_ratios
=
self
.
expand_ratios
,
hidden_act
=
self
.
hidden_act
,
num_labels
=
self
.
num_labels
,
)
def
create_and_check_model
(
self
,
config
,
pixel_values
,
labels
):
model
=
EfficientNetModel
(
config
=
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
pixel_values
)
# expected last hidden states: B, C, H // 4, W // 4
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
config
.
hidden_dim
,
self
.
image_size
//
4
,
self
.
image_size
//
4
),
)
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
model
=
EfficientNetForImageClassification
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
result
=
model
(
pixel_values
,
labels
=
labels
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_labels
))
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config
,
pixel_values
,
labels
=
config_and_inputs
inputs_dict
=
{
"pixel_values"
:
pixel_values
}
return
config
,
inputs_dict
@
require_torch
class
EfficientNetModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as EfficientNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes
=
(
EfficientNetModel
,
EfficientNetForImageClassification
)
if
is_torch_available
()
else
()
fx_compatible
=
False
test_pruning
=
False
test_resize_embeddings
=
False
test_head_masking
=
False
has_attentions
=
False
def
setUp
(
self
):
self
.
model_tester
=
EfficientNetModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
EfficientNetConfig
,
has_text_modality
=
False
,
hidden_size
=
37
)
def
test_config
(
self
):
self
.
create_and_test_config_common_properties
()
self
.
config_tester
.
create_and_test_config_to_json_string
()
self
.
config_tester
.
create_and_test_config_to_json_file
()
self
.
config_tester
.
create_and_test_config_from_and_save_pretrained
()
self
.
config_tester
.
create_and_test_config_with_num_labels
()
self
.
config_tester
.
check_config_can_be_init_without_params
()
self
.
config_tester
.
check_config_arguments_init
()
def
create_and_test_config_common_properties
(
self
):
return
@
unittest
.
skip
(
reason
=
"EfficientNet does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
pass
@
unittest
.
skip
(
reason
=
"EfficientNet does not support input and output embeddings"
)
def
test_model_common_attributes
(
self
):
pass
@
unittest
.
skip
(
reason
=
"EfficientNet does not use feedforward chunking"
)
def
test_feed_forward_chunking
(
self
):
pass
def
test_forward_signature
(
self
):
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
model
=
model_class
(
config
)
signature
=
inspect
.
signature
(
model
.
forward
)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names
=
[
*
signature
.
parameters
.
keys
()]
expected_arg_names
=
[
"pixel_values"
]
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
def
test_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
def
test_hidden_states_output
(
self
):
def
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
):
model
=
model_class
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
outputs
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
hidden_states
=
outputs
.
encoder_hidden_states
if
config
.
is_encoder_decoder
else
outputs
.
hidden_states
num_blocks
=
sum
(
config
.
num_block_repeats
)
*
4
self
.
assertEqual
(
len
(
hidden_states
),
num_blocks
)
# EfficientNet's feature maps are of shape (batch_size, num_channels, height, width)
self
.
assertListEqual
(
list
(
hidden_states
[
0
].
shape
[
-
2
:]),
[
self
.
model_tester
.
image_size
//
2
,
self
.
model_tester
.
image_size
//
2
],
)
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
inputs_dict
[
"output_hidden_states"
]
=
True
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
)
# check that output_hidden_states also work using config
del
inputs_dict
[
"output_hidden_states"
]
config
.
output_hidden_states
=
True
check_hidden_states_output
(
inputs_dict
,
config
,
model_class
)
def
test_for_image_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_for_image_classification
(
*
config_and_inputs
)
@
slow
def
test_model_from_pretrained
(
self
):
for
model_name
in
EFFICIENTNET_PRETRAINED_MODEL_ARCHIVE_LIST
[:
1
]:
model
=
EfficientNetModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
# We will verify our results on an image of cute cats
def
prepare_img
():
image
=
Image
.
open
(
"./tests/fixtures/tests_samples/COCO/000000039769.png"
)
return
image
@
require_torch
@
require_vision
class
EfficientNetModelIntegrationTest
(
unittest
.
TestCase
):
@
cached_property
def
default_image_processor
(
self
):
return
AutoImageProcessor
.
from_pretrained
(
"google/efficientnet-b7"
)
if
is_vision_available
()
else
None
@
slow
def
test_inference_image_classification_head
(
self
):
model
=
EfficientNetForImageClassification
.
from_pretrained
(
"google/efficientnet-b7"
).
to
(
torch_device
)
image_processor
=
self
.
default_image_processor
image
=
prepare_img
()
inputs
=
image_processor
(
images
=
image
,
return_tensors
=
"pt"
).
to
(
torch_device
)
# forward pass
with
torch
.
no_grad
():
outputs
=
model
(
**
inputs
)
# verify the logits
expected_shape
=
torch
.
Size
((
1
,
1000
))
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
expected_slice
=
torch
.
tensor
([
0.0001
,
0.0002
,
0.0002
]).
to
(
torch_device
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
],
expected_slice
,
atol
=
1e-4
))
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment