Commit 9c89e882 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc improvements to torch hub interface

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/750

Differential Revision: D16410986

Pulled By: myleott

fbshipit-source-id: 8ee6b4371d6ae5b041b00a54a6039a422345795e
parent 62b5498b
......@@ -6,12 +6,57 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import torch
from fairseq import utils
from fairseq.data import encoders
def from_pretrained(
model_name_or_path,
checkpoint_file='model.pt',
data_name_or_path='.',
archive_map=None,
**kwargs,
):
from fairseq import checkpoint_utils, file_utils
if archive_map is not None:
if model_name_or_path in archive_map:
model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path.startswith('.'):
kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
'code': 'bpe_codes',
'bpecodes': 'bpe_codes',
'sentencepiece.bpe.model': 'sentencepiece_vocab',
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
)
return {
'args': args,
'task': task,
'models': models,
}
class Generator(object):
"""PyTorch Hub API for generating sequences from a pre-trained translation
or language model."""
......
......@@ -144,7 +144,7 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_)
@classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs):
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs):
"""
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
file. Downloads and caches the pre-trained model file if needed.
......@@ -165,40 +165,16 @@ class BaseFairseqModel(nn.Module):
at the given path/URL. Can start with '.' or './' to reuse the
model archive path.
"""
from fairseq import checkpoint_utils, file_utils, hub_utils
if hasattr(cls, 'hub_models'):
archive_map = cls.hub_models()
if model_name_or_path in archive_map:
model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path is not None:
if data_name_or_path.startswith('.'):
kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
'code': 'bpe_codes',
'bpecodes': 'bpe_codes',
'sentencepiece.bpe.model': 'sentencepiece_vocab',
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
**kwargs,
)
print(args)
return hub_utils.Generator(args, task, models)
print(x['args'])
return hub_utils.Generator(x['args'], x['task'], x['models'])
@classmethod
def hub_models(cls):
......
......@@ -22,7 +22,7 @@ def setup_registry(
# maintain a registry of all registries
if registry_name in REGISTRIES:
raise ValueError('Canot setup duplicate registry: {}'.format(registry_name))
return # registry already exists
REGISTRIES[registry_name] = {
'registry': REGISTRY,
'default': default,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment