hubconf.py 2.53 KB
Newer Older
Ailing's avatar
Ailing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
'''
This is an example hubconf.py for pytorch/vision repo

## Users can get this published model by calling:
hub_model = hub.load(
    'pytorch/vision:master', # repo_owner/repo_name:branch
    'resnet18', # entrypoint
    1234, # args for callable [not applicable to resnet]
    pretrained=True) # kwargs for callable

## Protocol on repo owner side
1. The "published" models should be at least in a branch/tag. It can't be a random commit.
2. Repo owner should have the following field defined in hubconf.py
  2.1 Function/entrypoint with function signature "def resnet18(pretrained=False, *args, **kwargs):"
  2.2 Pretrained allows users to load pretrained weights from repo owner.
  2.3 Args and kwargs are passed to the callable _resnet18,
  2.4 Docstring of function works as a help message, explaining what does the model do and what's
      the allowed arguments.
  2.5 Dependencies is a list optionally provided by repo owner, to specify what packages are required
      to run the model.

## Hub_dir

hub_dir specifies where the intermediate files/folders will be saved. By default this is ~/.torch/hub.
Users can change it by either setting the environment variable TORCH_HUB_DIR or calling hub.set_dir(PATH_TO_HUB_DIR).
By default, we don't cleanup files after loading so that users can use cache next time.

## Cache logic

We used the cache by default if it exists in hub_dir.
Users can force a fresh reload by calling hub.load(..., force_reload=True).
'''

import torch.utils.model_zoo as model_zoo

# Optional list of dependencies required by the package
dependencies = ['torch', 'math']


def resnet18(pretrained=False, *args, **kwargs):
    """
    Resnet18 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    from torchvision.models.resnet import resnet18 as _resnet18
    model = _resnet18(*args, **kwargs)
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    if pretrained:
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model


def resnet50(pretrained=False, *args, **kwargs):
    """
    Resnet50 model
    pretrained (bool): a recommended kwargs for all entrypoints
    args & kwargs are arguments for the function
    """
    from torchvision.models.resnet import resnet50 as _resnet50
    model = _resnet50(*args, **kwargs)
    checkpoint = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
    if pretrained:
        model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
    return model