"examples/mxnet/gcn/gcn_spmv.py" did not exist on "cffa4034f580e33fe4295e9f1b54217e7fa724eb"
seg_models.py 2.41 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
import math
import einops
import torch
import apex
import torch.nn.functional as F
xingjinliang's avatar
xingjinliang committed
7
8
9
10
from megatron.training import get_args
from megatron.legacy.model.module import MegatronModule
from megatron.legacy.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.legacy.model.vision.mit_backbone import mit_b3, mit_b5
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from tasks.vision.segmentation.seg_heads import SetrSegmentationHead, SegformerSegmentationHead


class SetrSegmentationModel(MegatronModule):

    def __init__(self,
                 num_classes,
                 pre_process=True,
                 post_process=True):
        super(SetrSegmentationModel, self).__init__()
        args = get_args()
        assert post_process & pre_process
        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.backbone = VitBackbone(
            pre_process=pre_process,
            post_process=post_process,
            class_token=False,
            post_layer_norm=False,
            drop_path_rate=0.1
        )

        self.head = SetrSegmentationHead(
            self.hidden_size,
            self.num_classes
        )

    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
39
        """See megatron.legacy.model.transformer.set_input_tensor()"""
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        pass

    def forward(self, input):
        # [b hw c]
        hidden_states = self.backbone(input)
        result_final = self.head(hidden_states)
        return result_final


class SegformerSegmentationModel(MegatronModule):

    def __init__(self,
                 num_classes,
                 pre_process=True,
                 post_process=True):
        super(SegformerSegmentationModel, self).__init__()
        args = get_args()
        self.hidden_size = args.hidden_size
        self.num_classes = num_classes
        self.pre_process = pre_process
        self.post_process = post_process

        self.backbone = mit_b5()
        self.head = SegformerSegmentationHead(
            feature_strides=[4, 8, 16, 32],
            in_channels=[64, 128, 320, 512],
            embedding_dim=768,
            dropout_ratio=0.1
        )

    def set_input_tensor(self, input_tensor):
xingjinliang's avatar
xingjinliang committed
71
        """See megatron.legacy.model.transformer.set_input_tensor()"""
72
73
74
75
76
77
78
79
        pass

    def forward(self, input):
        # [b hw c]
        hidden_states = self.backbone(input)
        hidden_states = self.head(hidden_states)
        return hidden_states