"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "326de4191578dfb55cb968880d40d703075e331e"
Commit 5bdee18e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Iterate on torch.hub interface

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/793

Differential Revision: D15758755

Pulled By: myleott

fbshipit-source-id: b93e4ac11bde36a0b59b4d6d1c84d31c3124d767
parent eea4d20b
...@@ -8,6 +8,27 @@ Description | Dataset | Model | Test set(s) ...@@ -8,6 +8,27 @@ Description | Dataset | Model | Test set(s)
---|---|---|--- ---|---|---|---
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
## Example usage
Interactive generation from the full ensemble via PyTorch Hub:
```
>>> import torch
>>> en2de_ensemble = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer',
... model_name_or_path='transformer.wmt18.en-de',
... checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
... data_name_or_path='.',
... tokenizer='moses',
... aggressive_dash_splits=True,
... bpe='subword_nmt',
... )
>>> len(en2de_ensemble.models)
5
>>> print(en2de_ensemble.generate('Hello world!'))
Hallo Welt!
```
## Citation ## Citation
```bibtex ```bibtex
@inproceedings{edunov2018backtranslation, @inproceedings{edunov2018backtranslation,
......
...@@ -7,8 +7,34 @@ Description | Parameters | Dataset | Model and Test set(s) ...@@ -7,8 +7,34 @@ Description | Parameters | Dataset | Model and Test set(s)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2) Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2) Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2)
## Example usage ## Example usage
Interactive generation via PyTorch Hub:
```
>>> import torch
>>> lm = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer_lm',
... model_name_or_path='transformer_lm.wiki103.adaptive',
... data_name_or_path='./data-bin',
... tokenizer='moses',
... aggressive_dash_splits=True,
... no_escape=True,
... beam=1,
... sampling=True,
... sampling_topk=10,
... temperature=0.8,
... )
>>> lm.generate('Barack Obama', verbose=True)
```
Available models are listed in the ``hub_models()`` method in each model file, for example:
[transformer_lm.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer_lm.py).
## Training a new model with the CLI tools
These scripts provide an example of pre-processing data for the Language Modeling task. These scripts provide an example of pre-processing data for the Language Modeling task.
### prepare-wikitext-103.sh ### prepare-wikitext-103.sh
...@@ -45,10 +71,8 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \ ...@@ -45,10 +71,8 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
# Evaluate: # Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \ $ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/transformer_wiki103/checkpoint_best.pt' \
--sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024 --sample-break-mode complete --max-tokens 3072 --context-window 2560 --softmax-batch 1024
``` ```
Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)): Train a convolutional language model ([Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](conv_lm/README.md)):
``` ```
# If it runs out of memory, try to reduce max-tokens and tokens-per-sample # If it runs out of memory, try to reduce max-tokens and tokens-per-sample
...@@ -63,5 +87,4 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \ ...@@ -63,5 +87,4 @@ $ fairseq-train --task language_modeling data-bin/wikitext-103 \
# Evaluate: # Evaluate:
$ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt' $ fairseq-eval-lm data-bin/wikitext-103 --path 'checkpoints/fconv_wiki103/checkpoint_best.pt'
``` ```
...@@ -21,7 +21,6 @@ curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvz ...@@ -21,7 +21,6 @@ curl https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz | tar xvz
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token. and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
## Example usage ## Example usage
``` ```
......
...@@ -11,7 +11,30 @@ Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 ...@@ -11,7 +11,30 @@ Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14
Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) | newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381); WMT'18 winner) | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2) | See NOTE in the archive
## Example usage ## Example usage (torch.hub)
Interactive generation via PyTorch Hub:
```
>>> import torch
>>> en2de = torch.hub.load(
... 'pytorch/fairseq',
... 'transformer',
... model_name_or_path='transformer.wmt16.en-de',
... data_name_or_path='.',
... tokenizer='moses',
... aggressive_dash_splits=True,
... bpe='subword_nmt',
... )
>>> print(en2de.models[0].__class__)
<class 'fairseq.models.transformer.TransformerModel'>
>>> print(en2de.generate('Hello world!'))
Hallo Welt!
```
Available models are listed in the ``hub_models()`` method in each model file, for example:
[transformer.py](https://github.com/pytorch/fairseq/blob/master/fairseq/models/transformer.py).
## Example usage (CLI tools)
Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti:
``` ```
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from fairseq import registry
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY = registry.setup_registry(
'--tokenizer',
default='space',
)
build_bpe, register_bpe, BPE_REGISTRY = registry.setup_registry(
'--bpe',
default=None,
)
# automatically import any Python files in the transforms/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.data.transforms.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import file_utils
from fairseq.data.transforms import register_bpe
@register_bpe('gpt2')
class GPT2BPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--gpt2-encoder-json', type=str,
default='https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json',
help='path to encoder.json')
parser.add_argument('--gpt2-vocab-bpe', type=str,
default='https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe',
help='path to vocab.bpe')
# fmt: on
def __init__(self, args):
encoder_json = file_utils.cached_path(args.gpt2_encoder_json)
vocab_bpe = file_utils.cached_path(args.gpt2_vocab_bpe)
self.bpe = get_encoder(encoder_json, vocab_bpe)
def encode(self, x: str) -> str:
return ' '.join(map(str, self.bpe.encode(x)))
def decode(self, x: str) -> str:
return self.bpe.decode(map(int, x.split()))
"""Byte pair encoding utilities from GPT-2"""
import os
import json
import regex as re
from functools import lru_cache
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
class Encoder:
def __init__(self, encoder, bpe_merges, errors='replace'):
self.encoder = encoder
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def get_encoder(encoder_json_path, vocab_bpe_path):
with open(encoder_json_path, 'r') as f:
encoder = json.load(f)
with open(vocab_bpe_path, 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(
encoder=encoder,
bpe_merges=bpe_merges,
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data.transforms import register_tokenizer
@register_tokenizer('moses')
class MosesTokenizer(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('-s', '--source-lang', default='en', metavar='SRC',
help='source language')
parser.add_argument('-t', '--target-lang', default='en', metavar='TARGET',
help='target language')
parser.add_argument('--aggressive-dash-splits', action='store_true', default=False,
help='triggers dash split rules')
parser.add_argument('--no-escape', action='store_true', default=False,
help='don\'t perform HTML escaping on apostrophy, quotes, etc.')
# fmt: on
def __init__(self, args):
self.args = args
try:
from sacremoses import MosesTokenizer, MosesDetokenizer
self.tok = MosesTokenizer(args.source_lang)
self.detok = MosesDetokenizer(args.target_lang)
except ImportError:
raise ImportError('Please install Moses tokenizer with: pip install sacremoses')
def encode(self, x: str) -> str:
return self.tok.tokenize(
x,
aggressive_dash_splits=self.args.aggressive_dash_splits,
return_str=True,
escape=(not self.args.no_escape),
)
def decode(self, x: str) -> str:
return self.detok.detokenize(x.split())
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data.transforms import register_tokenizer
@register_tokenizer('nltk')
class NLTKTokenizer(object):
def __init__(self, source_lang=None, target_lang=None):
try:
from nltk.tokenize import word_tokenize
self.word_tokenize = word_tokenize
except ImportError:
raise ImportError('Please install nltk with: pip install nltk')
def encode(self, x: str) -> str:
return ' '.join(self.word_tokenize(x))
def decode(self, x: str) -> str:
return x
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import file_utils
from fairseq.data.transforms import register_bpe
@register_bpe('sentencepiece')
class SentencepieceBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--sentencepiece-vocab', type=str,
help='path to sentencepiece vocab')
# fmt: on
def __init__(self, args):
vocab = file_utils.cached_path(args.sentencepiece_vocab)
try:
import sentencepiece as spm
self.sp = spm.SentencePieceProcessor()
self.sp.Load(vocab)
except ImportError:
raise ImportError('Please install sentencepiece with: pip install sentencepiece')
def encode(self, x: str) -> str:
return ' '.join(self.sp.EncodeAsPieces(x))
def decode(self, x: str) -> str:
return x.replace(' ', '').replace('\u2581', ' ').strip()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import re
from fairseq.data.transforms import register_tokenizer
@register_tokenizer('space')
class SpaceTokenizer(object):
def __init__(self, source_lang=None, target_lang=None):
self.space_tok = re.compile(r"\s+")
def encode(self, x: str) -> str:
return self.space_tok.sub(" ", x).strip().split()
def decode(self, x: str) -> str:
return x
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import file_utils
from fairseq.data.transforms import register_bpe
@register_bpe('subword_nmt')
class SubwordNMTBPE(object):
@staticmethod
def add_args(parser):
# fmt: off
parser.add_argument('--bpe-codes', type=str,
help='path to subword NMT BPE')
parser.add_argument('--bpe-separator', default='@@',
help='BPE separator')
# fmt: on
def __init__(self, args):
codes = file_utils.cached_path(args.bpe_codes)
try:
from subword_nmt import apply_bpe
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args([
'--codes', codes,
'--separator', args.bpe_separator,
])
self.bpe = apply_bpe.BPE(
bpe_args.codes,
bpe_args.merges,
bpe_args.separator,
None,
bpe_args.glossaries,
)
self.bpe_symbol = bpe_args.separator + ' '
except ImportError:
raise ImportError('Please install subword_nmt with: pip install subword-nmt')
def encode(self, x: str) -> str:
return self.bpe.process_line(x)
def decode(self, x: str) -> str:
return (x + ' ').replace(self.bpe_symbol, '').rstrip()
...@@ -50,45 +50,20 @@ WEIGHTS_NAME = "pytorch_model.bin" ...@@ -50,45 +50,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
ARCHIVE_MAP = {
# Pre-trained models
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2',
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2',
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2',
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2',
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2',
# Test sets with dictionaries
'data.newstest1213.en-de': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2',
'data.newstest14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2',
'data.newstest14.en-fr.joined': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2',
'data.newstest14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2',
'data.newstest14.en-de.joined': 'https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2',
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
}
def load_archive_file(name_or_path):
if name_or_path in ARCHIVE_MAP:
archive_file = ARCHIVE_MAP[name_or_path]
else:
archive_file = name_or_path
def load_archive_file(archive_file):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=None) resolved_archive_file = cached_path(archive_file, cache_dir=None)
except EnvironmentError: except EnvironmentError:
print( print(
"Archive name '{}' was not found in archive name list ({}). " "Archive name '{}' was not found in archive name list. "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or URL but couldn't find any file "
"associated to this path or url.".format( "associated to this path or URL.".format(
name_or_path, archive_file,
', '.join(ARCHIVE_MAP.keys()), archive_file,
archive_file)) )
)
return None return None
if resolved_archive_file == archive_file: if resolved_archive_file == archive_file:
...@@ -116,7 +91,7 @@ def load_archive_file(name_or_path): ...@@ -116,7 +91,7 @@ def load_archive_file(name_or_path):
def url_to_filename(url, etag=None): def url_to_filename(url, etag=None):
""" """
Convert `url` into a hashed filename in a repeatable way. Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited If `etag` is specified, append its hash to the URL's, delimited
by a period. by a period.
""" """
url_bytes = url.encode('utf-8') url_bytes = url.encode('utf-8')
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import utils
from fairseq.data import transforms
class Generator(object):
"""PyTorch Hub API for generating sequences from a pre-trained translation
or language model."""
def __init__(self, args, task, models):
self.args = args
self.task = task
self.models = models
self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary
self.use_cuda = torch.cuda.is_available() and not getattr(args, 'cpu', False)
# optimize model for generation
for model in self.models:
model.make_generation_fast_(
beamable_mm_beam_size=(
None if getattr(args, 'no_beamable_mm', False)
else getattr(args, 'beam', 5)
),
need_attn=getattr(args, 'print_alignment', False),
)
if self.use_cuda:
if getattr(args, 'fp16', False):
model.half()
model.cuda()
self.generator = self.task.build_generator(args)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(getattr(args, 'replace_unk', None))
self.tokenizer = transforms.build_tokenizer(args)
self.bpe = transforms.build_bpe(args)
def generate(self, src_str, verbose=False):
def preprocess(s):
if self.tokenizer is not None:
s = self.tokenizer.encode(s)
if self.bpe is not None:
s = self.bpe.encode(s)
return s
def postprocess(s):
if self.bpe is not None:
s = self.bpe.decode(s)
if self.tokenizer is not None:
s = self.tokenizer.decode(s)
return s
src_str = preprocess(src_str)
tokens = self.src_dict.encode_line(src_str, add_if_not_exist=False).long()
if verbose:
src_str_with_unk = self.src_dict.string(tokens)
print('S\t{}'.format(src_str_with_unk))
dataset = self.task.build_dataset_for_inference([tokens], [tokens.numel()])
sample = dataset.collater([dataset[0]])
if self.use_cuda:
sample = utils.move_to_cuda(sample)
translations = self.task.inference_step(self.generator, self.models, sample)
# Process top predictions
for hypo in translations[0][:min(len(translations), getattr(self.args, 'nbest', 1))]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
)
hypo_str = postprocess(hypo_str)
if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if getattr(self.args, 'print_alignment', False):
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
return hypo_str
...@@ -144,38 +144,65 @@ class BaseFairseqModel(nn.Module): ...@@ -144,38 +144,65 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_) self.apply(apply_prepare_for_onnx_export_)
@classmethod @classmethod
def from_pretrained(cls, parser, *inputs, model_name_or_path, data_name_or_path, **kwargs): def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_name_or_path=None, **kwargs):
""" """
Instantiate a FairseqModel from a pre-trained model file or pytorch state dict. Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
Downloads and caches the pre-trained model file if needed. file. Downloads and caches the pre-trained model file if needed.
Params: The base implementation returns a :class:`fairseq.hub_utils.Generator`,
pretrained_model_name_or_path: either which can be used to generate translations or sample from language
- a str with the name of a pre-trained model to load models. The underlying :class:`~fairseq.models.FairseqModel` can be
- a path or url to a pretrained model state dict accessed via the *generator.models* attribute.
"""
from fairseq import checkpoint_utils, file_utils, options, tasks
model_path = file_utils.load_archive_file(model_name_or_path) Other models may override this to implement custom PyTorch Hub APIs.
data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, 'model.pt')
# set data and parse Args:
model_args = options.parse_args_and_arch(parser, input_args=[data_path]) model_name_or_path (str): either the name of a pre-trained model to
load or a path/URL to a pre-trained model state dict
checkpoint_file (str, optional): colon-separated list of checkpoint
files in the model archive to ensemble (default: 'model.pt')
data_name_or_path (str, optional): point args.data to the archive
at the given path/URL. Can start with '.' or './' to reuse the
model archive path.
"""
from fairseq import checkpoint_utils, file_utils, hub_utils
# override any kwargs passed in if hasattr(cls, 'hub_models'):
if kwargs is not None: archive_map = cls.hub_models()
for arg_name, arg_val in kwargs.items(): if model_name_or_path in archive_map:
setattr(model_args, arg_name, arg_val) model_name_or_path = archive_map[model_name_or_path]
if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path]
print(model_args) model_path = file_utils.load_archive_file(model_name_or_path)
task = tasks.setup_task(model_args) # convenience hack for loading data and BPE codes from model archive
print("loading model checkpoint from {}".format(checkpoint_path)) if data_name_or_path is not None:
if data_name_or_path.startswith('.'):
kwargs['data'] = os.path.abspath(os.path.join(model_path, data_name_or_path))
else:
kwargs['data'] = file_utils.load_archive_file(data_name_or_path)
for file, arg in {
'code': 'bpe_codes',
'bpecodes': 'bpe_codes',
'sentencepiece.bpe.model': 'sentencepiece_vocab',
}.items():
path = os.path.join(model_path, file)
if os.path.exists(path):
kwargs[arg] = path
models, args, task = checkpoint_utils._load_model_ensemble(
[os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')],
arg_overrides=kwargs,
)
model, _model_args = checkpoint_utils.load_model_ensemble([checkpoint_path], task=task) print(args)
return model[0] return hub_utils.Generator(args, task, models)
@classmethod
def hub_models(cls):
return {}
class FairseqEncoderDecoderModel(BaseFairseqModel): class FairseqEncoderDecoderModel(BaseFairseqModel):
......
...@@ -43,6 +43,14 @@ class FConvModel(FairseqEncoderDecoderModel): ...@@ -43,6 +43,14 @@ class FConvModel(FairseqEncoderDecoderModel):
:prog: :prog:
""" """
@classmethod
def hub_models(cls):
return {
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2',
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2',
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2',
}
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
......
...@@ -31,6 +31,15 @@ from fairseq.modules import ( ...@@ -31,6 +31,15 @@ from fairseq.modules import (
@register_model('fconv_self_att') @register_model('fconv_self_att')
class FConvModelSelfAtt(FairseqEncoderDecoderModel): class FConvModelSelfAtt(FairseqEncoderDecoderModel):
@classmethod
def hub_models(cls):
return {
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2',
# Test set containing dictionaries
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
}
def __init__(self, encoder, decoder, pretrained_encoder=None): def __init__(self, encoder, decoder, pretrained_encoder=None):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
...@@ -85,6 +94,7 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel): ...@@ -85,6 +94,7 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance."""
trained_encoder, trained_decoder = None, None trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
...@@ -102,7 +112,6 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel): ...@@ -102,7 +112,6 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
for param in trained_encoder.parameters(): for param in trained_encoder.parameters():
param.requires_grad = False param.requires_grad = False
"""Build a new model instance."""
encoder = FConvEncoder( encoder = FConvEncoder(
task.source_dictionary, task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
......
...@@ -49,6 +49,14 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -49,6 +49,14 @@ class TransformerModel(FairseqEncoderDecoderModel):
:prog: :prog:
""" """
@classmethod
def hub_models(cls):
return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2',
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.bz2',
}
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
......
...@@ -26,6 +26,13 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024 ...@@ -26,6 +26,13 @@ DEFAULT_MAX_TARGET_POSITIONS = 1024
@register_model('transformer_lm') @register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel): class TransformerLanguageModel(FairseqLanguageModel):
@classmethod
def hub_models(cls):
return {
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
}
def __init__(self, decoder): def __init__(self, decoder):
super().__init__(decoder) super().__init__(decoder)
......
...@@ -183,28 +183,28 @@ class FairseqTask(object): ...@@ -183,28 +183,28 @@ class FairseqTask(object):
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def build_generator(self, args): def build_generator(self, args):
if args.score_reference: if getattr(args, 'score_reference', False):
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(self.target_dictionary) return SequenceScorer(self.target_dictionary)
else: else:
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
return SequenceGenerator( return SequenceGenerator(
self.target_dictionary, self.target_dictionary,
beam_size=args.beam, beam_size=getattr(args, 'beam', 5),
max_len_a=args.max_len_a, max_len_a=getattr(args, 'max_len_a', 0),
max_len_b=args.max_len_b, max_len_b=getattr(args, 'max_len_b', 200),
min_len=args.min_len, min_len=getattr(args, 'min_len', 1),
stop_early=(not args.no_early_stop), stop_early=(not getattr(args, 'no_early_stop', False)),
normalize_scores=(not args.unnormalized), normalize_scores=(not getattr(args, 'unnormalized', False)),
len_penalty=args.lenpen, len_penalty=getattr(args, 'lenpen', 1),
unk_penalty=args.unkpen, unk_penalty=getattr(args, 'unkpen', 0),
sampling=args.sampling, sampling=getattr(args, 'sampling', False),
sampling_topk=args.sampling_topk, sampling_topk=getattr(args, 'sampling_topk', -1),
temperature=args.temperature, temperature=getattr(args, 'temperature', 1.),
diverse_beam_groups=args.diverse_beam_groups, diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
diverse_beam_strength=args.diverse_beam_strength, diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),
match_source_len=args.match_source_len, match_source_len=getattr(args, 'match_source_len', False),
no_repeat_ngram_size=args.no_repeat_ngram_size, no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0),
) )
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import namedtuple
import html
import os
import torch
from sacremoses import MosesTokenizer, MosesDetokenizer
from subword_nmt import apply_bpe
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.data import data_utils
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
class Generator(object):
def __init__(self, task, models, args, src_bpe=None, bpe_symbol='@@ '):
self.task = task
self.models = models
self.src_dict = task.source_dictionary
self.tgt_dict = task.target_dictionary
self.src_bpe = src_bpe
self.use_cuda = torch.cuda.is_available() and not args.cpu
self.args = args
# optimize model for generation
for model in self.models:
model.make_generation_fast_(
beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if self.use_cuda:
model.cuda()
self.generator = self.task.build_generator(args)
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
self.align_dict = utils.load_align_dict(args.replace_unk)
self.max_positions = utils.resolve_max_positions(
self.task.max_positions(),
*[model.max_positions() for model in models]
)
self.in_transforms = []
self.out_transforms = []
if getattr(args, 'moses', False):
tokenizer = MosesTokenizer(lang=args.source_lang or 'en')
detokenizer = MosesDetokenizer(lang=args.target_lang or 'en')
self.in_transforms.append(lambda s: tokenizer.tokenize(s, return_str=True))
self.out_transforms.append(lambda s: detokenizer.detokenize(s.split()))
elif getattr(args, 'nltk', False):
from nltk.tokenize import word_tokenize
self.in_transforms.append(lambda s: ' '.join(word_tokenize(s)))
if getattr(args, 'gpt2_bpe', False):
from fairseq.gpt2_bpe.gpt2_encoding import get_encoder
encoder_json = os.path.join(os.path.dirname(src_bpe), 'encoder.json')
vocab_bpe = src_bpe
encoder = get_encoder(encoder_json, vocab_bpe)
self.in_transforms.append(lambda s: ' '.join(map(str, encoder.encode(s))))
self.out_transforms.append(lambda s: ' '.join(t for t in s.split() if t != '<unk>'))
self.out_transforms.append(lambda s: encoder.decode(map(int, s.strip().split())))
elif getattr(args, 'sentencepiece', False):
import sentencepiece as spm
sp = spm.SentencePieceProcessor()
sp.Load(src_bpe)
self.in_transforms.append(lambda s: ' '.join(sp.EncodeAsPieces(s)))
self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, 'sentencepiece'))
elif src_bpe is not None:
bpe_parser = apply_bpe.create_parser()
bpe_args = bpe_parser.parse_args(['--codes', self.src_bpe])
bpe = apply_bpe.BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator, None, bpe_args.glossaries)
self.in_transforms.append(lambda s: bpe.process_line(s))
self.out_transforms.append(lambda s: data_utils.process_bpe_symbol(s, bpe_symbol))
def generate(self, src_str, verbose=False):
def preprocess(s):
for transform in self.in_transforms:
s = transform(s)
return s
def postprocess(s):
for transform in self.out_transforms:
s = transform(s)
return s
src_str = preprocess(src_str)
for batch in self.make_batches([src_str], self.args, self.task, self.max_positions):
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
if self.use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
sample = {
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
},
}
translations = self.task.inference_step(self.generator, self.models, sample)
src_tokens = utils.strip_pad(src_tokens, self.tgt_dict.pad())
if self.src_dict is not None:
src_str = self.src_dict.string(src_tokens)
src_str = postprocess(src_str)
if verbose:
print('S\t{}'.format(src_str))
# Process top predictions
for hypo in translations[0][:min(len(translations), self.args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
align_dict=self.align_dict,
tgt_dict=self.tgt_dict,
)
hypo_str = postprocess(hypo_str)
if verbose:
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('P\t{}'.format(
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
))
if self.args.print_alignment:
print('A\t{}'.format(
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
return html.unescape(hypo_str)
@classmethod
def from_pretrained(cls, parser, *args, model_name_or_path, data_name_or_path, checkpoint_file='model.pt', extra_task_args=None, **kwargs):
from fairseq import file_utils
model_path = file_utils.load_archive_file(model_name_or_path)
data_path = file_utils.load_archive_file(data_name_or_path)
checkpoint_path = os.path.join(model_path, checkpoint_file)
task_name = kwargs.get('task', 'translation')
# set data and parse
model_args = options.parse_args_and_arch(
parser,
input_args=[data_path, '--task', task_name] + (extra_task_args or [])
)
# override any kwargs passed in
if kwargs is not None:
for arg_name, arg_val in kwargs.items():
setattr(model_args, arg_name, arg_val)
utils.import_user_module(args)
if model_args.buffer_size < 1:
model_args.buffer_size = 1
if model_args.max_tokens is None and model_args.max_sentences is None:
model_args.max_sentences = 1
assert not model_args.sampling or model_args.nbest == model_args.beam, \
'--sampling requires --nbest to be equal to --beam'
assert not model_args.max_sentences or model_args.max_sentences <= model_args.buffer_size, \
'--max-sentences/--batch-size cannot be larger than --buffer-size'
print(model_args)
task = tasks.setup_task(model_args)
print("loading model checkpoint from {}".format(checkpoint_path))
model, _model_args = checkpoint_utils.load_model_ensemble(
[checkpoint_path],
task=task,
arg_overrides=kwargs,
)
src_bpe = None
for bpe in ['bpecodes', 'vocab.bpe', 'sentencepiece.bpe.model']:
path = os.path.join(model_path, bpe)
if os.path.exists(path):
src_bpe = path
break
return cls(task, model, model_args, src_bpe, kwargs.get('remove_bpe', '@@ '))
def make_batches(self, lines, args, task, max_positions):
tokens = [
task.source_dictionary.encode_line(src_str, add_if_not_exist=False).long()
for src_str in lines
]
lengths = torch.LongTensor([t.numel() for t in tokens])
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(tokens, lengths),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions,
).next_epoch_itr(shuffle=False)
for batch in itr:
yield Batch(
ids=batch['id'],
src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment