Commit b8849274 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 436852559
parent c166ae02
...@@ -13,37 +13,30 @@ ...@@ -13,37 +13,30 @@
# limitations under the License. # limitations under the License.
"""Image classification configuration definition.""" """Image classification configuration definition."""
import os
from typing import List, Optional
import dataclasses import dataclasses
import os
from typing import Optional
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.core import task_factory from official.core import task_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.projects.vit.configs import backbones
from official.vision.configs import common from official.vision.configs import common
from official.vision.configs import image_classification as img_cls_cfg from official.vision.configs import image_classification as img_cls_cfg
from official.projects.vit.configs import backbones
from official.vision.tasks import image_classification from official.vision.tasks import image_classification
# pytype: disable=wrong-keyword-args
DataConfig = img_cls_cfg.DataConfig DataConfig = img_cls_cfg.DataConfig
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config): class ImageClassificationModel(img_cls_cfg.ImageClassificationModel):
"""The model config.""" """The model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone( backbone: backbones.Backbone = backbones.Backbone(
type='vit', vit=backbones.VisionTransformer()) type='vit', vit=backbones.VisionTransformer())
dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass @dataclasses.dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment