model_zoo.py 2.01 KB
Newer Older
Zhang's avatar
v0.4.2  
Zhang committed
1
2
# pylint: disable=wildcard-import, unused-wildcard-import

Hang Zhang's avatar
Hang Zhang committed
3
4
from .backbone import *
from .sseg import *
Hang Zhang's avatar
Hang Zhang committed
5
from .deepten import *
Zhang's avatar
v0.4.2  
Zhang committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

__all__ = ['get_model']

def get_model(name, **kwargs):
    """Returns a pre-defined model by name

    Parameters
    ----------
    name : str
        Name of the model.
    pretrained : bool
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.

    Returns
    -------
    Module:
        The model.
    """
    models = {
Hang Zhang's avatar
Hang Zhang committed
27
        # resnet
Hang Zhang's avatar
Hang Zhang committed
28
29
30
        'resnet50': resnet50,
        'resnet101': resnet101,
        'resnet152': resnet152,
Hang Zhang's avatar
Hang Zhang committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        # resnest
        'resnest50': resnest50,
        'resnest101': resnest101,
        'resnest200': resnest200,
        'resnest269': resnest269,
        # resnet other variants
        'resnet50s': resnet50s,
        'resnet101s': resnet101s,
        'resnet152s': resnet152s,
        'resnet50d': resnet50d,
        'resnext50_32x4d': resnext50_32x4d,
        'resnext101_32x8d': resnext101_32x8d,
        # other segmentation backbones
        'xception65': xception65,
        'wideresnet38': wideresnet38,
        'wideresnet50': wideresnet50,
        # deepten paper
Hang Zhang's avatar
Hang Zhang committed
48
        'deepten_resnet50_minc': get_deepten_resnet50_minc,
Hang Zhang's avatar
Hang Zhang committed
49
        # segmentation models
Zhang's avatar
v0.4.2  
Zhang committed
50
51
        'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
        'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
Hang Zhang's avatar
Hang Zhang committed
52
        'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
Hang Zhang's avatar
Hang Zhang committed
53
        'encnet_resnet50_ade': get_encnet_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
54
        'encnet_resnet101_ade': get_encnet_resnet101_ade,
Zhang's avatar
v0.4.2  
Zhang committed
55
        'fcn_resnet50_ade': get_fcn_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
56
        'psp_resnet50_ade': get_psp_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
57
58
        'deeplab_resnest50_ade': get_deeplab_resnest50_ade,
        'deeplab_resnest101_ade': get_deeplab_resnest101_ade,
Zhang's avatar
v0.4.2  
Zhang committed
59
60
61
        }
    name = name.lower()
    if name not in models:
Hang Zhang's avatar
Hang Zhang committed
62
        raise ValueError('%s\n\t%s' % (str(name), '\n\t'.join(sorted(models.keys()))))
Zhang's avatar
v0.4.2  
Zhang committed
63
64
    net = models[name](**kwargs)
    return net