"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "2cf59c41a50dd042402ffdf1d7c18dc50109d316"
Commit 5b2be870 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update PyTorch Hub interface

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/782

Differential Revision: D16542256

Pulled By: myleott

fbshipit-source-id: ea3279e7a1ce4687a5914f32b76787c419be1ffa
parent 3e0e5bec
...@@ -301,30 +301,14 @@ def _upgrade_state_dict(state): ...@@ -301,30 +301,14 @@ def _upgrade_state_dict(state):
if not hasattr(state['args'], 'task'): if not hasattr(state['args'], 'task'):
state['args'].task = 'translation' state['args'].task = 'translation'
def set_defaults(cls):
if not hasattr(cls, 'add_args'):
return
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(state['args'], key):
setattr(state['args'], key, default_value)
# set any missing default values in the task, model or other registries # set any missing default values in the task, model or other registries
set_defaults(tasks.TASK_REGISTRY[state['args'].task]) registry.set_defaults(state['args'], tasks.TASK_REGISTRY[state['args'].task])
set_defaults(models.ARCH_MODEL_REGISTRY[state['args'].arch]) registry.set_defaults(state['args'], models.ARCH_MODEL_REGISTRY[state['args'].arch])
for registry_name, REGISTRY in registry.REGISTRIES.items(): for registry_name, REGISTRY in registry.REGISTRIES.items():
choice = getattr(state['args'], registry_name, None) choice = getattr(state['args'], registry_name, None)
if choice is not None: if choice is not None:
cls = REGISTRY['registry'][choice] cls = REGISTRY['registry'][choice]
set_defaults(cls) registry.set_defaults(state['args'], cls)
return state return state
......
...@@ -4,13 +4,16 @@ ...@@ -4,13 +4,16 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
import copy
import os import os
from typing import List
import torch import torch
from torch import nn
from fairseq import utils from fairseq import utils
from fairseq.data import encoders from fairseq.data import encoders
from fairseq.models import BaseFairseqModel
def from_pretrained( def from_pretrained(
...@@ -56,22 +59,19 @@ def from_pretrained( ...@@ -56,22 +59,19 @@ def from_pretrained(
} }
class Generator(BaseFairseqModel): class GeneratorHubInterface(nn.Module):
"""PyTorch Hub API for generating sequences from a pre-trained translation """
or language model.""" PyTorch Hub interface for generating sequences from a pre-trained
translation or language model.
"""
def __init__(self, args, task, models): def __init__(self, args, task, models):
super().__init__()
self.args = args self.args = args
self.task = task self.task = task
self.models = models self.models = nn.ModuleList(models)
self.src_dict = task.source_dictionary self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary self.tgt_dict = task.target_dictionary
self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False)
if self.use_cuda:
if getattr(args, 'fp16', False):
self.half()
self.cuda()
# optimize model for generation # optimize model for generation
for model in self.models: for model in self.models:
...@@ -83,8 +83,6 @@ class Generator(BaseFairseqModel): ...@@ -83,8 +83,6 @@ class Generator(BaseFairseqModel):
need_attn=getattr(args, 'print_alignment', False), need_attn=getattr(args, 'print_alignment', False),
) )
self.generator = self.task.build_generator(args)
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary) # (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None)) self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None))
...@@ -92,53 +90,122 @@ class Generator(BaseFairseqModel): ...@@ -92,53 +90,122 @@ class Generator(BaseFairseqModel):
self.tokenizer = encoders.build_tokenizer(args) self.tokenizer = encoders.build_tokenizer(args)
self.bpe = encoders.build_bpe(args) self.bpe = encoders.build_bpe(args)
def generate(self, src_str, verbose=False): # this is useful for determining the device
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
def preprocess(s): @property
if self.tokenizer is not None: def device(self):
s = self.tokenizer.encode(s) return self._float_tensor.device
if self.bpe is not None:
s = self.bpe.encode(s)
return s
def postprocess(s): def translate(self, sentence: str, verbose: bool = False, **kwargs) -> str:
if self.bpe is not None: input = self.encode(sentence)
s = self.bpe.decode(s) hypo = self.generate(input, verbose, **kwargs)
if self.tokenizer is not None: return self.decode(hypo)
s = self.tokenizer.decode(s)
return s
src_str = preprocess(src_str) def generate(self, tokens: torch.LongTensor, verbose: bool = False, **kwargs) -> torch.LongTensor:
tokens = self.src_dict.encode_line(src_str, add_if_not_exist=False).long() sample = self._build_sample(tokens)
if verbose:
src_str_with_unk = self.src_dict.string(tokens)
print('S\t{}'.format(src_str_with_unk))
dataset = self.task.build_dataset_for_inference([tokens], [tokens.numel()]) # build generator using current args as well as any kwargs
sample = dataset.collater([dataset[0]]) gen_args = copy.copy(self.args)
if self.use_cuda: for k, v in kwargs.items():
sample = utils.move_to_cuda(sample) setattr(gen_args, k, v)
generator = self.task.build_generator(gen_args)
translations = self.task.inference_step(generator, self.models, sample)
translations = self.task.inference_step(self.generator, self.models, sample) if verbose:
src_str_with_unk = self.string(tokens)
print('S\t{}'.format(src_str_with_unk))
# Process top predictions # Process top predictions
for hypo in translations[0][:min(len(translations), getattr(self.args, 'nbest', 1))]: for hypo in translations[0][:min(len(translations), getattr(self.args, 'nbest', 1))]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_str = self.decode(hypo['tokens'])
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
)
hypo_str = postprocess(hypo_str)
if verbose: if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str)) print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format( print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
)) ))
if getattr(self.args, 'print_alignment', False): if hypo['alignment'] is not None and getattr(self.args, 'print_alignment', False):
print('A\t{}'.format( print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), alignment)) ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
)) ))
return hypo_str return hypo['tokens']
def encode(self, sentence: str) -> torch.LongTensor:
sentence = self.tokenize(sentence)
sentence = self.apply_bpe(sentence)
return self.binarize(sentence)
def decode(self, tokens: torch.LongTensor) -> str:
sentence = self.string(tokens)
sentence = self.remove_bpe(sentence)
return self.detokenize(sentence)
def tokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.encode(sentence)
return sentence
def detokenize(self, sentence: str) -> str:
if self.tokenizer is not None:
sentence = self.tokenizer.decode(sentence)
return sentence
def apply_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.encode(sentence)
return sentence
def remove_bpe(self, sentence: str) -> str:
if self.bpe is not None:
sentence = self.bpe.decode(sentence)
return sentence
def binarize(self, sentence: str) -> torch.LongTensor:
return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()
def string(self, tokens: torch.LongTensor) -> str:
return self.tgt_dict.string(tokens)
def _build_sample(self, src_tokens: torch.LongTensor):
assert torch.is_tensor(src_tokens)
dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()])
sample = dataset.collater([dataset[0]])
sample = utils.apply_to_sample(
lambda tensor: tensor.to(self.device),
sample
)
return sample
class BPEHubInterface(object):
"""PyTorch Hub interface for Byte-Pair Encoding (BPE)."""
def __init__(self, bpe, **kwargs):
super().__init__()
args = argparse.Namespace(bpe=bpe, **kwargs)
self.bpe = encoders.build_bpe(args)
assert self.bpe is not None
def encode(self, sentence: str) -> str:
return self.bpe.encode(sentence)
def decode(self, sentence: str) -> str:
return self.bpe.decode(sentence)
class TokenizerHubInterface(object):
"""PyTorch Hub interface for tokenization."""
def __init__(self, tokenizer, **kwargs):
super().__init__()
args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
self.tokenizer = encoders.build_tokenizer(args)
assert self.tokenizer is not None
def encode(self, sentence: str) -> str:
return self.tokenizer.encode(sentence)
def decode(self, sentence: str) -> str:
return self.tokenizer.decode(sentence)
...@@ -147,12 +147,13 @@ class BaseFairseqModel(nn.Module): ...@@ -147,12 +147,13 @@ class BaseFairseqModel(nn.Module):
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
file. Downloads and caches the pre-trained model file if needed. file. Downloads and caches the pre-trained model file if needed.
The base implementation returns a :class:`fairseq.hub_utils.Generator`, The base implementation returns a
which can be used to generate translations or sample from language :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
models. The underlying :class:`~fairseq.models.FairseqModel` can be generate translations or sample from language models. The underlying
accessed via the *generator.models* attribute. :class:`~fairseq.models.FairseqModel` can be accessed via the
*generator.models* attribute.
Other models may override this to implement custom PyTorch Hub APIs. Other models may override this to implement custom hub interfaces.
Args: Args:
model_name_or_path (str): either the name of a pre-trained model to model_name_or_path (str): either the name of a pre-trained model to
...@@ -172,7 +173,7 @@ class BaseFairseqModel(nn.Module): ...@@ -172,7 +173,7 @@ class BaseFairseqModel(nn.Module):
**kwargs, **kwargs,
) )
print(x['args']) print(x['args'])
return hub_utils.Generator(x['args'], x['task'], x['models']) return hub_utils.GeneratorHubInterface(x['args'], x['task'], x['models'])
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
REGISTRIES = {} REGISTRIES = {}
...@@ -35,6 +37,7 @@ def setup_registry( ...@@ -35,6 +37,7 @@ def setup_registry(
builder = getattr(cls, 'build_' + registry_name) builder = getattr(cls, 'build_' + registry_name)
else: else:
builder = cls builder = cls
set_defaults(args, cls)
return builder(args, *extra_args, **extra_kwargs) return builder(args, *extra_args, **extra_kwargs)
def register_x(name): def register_x(name):
...@@ -57,3 +60,21 @@ def setup_registry( ...@@ -57,3 +60,21 @@ def setup_registry(
return register_x_cls return register_x_cls
return build_x, register_x, REGISTRY return build_x, register_x, REGISTRY
def set_defaults(args, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, 'add_args'):
return
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
import functools import functools
from fairseq.hub_utils import BPEHubInterface as bpe # noqa
from fairseq.hub_utils import TokenizerHubInterface as tokenizer # noqa
from fairseq.models import MODEL_REGISTRY from fairseq.models import MODEL_REGISTRY
...@@ -18,11 +20,11 @@ dependencies = [ ...@@ -18,11 +20,11 @@ dependencies = [
] ]
for model_type, _cls in MODEL_REGISTRY.items(): for _model_type, _cls in MODEL_REGISTRY.items():
for model_name in _cls.hub_models().keys(): for model_name in _cls.hub_models().keys():
globals()[model_name] = functools.partial( globals()[model_name] = functools.partial(
_cls.from_pretrained, _cls.from_pretrained,
model_name_or_path=model_name, model_name_or_path=model_name,
) )
# to simplify the interface we only expose named models # to simplify the interface we only expose named models
#globals()[model_type] = _cls.from_pretrained # globals()[_model_type] = _cls.from_pretrained
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment