model_zoo.py 1.2 KB
Newer Older
Zhang's avatar
v0.4.2  
Zhang committed
1
2
3
# pylint: disable=wildcard-import, unused-wildcard-import

from .fcn import *
Hang Zhang's avatar
Hang Zhang committed
4
from .psp import *
Zhang's avatar
v0.4.2  
Zhang committed
5
from .encnet import *
Hang Zhang's avatar
Hang Zhang committed
6
from .deeplab import *
Zhang's avatar
v0.4.2  
Zhang committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

__all__ = ['get_model']


def get_model(name, **kwargs):
    """Returns a pre-defined model by name

    Parameters
    ----------
    name : str
        Name of the model.
    pretrained : bool
        Whether to load the pretrained weights for model.
    root : str, default '~/.encoding/models'
        Location for keeping the model parameters.

    Returns
    -------
    Module:
        The model.
    """
    models = {
        'fcn_resnet50_pcontext': get_fcn_resnet50_pcontext,
        'encnet_resnet50_pcontext': get_encnet_resnet50_pcontext,
Hang Zhang's avatar
Hang Zhang committed
31
        'encnet_resnet101_pcontext': get_encnet_resnet101_pcontext,
Hang Zhang's avatar
Hang Zhang committed
32
        'encnet_resnet50_ade': get_encnet_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
33
        'encnet_resnet101_ade': get_encnet_resnet101_ade,
Zhang's avatar
v0.4.2  
Zhang committed
34
        'fcn_resnet50_ade': get_fcn_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
35
        'psp_resnet50_ade': get_psp_resnet50_ade,
Hang Zhang's avatar
Hang Zhang committed
36
        'deeplab_resnet50_ade': get_deeplab_resnet50_ade,
Zhang's avatar
v0.4.2  
Zhang committed
37
38
39
40
41
42
        }
    name = name.lower()
    if name not in models:
        raise ValueError('%s\n\t%s' % (str(e), '\n\t'.join(sorted(models.keys()))))
    net = models[name](**kwargs)
    return net