Commit 9c89e882 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Misc improvements to torch hub interface

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/750

Differential Revision: D16410986

Pulled By: myleott

fbshipit-source-id: 8ee6b4371d6ae5b041b00a54a6039a422345795e
parent 62b5498b
...@@ -6,12 +6,57 @@ ...@@ -6,12 +6,57 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import os
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
def from_pretrained(
model_name_or_path,
checkpoint_file='model.pt',
data_name_or_path='.',
archive_map=None,
**kwargs,
):
from fairseq import checkpoint_utils, file_utils
if archive_map is not None:
if model_name_or_path in archive_map:
model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path.startswith('.'):
kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
'code': 'bpe_codes',
'bpecodes': 'bpe_codes',
'sentencepiece.bpe.model': 'sentencepiece_vocab',
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
)
return {
'args': args,
'task': task,
'models': models,
}
class Generator(object): class Generator(object):
"""PyTorch Hub API for generating sequences from a pre-trained translation """PyTorch Hub API for generating sequences from a pre-trained translation
or language model.""" or language model."""
......
...@@ -144,7 +144,7 @@ class BaseFairseqModel(nn.Module): ...@@ -144,7 +144,7 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_) self.apply(apply_prepare_for_onnx_export_)
@classmethod @classmethod
def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs): def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', **kwargs):
""" """
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
file. Downloads and caches the pre-trained model file if needed. file. Downloads and caches the pre-trained model file if needed.
...@@ -165,40 +165,16 @@ class BaseFairseqModel(nn.Module): ...@@ -165,40 +165,16 @@ class BaseFairseqModel(nn.Module):
at the given path/URL. Can start with '.' or './' to reuse the at the given path/URL. Can start with '.' or './' to reuse the
model archive path. model archive path.
""" """
from fairseq import checkpoint_utils, file_utils, hub_utils from fairseq import hub_utils
x = hub_utils.from_pretrained(
if hasattr(cls, 'hub_models'): model_name_or_path,
archive_map = cls.hub_models() checkpoint_file,
if model_name_or_path in archive_map: data_name_or_path,
model_name_or_path = archive_map[model_name_or_path] archive_map=cls.hub_models(),
if data_name_or_path is not None and data_name_or_path in archive_map: **kwargs,
data_name_or_path = archive_map[data_name_or_path]
model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive
if data_name_or_path is not None:
if data_name_or_path.startswith('.'):
kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
'code': 'bpe_codes',
'bpecodes': 'bpe_codes',
'sentencepiece.bpe.model': 'sentencepiece_vocab',
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
models, args, task = checkpoint_utils.load_model_ensemble_and_task(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
) )
print(x['args'])
print(args) return hub_utils.Generator(x['args'], x['task'], x['models'])
return hub_utils.Generator(args, task, models)
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
......
...@@ -22,7 +22,7 @@ def setup_registry( ...@@ -22,7 +22,7 @@ def setup_registry(
# maintain a registry of all registries # maintain a registry of all registries
if registry_name in REGISTRIES: if registry_name in REGISTRIES:
raise ValueError('Canot setup duplicate registry: {}'.format(registry_name)) return # registry already exists
REGISTRIES[registry_name] = { REGISTRIES[registry_name] = {
'registry': REGISTRY, 'registry': REGISTRY,
'default': default, 'default': default,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment