Commit 4868c182 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

More generator features for demo (#791)

Summary:
- make it possible to load file_utils.py without the dependencies
- add some more demo features
Pull Request resolved: https://github.com/pytorch/fairseq/pull/791

Differential Revision: D15739950

Pulled By: myleott

fbshipit-source-id: 38df5209973a6fe2e3651575b97134e096aaf5bf
parent a58c1127
......@@ -22,10 +22,6 @@ import shutil
import tarfile
import tempfile
import boto3
from botocore.exceptions import ClientError
import requests
from tqdm import tqdm
try:
from torch.hub import _get_torch_home
......@@ -212,6 +208,7 @@ def s3_request(func):
@wraps(func)
def wrapper(url, *args, **kwargs):
from botocore.exceptions import ClientError
try:
return func(url, *args, **kwargs)
except ClientError as exc:
......@@ -226,6 +223,7 @@ def s3_request(func):
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
......@@ -235,12 +233,15 @@ def s3_etag(url):
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
import boto3
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
import requests
from tqdm import tqdm
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
......@@ -270,6 +271,7 @@ def get_from_cache(url, cache_dir=None):
etag = s3_etag(url)
else:
try:
import requests
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
etag = None
......
......@@ -147,7 +147,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe):
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe=None):
from fairseq import tokenizer
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
......
......@@ -7,14 +7,15 @@
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import torch
import html
import os
import torch
from sacremoses import MosesTokenizer, MosesDetokenizer
from subword_nmt import apply_bpe
from fairseq import checkpoint_utils, options, tasks, utils, file_utils
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.data import data_utils
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
......@@ -30,8 +31,6 @@ class Generator(object):
self.use_cuda = torch.cuda.is_available() and not args.cpu
self.args = args
self.args.remove_bpe = bpe_symbol
# optimize model for generation
for model in self.models:
model.make_generation_fast_(
......@@ -54,23 +53,52 @@ class Generator(object):
*[model.max_positions() for model in models]
)
if hasattr(args, 'source_lang'):
self.tokenizer = MosesTokenizer(lang=args.source_lang)
else:
self.tokenizer = MosesTokenizer()
if src_bpe is not None:
self.in_transforms = []
self.out_transforms = []
if getattr(args, 'moses', False):
tokenizer = MosesTokenizer(lang=args.source_lang or 'en')
detokenizer = MosesDetokenizer(lang=args.target_lang or 'en')
self.in_transforms.append(lambda s: tokenizer.tokenize(s, return_str=True))
self.out_transforms.append(lambda s: detokenizer.detokenize(s.split()))
elif getattr(args, 'nltk', False):
from nltk.tokenize import word_tokenize
self.in_transforms.append(lambda s: ' '.join(word_tokenize(s)))
if getattr(args, 'gpt2_bpe', False):
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
encoder_json = os.path.join(os.path.dirname(src_bpe), 'encoder.json')
vocab_bpe = src_bpe
encoder = get_encoder(encoder_json, vocab_bpe)
self.in_transforms.append(lambda s: ' '.join(map(str, encoder.encode(s))))
self.out_transforms.append(lambda s: ' '.join(t for t in s.split() if t != '<unk>'))
self.out_transforms.append(lambda s: encoder.decode(map(int, s.strip().split())))
elif getattr(args, 'sentencepiece', False):
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(src_bpe)
self.in_transforms.append(lambda s: ' '.join(sp.EncodeAsPieces(s)))
self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, 'sentencepiece'))
elif src_bpe is not None:
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe])
self.bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
else:
self.bpe = None
bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
self.in_transforms.append(lambda s: bpe.process_line(s))
self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, bpe_symbol))
def generate(self, src_str, verbose=False):
src_str = self.tokenizer.tokenize(src_str, return_str=True)
if self.bpe:
src_str = self.bpe.process_line(src_str)
def preprocess(s):
for transform in self.in_transforms:
s = transform(s)
return s
def postprocess(s):
for transform in self.out_transforms:
s = transform(s)
return s
src_str = preprocess(src_str)
for batch in self.make_batches([src_str], self.args, self.task, self.max_positions):
src_tokens = batch.src_tokens
......@@ -89,7 +117,8 @@ class Generator(object):
src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad())
if self.src_dict is not None:
src_str = self.src_dict.string(src_tokens, self.args.remove_bpe)
src_str = self.src_dict.string(src_tokens)
src_str = postprocess(src_str)
if verbose:
print('S\t{}'.format(src_str))
......@@ -101,8 +130,8 @@ class Generator(object):
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
remove_bpe=self.args.remove_bpe,
)
hypo_str = postprocess(hypo_str)
if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
......@@ -116,15 +145,20 @@ class Generator(object):
return html.unescape(hypo_str)
@classmethod
def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, **kwargs):
def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, checkpoint_file='model.pt', extra_task_args=None, **kwargs):
from fairseq import file_utils
model_path = file_utils.load_archive_file(model_name_or_path)
data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, 'model.pt')
checkpoint_path = os.path.join(model_path, checkpoint_file)
task_name = kwargs.get('task', 'translation')
# set data and parse
model_args = options.parse_args_and_arch(parser, input_args=[data_path, '--task', task_name])
model_args = options.parse_args_and_arch(
parser,
input_args=[data_path, '--task', task_name] + (extra_task_args or [])
)
# override any kwargs passed in
if kwargs is not None:
......@@ -148,10 +182,18 @@ class Generator(object):
task = tasks.setup_task(model_args)
print("loading model checkpoint from {}".format(checkpoint_path))
model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task)
src_bpe = os.path.join(model_path, 'bpecodes')
if not os.path.exists(src_bpe):
model, _model_args = checkpoint_utils.load_model_ensemble(
[checkpoint_path],
task=task,
arg_overrides=kwargs,
)
src_bpe = None
for bpe in ['bpecodes', 'vocab.bpe', 'sentencepiece.bpe.model']:
path = os.path.join(model_path, bpe)
if os.path.exists(path):
src_bpe = path
break
return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))
......
......@@ -34,9 +34,11 @@ def buffered_read(input, buffer_size):
yield buffer
def make_batches(lines, args, task, max_positions):
def make_batches(lines, args, task, max_positions, encode_fn):
tokens = [
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
task.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines
]
lengths = torch.LongTensor([t.numel() for t in tokens])
......@@ -99,6 +101,18 @@ def main(args):
# Initialize generator
generator = task.build_generator(args)
# Hack to support GPT-2 BPE
if args.remove_bpe == 'gpt2':
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
decoder = get_encoder(
'fairseq/gpt2_bpe/encoder.json',
'fairseq/gpt2_bpe/vocab.bpe',
)
encode_fn = lambda x: ' '.join(map(str, decoder.encode(x)))
else:
decoder = None
encode_fn = lambda x: x
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
......@@ -114,7 +128,7 @@ def main(args):
start_id = 0
for inputs in buffered_read(args.input, args.buffer_size):
results = []
for batch in make_batches(inputs, args, task, max_positions):
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if use_cuda:
......@@ -148,6 +162,8 @@ def main(args):
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
)
if decoder is not None:
hypo_str = decoder.decode(map(int, hypo_str.strip().split()))
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format(
id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment