Commit aa645e71 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 436852559
parent 1e788762
......@@ -13,37 +13,30 @@
# limitations under the License.
"""Image classification configuration definition."""
import os
from typing import List, Optional
import dataclasses
import os
from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.projects.vit.configs import backbones
from official.vision.configs import common
from official.vision.configs import image_classification as img_cls_cfg
from official.projects.vit.configs import backbones
from official.vision.tasks import image_classification
# pytype: disable=wrong-keyword-args
DataConfig = img_cls_cfg.DataConfig
@dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config):
class ImageClassificationModel(img_cls_cfg.ImageClassificationModel):
"""The model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='vit', vit=backbones.VisionTransformer())
dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment