hubconf.py 778 Bytes
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
Nathan Ng's avatar
Nathan Ng committed
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Nathan Ng's avatar
Nathan Ng committed
5

Myle Ott's avatar
Myle Ott committed
6
7
import functools

Myle Ott's avatar
Myle Ott committed
8
9
from fairseq.hub_utils import BPEHubInterface as bpe  # noqa
from fairseq.hub_utils import TokenizerHubInterface as tokenizer  # noqa
Myle Ott's avatar
Myle Ott committed
10
from fairseq.models import MODEL_REGISTRY
Nathan Ng's avatar
Nathan Ng committed
11
12


Myle Ott's avatar
Myle Ott committed
13
dependencies = [
Myle Ott's avatar
Myle Ott committed
14
    'numpy',
Myle Ott's avatar
Myle Ott committed
15
    'regex',
Myle Ott's avatar
Myle Ott committed
16
    'requests',
Myle Ott's avatar
Myle Ott committed
17
18
    'torch',
]
Nathan Ng's avatar
Nathan Ng committed
19
20


Myle Ott's avatar
Myle Ott committed
21
for _model_type, _cls in MODEL_REGISTRY.items():
Myle Ott's avatar
Myle Ott committed
22
23
24
    for model_name in _cls.hub_models().keys():
        globals()[model_name] = functools.partial(
            _cls.from_pretrained,
Myle Ott's avatar
Myle Ott committed
25
            model_name,
Myle Ott's avatar
Myle Ott committed
26
27
        )
    # to simplify the interface we only expose named models
Myle Ott's avatar
Myle Ott committed
28
    # globals()[_model_type] = _cls.from_pretrained