Commit a2aed890 authored by Nathan Ng's avatar Nathan Ng Committed by Facebook Github Bot
Browse files

Torch hub

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/621

Differential Revision: D15571435

Pulled By: myleott

fbshipit-source-id: 67d25b00c8c1bc69dbffd8521da56f7cc14eb75e
parent b35d9bca
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Utilities for working with the local dataset cache.
This file is adapted from `AllenNLP <https://github.com/allenai/allennlp>`_.
and `huggingface <https://github.com/huggingface>`_.
"""
import fnmatch
from functools import wraps
from hashlib import sha256
from io import open
import json
import logging
import os
import shutil
import tarfile
import tempfile
import boto3
from botocore.exceptions import ClientError
import requests
from tqdm import tqdm
try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
except ImportError:
torch_cache_home = os.path.expanduser(
os.getenv('TORCH_HOME', os.path.join(
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch')))
default_cache_path = os.path.join(torch_cache_home, 'pytorch_fairseq')
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_FAIRSEQ_CACHE = Path(
os.getenv('PYTORCH_FAIRSEQ_CACHE', default_cache_path))
except (AttributeError, ImportError):
PYTORCH_FAIRSEQ_CACHE = os.getenv(
'PYTORCH_FAIRSEQ_CACHE', default_cache_path)
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
ARCHIVE_MAP = {
# Pre-trained models
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2',
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2',
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2',
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2',
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2',
# Test sets with dictionaries
'data.newstest1213.en-de': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2',
'data.newstest14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2',
'data.newstest14.en-fr.joined': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2',
'data.newstest14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2',
'data.newstest14.en-de.joined': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2',
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
}
def load_archive_file(name_or_path):
if name_or_path in ARCHIVE_MAP:
archive_file = ARCHIVE_MAP[name_or_path]
else:
archive_file = name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=None)
except EnvironmentError:
print(
"Archive name '{}' was not found in archive name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
name_or_path,
', '.join(ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
print("loading archive file {}".format(archive_file))
else:
print("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
# Extract archive to temp dir and replace .tar.bz2 if necessary
tempdir = None
if not os.path.isdir(resolved_archive_file):
tempdir = tempfile.mkdtemp()
print("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:bz2') as archive:
top_dir = os.path.commonprefix(archive.getnames())
archive.extractall(tempdir)
os.remove(resolved_archive_file)
shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file)
shutil.rmtree(tempdir)
return resolved_archive_file
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_FAIRSEQ_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
try:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*')
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files))
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
output_string = json.dumps(meta)
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
......@@ -5,6 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
import argparse
import importlib
import os
......@@ -40,11 +45,6 @@ __all__ = [
]
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
......
......@@ -8,6 +8,7 @@
Base classes for various fairseq models.
"""
import os
from typing import Dict, List, Optional
import torch
......@@ -142,6 +143,40 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_)
@classmethod
def from_pretrained(cls, parser, *inputs, model_name_or_path, data_name_or_path, **kwargs):
"""
Instantiate a FairseqModel from a pre-trained model file or pytorch state dict.
Downloads and caches the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either
- a str with the name of a pre-trained model to load
- a path or url to a pretrained model state dict
"""
from fairseq import checkpoint_utils, file_utils, options, tasks
model_path = file_utils.load_archive_file(model_name_or_path)
data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, 'model.pt')
# set data and parse
model_args = options.parse_args_and_arch(parser, input_args=[data_path])
# override any kwargs passed in
if kwargs is not None:
for arg_name, arg_val in kwargs.items():
setattr(model_args, arg_name, arg_val)
print(model_args)
task = tasks.setup_task(model_args)
print("loading model checkpoint from {}".format(checkpoint_path))
model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task)
return model[0]
class FairseqEncoderDecoderModel(BaseFairseqModel):
"""Base class for encoder-decoder models.
......@@ -222,7 +257,6 @@ class FairseqModel(FairseqEncoderDecoderModel):
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import torch
import html
import os
from sacremoses import MosesTokenizer, MosesDetokenizer
from subword_nmt import apply_bpe
from fairseq import checkpoint_utils, options, tasks, utils, file_utils
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
class Generator(object):
def __init__(self, task, models, args, src_bpe=None, bpe_symbol='@@ '):
self.task = task
self.models = models
self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary
self.src_bpe = src_bpe
self.use_cuda = torch.cuda.is_available() and not args.cpu
self.args = args
self.args.remove_bpe = bpe_symbol
# optimize model for generation
for model in self.models:
model.make_generation_fast_(
beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if self.use_cuda:
model.cuda()
self.generator = self.task.build_generator(args)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(args.replace_unk)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(),
*[model.max_positions() for model in models]
)
if hasattr(args, 'source_lang'):
self.tokenizer = MosesTokenizer(lang=args.source_lang)
else:
self.tokenizer = MosesTokenizer()
if src_bpe is not None:
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe])
self.bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
else:
self.bpe = None
def generate(self, src_str, verbose=False):
src_str = self.tokenizer.tokenize(src_str, return_str=True)
if self.bpe:
src_str = self.bpe.process_line(src_str)
for batch in self.make_batches([src_str], self.args, self.task, self.max_positions):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
sample = {
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
}
translations = self.task.inference_step(self.generator, self.models, sample)
src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad())
if self.src_dict is not None:
src_str = self.src_dict.string(src_tokens, self.args.remove_bpe)
if verbose:
print('S\t{}'.format(src_str))
# Process top predictions
for hypo in translations[0][:min(len(translations), self.args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
remove_bpe=self.args.remove_bpe,
)
if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if self.args.print_alignment:
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
return html.unescape(hypo_str)
@classmethod
def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, **kwargs):
model_path = file_utils.load_archive_file(model_name_or_path)
data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, 'model.pt')
task_name = kwargs.get('task', 'translation')
# set data and parse
model_args = options.parse_args_and_arch(parser, input_args=[data_path, '--task', task_name])
# override any kwargs passed in
if kwargs is not None:
for arg_name, arg_val in kwargs.items():
setattr(model_args, arg_name, arg_val)
utils.import_user_module(args)
if model_args.buffer_size < 1:
model_args.buffer_size = 1
if model_args.max_tokens is None and model_args.max_sentences is None:
model_args.max_sentences = 1
assert not model_args.sampling or model_args.nbest == model_args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not model_args.max_sentences or model_args.max_sentences <= model_args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
print(model_args)
task = tasks.setup_task(model_args)
print("loading model checkpoint from {}".format(checkpoint_path))
model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task)
src_bpe = os.path.join(model_path, 'bpecodes')
if not os.path.exists(src_bpe):
src_bpe = None
return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))
def make_batches(self, lines, args, task, max_positions):
tokens = [
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
for src_str in lines
]
lengths = torch.LongTensor([t.numel() for t in tokens])
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
ids=batch['id'],
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.models.transformer import TransformerModel
from fairseq.models.fconv import FConvModel
from fairseq.models.fconv_self_att import FConvModelSelfAtt
from generator import Generator
from fairseq import options
dependencies = [
'torch',
'sacremoses',
'subword_nmt',
]
def transformer(*args, **kwargs):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
"""
parser = options.get_interactive_generation_parser()
model = TransformerModel.from_pretrained(parser, *args, **kwargs)
return model
def fconv(*args, **kwargs):
"""
A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence
Learning" (Gehring et al., 2017) <https://arxiv.org/abs/1705.03122>`_.
"""
parser = options.get_interactive_generation_parser()
model = FConvModel.from_pretrained(parser, *args, **kwargs)
return model
def fconv_self_att(*args, **kwargs):
parser = options.get_interactive_generation_parser()
model = FConvModelSelfAtt.from_pretrained(parser, *args, **kwargs)
return model
def generator(*args, **kwargs):
parser = options.get_generation_parser(interactive=True)
generator = Generator.from_pretrained(parser, *args, **kwargs)
return generator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment