example_net.py 1010 Bytes
Newer Older
limm's avatar
limm committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from mmpretrain.models import ResNet
from mmpretrain.registry import MODELS


# Register your model to the `MODELS`.
@MODELS.register_module()
class ExampleNet(ResNet):
    """Implements an example backbone.

    Implement the backbone network just like a normal pytorch network.
    """

    def __init__(self, **kwargs) -> None:
        print('#############################\n'
              '#     Hello MMPretrain!     #\n'
              '#############################')
        super().__init__(**kwargs)

    def forward(self, x):
        """The forward method of the network.

        Args:
            x (torch.Tensor): A tensor of image batch with shape
                ``(batch_size, num_channels, height, width)``.

        Returns:
            Tuple[torch.Tensor]: Please return a tuple of tensors and every
            tensor is a feature map of specified scale. If you only want the
            final feature map, simply return a tuple with one item.
        """
        return super().forward(x)