build.py 3.14 KB
Newer Older
PRC-Huang's avatar
PRC-Huang committed
1
2
3
4
5
6
7
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

from .intern_image import InternImage
8
from .intern_image_meta_former import InternImageMetaFormer
PRC-Huang's avatar
PRC-Huang committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


def build_model(config):
    model_type = config.MODEL.TYPE
    if model_type == 'intern_image':
        model = InternImage(
            core_op=config.MODEL.INTERN_IMAGE.CORE_OP,
            num_classes=config.MODEL.NUM_CLASSES,
            channels=config.MODEL.INTERN_IMAGE.CHANNELS,
            depths=config.MODEL.INTERN_IMAGE.DEPTHS,
            groups=config.MODEL.INTERN_IMAGE.GROUPS,
            layer_scale=config.MODEL.INTERN_IMAGE.LAYER_SCALE,
            offset_scale=config.MODEL.INTERN_IMAGE.OFFSET_SCALE,
            post_norm=config.MODEL.INTERN_IMAGE.POST_NORM,
            mlp_ratio=config.MODEL.INTERN_IMAGE.MLP_RATIO,
            with_cp=config.TRAIN.USE_CHECKPOINT,
25
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
26
27
28
29
30
            res_post_norm=config.MODEL.INTERN_IMAGE.RES_POST_NORM, # for InternImage-H/G
            dw_kernel_size=config.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE, # for InternImage-H/G
            use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR, # for InternImage-H/G
            level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM, # for InternImage-H/G
            level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS, # for InternImage-H/G
31
32
            center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE, # for InternImage-H/G
            remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
PRC-Huang's avatar
PRC-Huang committed
33
        )
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    elif model_type == 'intern_image_meta_former':
        model = InternImageMetaFormer(
            core_op=config.MODEL.INTERN_IMAGE.CORE_OP,
            num_classes=config.MODEL.NUM_CLASSES,
            channels=config.MODEL.INTERN_IMAGE.CHANNELS,
            depths=config.MODEL.INTERN_IMAGE.DEPTHS,
            groups=config.MODEL.INTERN_IMAGE.GROUPS,
            layer_scale=config.MODEL.INTERN_IMAGE.LAYER_SCALE,
            offset_scale=config.MODEL.INTERN_IMAGE.OFFSET_SCALE,
            post_norm=config.MODEL.INTERN_IMAGE.POST_NORM,
            mlp_ratio=config.MODEL.INTERN_IMAGE.MLP_RATIO,
            with_cp=config.TRAIN.USE_CHECKPOINT,
            drop_path_rate=config.MODEL.DROP_PATH_RATE,
            res_post_norm=config.MODEL.INTERN_IMAGE.RES_POST_NORM,  # for InternImage-H/G
            dw_kernel_size=config.MODEL.INTERN_IMAGE.DW_KERNEL_SIZE,  # for InternImage-H/G
            use_clip_projector=config.MODEL.INTERN_IMAGE.USE_CLIP_PROJECTOR,  # for InternImage-H/G
            level2_post_norm=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM,  # for InternImage-H/G
            level2_post_norm_block_ids=config.MODEL.INTERN_IMAGE.LEVEL2_POST_NORM_BLOCK_IDS,  # for InternImage-H/G
            center_feature_scale=config.MODEL.INTERN_IMAGE.CENTER_FEATURE_SCALE,  # for InternImage-H/G
            remove_center=config.MODEL.INTERN_IMAGE.REMOVE_CENTER,
        )
PRC-Huang's avatar
PRC-Huang committed
55
    else:
zhe chen's avatar
zhe chen committed
56
        raise NotImplementedError(f'Unkown model: {model_type}')
PRC-Huang's avatar
PRC-Huang committed
57
58

    return model