Commit 4868c182 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

More generator features for demo (#791)

Summary:
- make it possible to load file_utils.py without the dependencies
- add some more demo features
Pull Request resolved: https://github.com/pytorch/fairseq/pull/791

Differential Revision: D15739950

Pulled By: myleott

fbshipit-source-id: 38df5209973a6fe2e3651575b97134e096aaf5bf
parent a58c1127
...@@ -22,10 +22,6 @@ import shutil ...@@ -22,10 +22,6 @@ import shutil
import tarfile import tarfile
import tempfile import tempfile
import boto3
from botocore.exceptions import ClientError
import requests
from tqdm import tqdm
try: try:
from torch.hub import _get_torch_home from torch.hub import _get_torch_home
...@@ -212,6 +208,7 @@ def s3_request(func): ...@@ -212,6 +208,7 @@ def s3_request(func):
@wraps(func) @wraps(func)
def wrapper(url, *args, **kwargs): def wrapper(url, *args, **kwargs):
from botocore.exceptions import ClientError
try: try:
return func(url, *args, **kwargs) return func(url, *args, **kwargs)
except ClientError as exc: except ClientError as exc:
...@@ -226,6 +223,7 @@ def s3_request(func): ...@@ -226,6 +223,7 @@ def s3_request(func):
@s3_request @s3_request
def s3_etag(url): def s3_etag(url):
"""Check ETag on S3 object.""" """Check ETag on S3 object."""
import boto3
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path) s3_object = s3_resource.Object(bucket_name, s3_path)
...@@ -235,12 +233,15 @@ def s3_etag(url): ...@@ -235,12 +233,15 @@ def s3_etag(url):
@s3_request @s3_request
def s3_get(url, temp_file): def s3_get(url, temp_file):
"""Pull a file directly from S3.""" """Pull a file directly from S3."""
import boto3
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file): def http_get(url, temp_file):
import requests
from tqdm import tqdm
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length') content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
...@@ -270,6 +271,7 @@ def get_from_cache(url, cache_dir=None): ...@@ -270,6 +271,7 @@ def get_from_cache(url, cache_dir=None):
etag = s3_etag(url) etag = s3_etag(url)
else: else:
try: try:
import requests
response = requests.head(url, allow_redirects=True) response = requests.head(url, allow_redirects=True)
if response.status_code != 200: if response.status_code != 200:
etag = None etag = None
......
...@@ -147,7 +147,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk): ...@@ -147,7 +147,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
return ' '.join(hypo_tokens) return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe): def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe=None):
from fairseq import tokenizer from fairseq import tokenizer
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe) hypo_str = tgt_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None: if align_dict is not None:
......
...@@ -7,14 +7,15 @@ ...@@ -7,14 +7,15 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import namedtuple from collections import namedtuple
import torch
import html import html
import os import os
import torch
from sacremoses import MosesTokenizer, MosesDetokenizer from sacremoses import MosesTokenizer, MosesDetokenizer
from subword_nmt import apply_bpe from subword_nmt import apply_bpe
from fairseq import checkpoint_utils, options, tasks, utils, file_utils from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.data import data_utils
Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
...@@ -30,8 +31,6 @@ class Generator(object): ...@@ -30,8 +31,6 @@ class Generator(object):
self.use_cuda = torch.cuda.is_available() and not args.cpu self.use_cuda = torch.cuda.is_available() and not args.cpu
self.args = args self.args = args
self.args.remove_bpe = bpe_symbol
# optimize model for generation # optimize model for generation
for model in self.models: for model in self.models:
model.make_generation_fast_( model.make_generation_fast_(
...@@ -54,23 +53,52 @@ class Generator(object): ...@@ -54,23 +53,52 @@ class Generator(object):
*[model.max_positions() for model in models] *[model.max_positions() for model in models]
) )
if hasattr(args, 'source_lang'): self.in_transforms = []
self.tokenizer = MosesTokenizer(lang=args.source_lang) self.out_transforms = []
else:
self.tokenizer = MosesTokenizer() if getattr(args, 'moses', False):
tokenizer = MosesTokenizer(lang=args.source_lang or 'en')
if src_bpe is not None: detokenizer = MosesDetokenizer(lang=args.target_lang or 'en')
self.in_transforms.append(lambda s: tokenizer.tokenize(s, return_str=True))
self.out_transforms.append(lambda s: detokenizer.detokenize(s.split()))
elif getattr(args, 'nltk', False):
from nltk.tokenize import word_tokenize
self.in_transforms.append(lambda s: ' '.join(word_tokenize(s)))
if getattr(args, 'gpt2_bpe', False):
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
encoder_json = os.path.join(os.path.dirname(src_bpe), 'encoder.json')
vocab_bpe = src_bpe
encoder = get_encoder(encoder_json, vocab_bpe)
self.in_transforms.append(lambda s: ' '.join(map(str, encoder.encode(s))))
self.out_transforms.append(lambda s: ' '.join(t for t in s.split() if t != '<unk>'))
self.out_transforms.append(lambda s: encoder.decode(map(int, s.strip().split())))
elif getattr(args, 'sentencepiece', False):
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(src_bpe)
self.in_transforms.append(lambda s: ' '.join(sp.EncodeAsPieces(s)))
self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, 'sentencepiece'))
elif src_bpe is not None:
bpe_parser = apply_bpe.create_parser() bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe]) bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe])
self.bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries) bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
else: self.in_transforms.append(lambda s: bpe.process_line(s))
self.bpe = None self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, bpe_symbol))
def generate(self, src_str, verbose=False): def generate(self, src_str, verbose=False):
src_str = self.tokenizer.tokenize(src_str, return_str=True) def preprocess(s):
if self.bpe: for transform in self.in_transforms:
src_str = self.bpe.process_line(src_str) s = transform(s)
return s
def postprocess(s):
for transform in self.out_transforms:
s = transform(s)
return s
src_str = preprocess(src_str)
for batch in self.make_batches([src_str], self.args, self.task, self.max_positions): for batch in self.make_batches([src_str], self.args, self.task, self.max_positions):
src_tokens = batch.src_tokens src_tokens = batch.src_tokens
...@@ -89,7 +117,8 @@ class Generator(object): ...@@ -89,7 +117,8 @@ class Generator(object):
src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad()) src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad())
if self.src_dict is not None: if self.src_dict is not None:
src_str = self.src_dict.string(src_tokens, self.args.remove_bpe) src_str = self.src_dict.string(src_tokens)
src_str = postprocess(src_str)
if verbose: if verbose:
print('S\t{}'.format(src_str)) print('S\t{}'.format(src_str))
...@@ -101,8 +130,8 @@ class Generator(object): ...@@ -101,8 +130,8 @@ class Generator(object):
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict, align_dict=self.align_dict,
tgt_dict=self.tgt_dict, tgt_dict=self.tgt_dict,
remove_bpe=self.args.remove_bpe,
) )
hypo_str = postprocess(hypo_str)
if verbose: if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str)) print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format( print('P\t{}'.format(
...@@ -116,15 +145,20 @@ class Generator(object): ...@@ -116,15 +145,20 @@ class Generator(object):
return html.unescape(hypo_str) return html.unescape(hypo_str)
@classmethod @classmethod
def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, **kwargs): def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, checkpoint_file='model.pt', extra_task_args=None, **kwargs):
from fairseq import file_utils
model_path = file_utils.load_archive_file(model_name_or_path) model_path = file_utils.load_archive_file(model_name_or_path)
data_path = file_utils.load_archive_file(data_name_or_path) data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, 'model.pt') checkpoint_path = os.path.join(model_path, checkpoint_file)
task_name = kwargs.get('task', 'translation') task_name = kwargs.get('task', 'translation')
# set data and parse # set data and parse
model_args = options.parse_args_and_arch(parser, input_args=[data_path, '--task', task_name]) model_args = options.parse_args_and_arch(
parser,
input_args=[data_path, '--task', task_name] + (extra_task_args or [])
)
# override any kwargs passed in # override any kwargs passed in
if kwargs is not None: if kwargs is not None:
...@@ -148,10 +182,18 @@ class Generator(object): ...@@ -148,10 +182,18 @@ class Generator(object):
task = tasks.setup_task(model_args) task = tasks.setup_task(model_args)
print("loading model checkpoint from {}".format(checkpoint_path)) print("loading model checkpoint from {}".format(checkpoint_path))
model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task) model, _model_args = checkpoint_utils.load_model_ensemble(
src_bpe = os.path.join(model_path, 'bpecodes') [checkpoint_path],
if not os.path.exists(src_bpe): task=task,
arg_overrides=kwargs,
)
src_bpe = None src_bpe = None
for bpe in ['bpecodes', 'vocab.bpe', 'sentencepiece.bpe.model']:
path = os.path.join(model_path, bpe)
if os.path.exists(path):
src_bpe = path
break
return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ ')) return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))
......
...@@ -34,9 +34,11 @@ def buffered_read(input, buffer_size): ...@@ -34,9 +34,11 @@ def buffered_read(input, buffer_size):
yield buffer yield buffer
def make_batches(lines, args, task, max_positions): def make_batches(lines, args, task, max_positions, encode_fn):
tokens = [ tokens = [
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long() task.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines for src_str in lines
] ]
lengths = torch.LongTensor([t.numel() for t in tokens]) lengths = torch.LongTensor([t.numel() for t in tokens])
...@@ -99,6 +101,18 @@ def main(args): ...@@ -99,6 +101,18 @@ def main(args):
# Initialize generator # Initialize generator
generator = task.build_generator(args) generator = task.build_generator(args)
# Hack to support GPT-2 BPE
if args.remove_bpe == 'gpt2':
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
decoder = get_encoder(
'fairseq/gpt2_bpe/encoder.json',
'fairseq/gpt2_bpe/vocab.bpe',
)
encode_fn = lambda x: ' '.join(map(str, decoder.encode(x)))
else:
decoder = None
encode_fn = lambda x: x
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
...@@ -114,7 +128,7 @@ def main(args): ...@@ -114,7 +128,7 @@ def main(args):
start_id = 0 start_id = 0
for inputs in buffered_read(args.input, args.buffer_size): for inputs in buffered_read(args.input, args.buffer_size):
results = [] results = []
for batch in make_batches(inputs, args, task, max_positions): for batch in make_batches(inputs, args, task, max_positions, encode_fn):
src_tokens = batch.src_tokens src_tokens = batch.src_tokens
src_lengths = batch.src_lengths src_lengths = batch.src_lengths
if use_cuda: if use_cuda:
...@@ -148,6 +162,8 @@ def main(args): ...@@ -148,6 +162,8 @@ def main(args):
tgt_dict=tgt_dict, tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe, remove_bpe=args.remove_bpe,
) )
if decoder is not None:
hypo_str = decoder.decode(map(int, hypo_str.strip().split()))
print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str)) print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
print('P-{}\t{}'.format( print('P-{}\t{}'.format(
id, id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment