# Copyright 2022 The KerasCV Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """EfficientNet V2 models for KerasCV. Reference: - [EfficientNetV2: Smaller Models and Faster Training]( https://arxiv.org/abs/2104.00298) (ICML 2021) - [Based on the original keras.applications EfficientNetV2](https://github.com/keras-team/keras/blob/master/keras/applications/efficientnet_v2.py) """ import copy import math import tensorflow as tf from keras import backend from keras import layers from keras_cv.models import utils from keras_cv.models.weights import parse_weights DEFAULT_BLOCKS_ARGS = { "efficientnetv2-s": [ { "kernel_size": 3, "num_repeat": 2, "input_filters": 24, "output_filters": 24, "expand_ratio": 1, "se_ratio": 0.0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 4, "input_filters": 24, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0.0, "strides": 2, "conv_type": 1, }, { "conv_type": 1, "expand_ratio": 4, "input_filters": 48, "kernel_size": 3, "num_repeat": 4, "output_filters": 64, "se_ratio": 0, "strides": 2, }, { "conv_type": 0, "expand_ratio": 4, "input_filters": 64, "kernel_size": 3, "num_repeat": 6, "output_filters": 128, "se_ratio": 0.25, "strides": 2, }, { "conv_type": 0, "expand_ratio": 6, "input_filters": 128, "kernel_size": 3, "num_repeat": 9, "output_filters": 160, "se_ratio": 0.25, "strides": 1, }, { "conv_type": 0, "expand_ratio": 6, "input_filters": 160, "kernel_size": 3, "num_repeat": 15, "output_filters": 256, "se_ratio": 0.25, "strides": 2, }, ], "efficientnetv2-m": [ { "kernel_size": 3, "num_repeat": 3, "input_filters": 24, "output_filters": 24, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 24, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 48, "output_filters": 80, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 7, "input_filters": 80, "output_filters": 160, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 14, "input_filters": 160, "output_filters": 176, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 18, "input_filters": 176, "output_filters": 304, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 304, "output_filters": 512, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, ], "efficientnetv2-l": [ { "kernel_size": 3, "num_repeat": 4, "input_filters": 32, "output_filters": 32, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 7, "input_filters": 32, "output_filters": 64, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 7, "input_filters": 64, "output_filters": 96, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 10, "input_filters": 96, "output_filters": 192, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 19, "input_filters": 192, "output_filters": 224, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 25, "input_filters": 224, "output_filters": 384, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 7, "input_filters": 384, "output_filters": 640, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, ], "efficientnetv2-b0": [ { "kernel_size": 3, "num_repeat": 1, "input_filters": 32, "output_filters": 16, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 16, "output_filters": 32, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 32, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 3, "input_filters": 48, "output_filters": 96, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 96, "output_filters": 112, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 8, "input_filters": 112, "output_filters": 192, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, ], "efficientnetv2-b1": [ { "kernel_size": 3, "num_repeat": 1, "input_filters": 32, "output_filters": 16, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 16, "output_filters": 32, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 32, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 3, "input_filters": 48, "output_filters": 96, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 96, "output_filters": 112, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 8, "input_filters": 112, "output_filters": 192, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, ], "efficientnetv2-b2": [ { "kernel_size": 3, "num_repeat": 1, "input_filters": 32, "output_filters": 16, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 16, "output_filters": 32, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 32, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 3, "input_filters": 48, "output_filters": 96, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 96, "output_filters": 112, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 8, "input_filters": 112, "output_filters": 192, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, ], "efficientnetv2-b3": [ { "kernel_size": 3, "num_repeat": 1, "input_filters": 32, "output_filters": 16, "expand_ratio": 1, "se_ratio": 0, "strides": 1, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 16, "output_filters": 32, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 2, "input_filters": 32, "output_filters": 48, "expand_ratio": 4, "se_ratio": 0, "strides": 2, "conv_type": 1, }, { "kernel_size": 3, "num_repeat": 3, "input_filters": 48, "output_filters": 96, "expand_ratio": 4, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 5, "input_filters": 96, "output_filters": 112, "expand_ratio": 6, "se_ratio": 0.25, "strides": 1, "conv_type": 0, }, { "kernel_size": 3, "num_repeat": 8, "input_filters": 112, "output_filters": 192, "expand_ratio": 6, "se_ratio": 0.25, "strides": 2, "conv_type": 0, }, ], } CONV_KERNEL_INITIALIZER = { "class_name": "VarianceScaling", "config": { "scale": 2.0, "mode": "fan_out", "distribution": "truncated_normal", }, } DENSE_KERNEL_INITIALIZER = { "class_name": "VarianceScaling", "config": { "scale": 1.0 / 3.0, "mode": "fan_out", "distribution": "uniform", }, } BN_AXIS = 3 BASE_DOCSTRING = """Instantiates the {name} architecture. Reference: - [EfficientNetV2: Smaller Models and Faster Training]( https://arxiv.org/abs/2104.00298) (ICML 2021) This function returns a Keras image classification model, optionally loaded with weights pre-trained on ImageNet. For image classification use cases, see [this page for detailed examples]( https://keras.io/api/applications/#usage-examples-for-image-classification-models). For transfer learning use cases, make sure to read the [guide to transfer learning & fine-tuning]( https://keras.io/guides/transfer_learning/). Args: include_rescaling: whether or not to Rescale the inputs.If set to True, inputs will be passed through a `Rescaling(1/255.0)` layer. include_top: Whether to include the fully-connected layer at the top of the network. weights: one of `None` (random initialization), a pretrained weight file path, or a reference to pre-trained weights (e.g. 'imagenet/classification') (see available pre-trained weights in weights.py) input_shape: Optional shape tuple. It should have exactly 3 inputs channels. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. pooling: Optional pooling mode for feature extraction when `include_top` is `False`. Defaults to None. - `None` means that the output of the model will be the 4D tensor output of the last convolutional layer. - `avg` means that global average pooling will be applied to the output of the last convolutional layer, and thus the output of the model will be a 2D tensor. - `max` means that global max pooling will be applied. classes: Optional number of lasses to classify images into, only to be specified if `include_top` is True, and if no `weights` argument is specified. Defaults to None. classifier_activation: A `str` or callable. The activation function to use on the "top" layer. Ignored unless `include_top=True`. Set `classifier_activation=None` to return the logits of the "top" layer. Defaults to 'softmax'. When loading pretrained weights, `classifier_activation` can only be `None` or `"softmax"`. Returns: A `keras.Model` instance. """ def round_filters(filters, width_coefficient, min_depth, depth_divisor): """Round number of filters based on depth multiplier.""" filters *= width_coefficient minimum_depth = min_depth or depth_divisor new_filters = max( minimum_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor, ) return int(new_filters) def round_repeats(repeats, depth_coefficient): """Round number of repeats based on depth multiplier.""" return int(math.ceil(depth_coefficient * repeats)) def MBConvBlock( input_filters: int, output_filters: int, expand_ratio=1, kernel_size=3, strides=1, se_ratio=0.0, bn_momentum=0.9, activation="swish", survival_probability: float = 0.8, name=None, ): """MBConv block: Mobile Inverted Residual Bottleneck.""" if name is None: name = backend.get_uid("block0") def apply(inputs): # Expansion phase filters = input_filters * expand_ratio if expand_ratio != 1: x = layers.Conv2D( filters=filters, kernel_size=1, strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", data_format="channels_last", use_bias=False, name=name + "expand_conv", )(inputs) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name=name + "expand_bn", )(x) x = layers.Activation(activation, name=name + "expand_activation")(x) else: x = inputs # Depthwise conv x = layers.DepthwiseConv2D( kernel_size=kernel_size, strides=strides, depthwise_initializer=CONV_KERNEL_INITIALIZER, padding="same", data_format="channels_last", use_bias=False, name=name + "dwconv2", )(x) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name=name + "bn" )(x) x = layers.Activation(activation, name=name + "activation")(x) # Squeeze and excite if 0 < se_ratio <= 1: filters_se = max(1, int(input_filters * se_ratio)) se = layers.GlobalAveragePooling2D(name=name + "se_squeeze")(x) if BN_AXIS == 1: se_shape = (filters, 1, 1) else: se_shape = (1, 1, filters) se = layers.Reshape(se_shape, name=name + "se_reshape")(se) se = layers.Conv2D( filters_se, 1, padding="same", activation=activation, kernel_initializer=CONV_KERNEL_INITIALIZER, name=name + "se_reduce", )(se) se = layers.Conv2D( filters, 1, padding="same", activation="sigmoid", kernel_initializer=CONV_KERNEL_INITIALIZER, name=name + "se_expand", )(se) x = layers.multiply([x, se], name=name + "se_excite") # Output phase x = layers.Conv2D( filters=output_filters, kernel_size=1, strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", data_format="channels_last", use_bias=False, name=name + "project_conv", )(x) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name=name + "project_bn" )(x) if strides == 1 and input_filters == output_filters: if survival_probability: x = layers.Dropout( survival_probability, noise_shape=(None, 1, 1, 1), name=name + "drop", )(x) x = layers.add([x, inputs], name=name + "add") return x return apply def FusedMBConvBlock( input_filters: int, output_filters: int, expand_ratio=1, kernel_size=3, strides=1, se_ratio=0.0, bn_momentum=0.9, activation="swish", survival_probability: float = 0.8, name=None, ): """Fused MBConv Block: Fusing the proj conv1x1 and depthwise_conv into a conv2d.""" if name is None: name = backend.get_uid("block0") def apply(inputs): filters = input_filters * expand_ratio if expand_ratio != 1: x = layers.Conv2D( filters, kernel_size=kernel_size, strides=strides, kernel_initializer=CONV_KERNEL_INITIALIZER, data_format="channels_last", padding="same", use_bias=False, name=name + "expand_conv", )(inputs) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name=name + "expand_bn" )(x) x = layers.Activation( activation=activation, name=name + "expand_activation" )(x) else: x = inputs # Squeeze and excite if 0 < se_ratio <= 1: filters_se = max(1, int(input_filters * se_ratio)) se = layers.GlobalAveragePooling2D(name=name + "se_squeeze")(x) if BN_AXIS == 1: se_shape = (filters, 1, 1) else: se_shape = (1, 1, filters) se = layers.Reshape(se_shape, name=name + "se_reshape")(se) se = layers.Conv2D( filters_se, 1, padding="same", activation=activation, kernel_initializer=CONV_KERNEL_INITIALIZER, name=name + "se_reduce", )(se) se = layers.Conv2D( filters, 1, padding="same", activation="sigmoid", kernel_initializer=CONV_KERNEL_INITIALIZER, name=name + "se_expand", )(se) x = layers.multiply([x, se], name=name + "se_excite") # Output phase: x = layers.Conv2D( output_filters, kernel_size=1 if expand_ratio != 1 else kernel_size, strides=1 if expand_ratio != 1 else strides, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", use_bias=False, name=name + "project_conv", )(x) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name=name + "project_bn" )(x) if expand_ratio == 1: x = layers.Activation( activation=activation, name=name + "project_activation" )(x) # Residual: if strides == 1 and input_filters == output_filters: if survival_probability: x = layers.Dropout( survival_probability, noise_shape=(None, 1, 1, 1), name=name + "drop", )(x) x = layers.add([x, inputs], name=name + "add") return x return apply def EfficientNetV2( include_rescaling, include_top, width_coefficient, depth_coefficient, default_size, dropout_rate=0.2, drop_connect_rate=0.2, depth_divisor=8, min_depth=8, bn_momentum=0.9, activation="swish", blocks_args="default", model_name="efficientnet", weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): """Instantiates the EfficientNetV2 architecture using given scaling coefficients. Args: include_rescaling: whether or not to Rescale the inputs.If set to True, inputs will be passed through a `Rescaling(1/255.0)` layer. include_top: whether to include the fully-connected layer at the top of the network. width_coefficient: float, scaling coefficient for network width. depth_coefficient: float, scaling coefficient for network depth. default_size: integer, default input image size. dropout_rate: float, dropout rate before final classifier layer. drop_connect_rate: float, dropout rate at skip connections. depth_divisor: integer, a unit of network width. min_depth: integer, minimum number of filters. bn_momentum: float. Momentum parameter for Batch Normalization layers. activation: activation function. blocks_args: list of dicts, parameters to construct block modules. model_name: string, model name. weights: one of `None` (random initialization), or the path to the weights file to be loaded. input_shape: optional shape tuple, It should have exactly 3 inputs channels. input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. pooling: optional pooling mode for feature extraction when `include_top` is `False`. - `None` means that the output of the model will be the 4D tensor output of the last convolutional layer. - `avg` means that global average pooling will be applied to the output of the last convolutional layer, and thus the output of the model will be a 2D tensor. - `max` means that global max pooling will be applied. classes: optional number of classes to classify images into, only to be specified if `include_top` is True, and if no `weights` argument is specified. classifier_activation: A `str` or callable. The activation function to use on the "top" layer. Ignored unless `include_top=True`. Set `classifier_activation=None` to return the logits of the "top" layer. Returns: A `keras.Model` instance. Raises: ValueError: in case of invalid argument for `weights`, or invalid input shape. ValueError: if `classifier_activation` is not `"softmax"` or `None` when using a pretrained top layer. """ if blocks_args == "default": blocks_args = DEFAULT_BLOCKS_ARGS[model_name] if weights and not tf.io.gfile.exists(weights): raise ValueError( "The `weights` argument should be either `None` or the path to the " "weights file to be loaded. Weights file not found at location: {weights}" ) if include_top and not classes: raise ValueError( "If `include_top` is True, you should specify `classes`. " f"Received: classes={classes}" ) if include_top and pooling: raise ValueError( f"`pooling` must be `None` when `include_top=True`." f"Received pooling={pooling} and include_top={include_top}. " ) # Determine proper input shape img_input = utils.parse_model_inputs(input_shape, input_tensor) x = img_input if include_rescaling: x = layers.Rescaling(scale=1 / 255.0)(x) # Build stem stem_filters = round_filters( filters=blocks_args[0]["input_filters"], width_coefficient=width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, ) x = layers.Conv2D( filters=stem_filters, kernel_size=3, strides=2, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", use_bias=False, name="stem_conv", )(x) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name="stem_bn", )(x) x = layers.Activation(activation, name="stem_activation")(x) # Build blocks blocks_args = copy.deepcopy(blocks_args) b = 0 blocks = float(sum(args["num_repeat"] for args in blocks_args)) for (i, args) in enumerate(blocks_args): assert args["num_repeat"] > 0 # Update block input and output filters based on depth multiplier. args["input_filters"] = round_filters( filters=args["input_filters"], width_coefficient=width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, ) args["output_filters"] = round_filters( filters=args["output_filters"], width_coefficient=width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, ) # Determine which conv type to use: block = {0: MBConvBlock, 1: FusedMBConvBlock}[args.pop("conv_type")] repeats = round_repeats( repeats=args.pop("num_repeat"), depth_coefficient=depth_coefficient ) for j in range(repeats): # The first block needs to take care of stride and filter size # increase. if j > 0: args["strides"] = 1 args["input_filters"] = args["output_filters"] x = block( activation=activation, bn_momentum=bn_momentum, survival_probability=drop_connect_rate * b / blocks, name="block{}{}_".format(i + 1, chr(j + 97)), **args, )(x) b += 1 # Build top top_filters = round_filters( filters=1280, width_coefficient=width_coefficient, min_depth=min_depth, depth_divisor=depth_divisor, ) x = layers.Conv2D( filters=top_filters, kernel_size=1, strides=1, kernel_initializer=CONV_KERNEL_INITIALIZER, padding="same", data_format="channels_last", use_bias=False, name="top_conv", )(x) x = layers.BatchNormalization( axis=BN_AXIS, momentum=bn_momentum, name="top_bn", )(x) x = layers.Activation(activation=activation, name="top_activation")(x) if include_top: x = layers.GlobalAveragePooling2D(name="avg_pool")(x) if dropout_rate > 0: x = layers.Dropout(dropout_rate, name="top_dropout")(x) x = layers.Dense( classes, activation=classifier_activation, kernel_initializer=DENSE_KERNEL_INITIALIZER, bias_initializer=tf.constant_initializer(0), name="predictions", )(x) else: if pooling == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) elif pooling == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) inputs = img_input # Create model. model = tf.keras.Model(inputs, x, **kwargs) # Load weights. if weights is not None: model.load_weights(weights) return model def EfficientNetV2B0( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.0, depth_coefficient=1.0, default_size=224, model_name="efficientnetv2-b0", weights=parse_weights(weights, include_top, "efficientnetv2b0"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2B1( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.0, depth_coefficient=1.1, default_size=240, model_name="efficientnetv2-b1", weights=parse_weights(weights, include_top, "efficientnetv2b1"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2B2( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.1, depth_coefficient=1.2, default_size=260, model_name="efficientnetv2-b2", weights=parse_weights(weights, include_top, "efficientnetv2b2"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2B3( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.2, depth_coefficient=1.4, default_size=300, model_name="efficientnetv2-b3", weights=parse_weights(weights, include_top, "efficientnetv2b3"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2S( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.0, depth_coefficient=1.0, default_size=384, model_name="efficientnetv2-s", weights=parse_weights(weights, include_top, "efficientnetv2s"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2M( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.0, depth_coefficient=1.0, default_size=480, model_name="efficientnetv2-m", weights=parse_weights(weights, include_top, "efficientnetv2m"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) def EfficientNetV2L( include_rescaling, include_top, weights=None, input_shape=(None, None, 3), input_tensor=None, pooling=None, classes=None, classifier_activation="softmax", **kwargs, ): return EfficientNetV2( include_rescaling=include_rescaling, include_top=include_top, width_coefficient=1.0, depth_coefficient=1.0, default_size=480, model_name="efficientnetv2-l", weights=parse_weights(weights, include_top, "efficientnetv2l"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, classes=classes, classifier_activation=classifier_activation, **kwargs, ) EfficientNetV2B0.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B0") EfficientNetV2B1.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B1") EfficientNetV2B2.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B2") EfficientNetV2B3.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2B3") EfficientNetV2S.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2S") EfficientNetV2M.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2M") EfficientNetV2L.__doc__ = BASE_DOCSTRING.format(name="EfficientNetV2L")