Commit cb6c67bc authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Make torch.hub interface automatically apply tokenization and BPE

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/926

Differential Revision: D18685772

Pulled By: myleott

fbshipit-source-id: 0f99d79ed6ee72e9d3ced786d75ab9504d0dfcf0
parent fb3e1e36
...@@ -55,8 +55,15 @@ Fairseq provides reference implementations of various sequence-to-sequence model ...@@ -55,8 +55,15 @@ Fairseq provides reference implementations of various sequence-to-sequence model
- mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) - mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models-and-examples) for several benchmark We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
translation and language modeling datasets. with a convenient `torch.hub` interface:
```python
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
en2de.translate('Hello world', beam=5)
# 'Hallo Welt'
```
See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
![Model](fairseq.gif) ![Model](fairseq.gif)
......
...@@ -6,7 +6,7 @@ The following commands provide an example of pre-processing data, training a mod ...@@ -6,7 +6,7 @@ The following commands provide an example of pre-processing data, training a mod
Description | Dataset | Model | Test set(s) Description | Dataset | Model | Test set(s)
---|---|---|--- ---|---|---|---
Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://arxiv.org/abs/1805.04833) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2) Stories with Convolutional Model <br> ([Fan et al., 2018](https://arxiv.org/abs/1805.04833)) | [WritingPrompts](https://dl.fbaipublicfiles.com/fairseq/data/writingPrompts.tar.gz) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2)
We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation. We provide sample stories generated by the [convolutional seq2seq model](https://dl.fbaipublicfiles.com/fairseq/data/seq2seq_stories.txt) and [fusion model](https://dl.fbaipublicfiles.com/fairseq/data/fusion_stories.txt) from [Fan et al., 2018](https://arxiv.org/abs/1805.04833). The corresponding prompts for the fusion model can be found [here](https://dl.fbaipublicfiles.com/fairseq/data/fusion_prompts.txt). Note that there are unk in the file, as we modeled a small full vocabulary (no BPE or pre-training). We did not use these unk prompts for human evaluation.
......
...@@ -30,6 +30,20 @@ def from_pretrained( ...@@ -30,6 +30,20 @@ def from_pretrained(
if data_name_or_path is not None and data_name_or_path in archive_map: if data_name_or_path is not None and data_name_or_path in archive_map:
data_name_or_path = archive_map[data_name_or_path] data_name_or_path = archive_map[data_name_or_path]
# allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
# for each model
if isinstance(model_name_or_path, dict):
for k, v in model_name_or_path.items():
if k == 'checkpoint_file':
checkpoint_file = v
elif (
k != 'path'
# only set kwargs that don't already have overrides
and k not in kwargs
):
kwargs[k] = v
model_name_or_path = model_name_or_path['path']
model_path = file_utils.load_archive_file(model_name_or_path) model_path = file_utils.load_archive_file(model_name_or_path)
# convenience hack for loading data and BPE codes from model archive # convenience hack for loading data and BPE codes from model archive
......
...@@ -43,10 +43,18 @@ class FConvModel(FairseqEncoderDecoderModel): ...@@ -43,10 +43,18 @@ class FConvModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
return { return {
'conv.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2', 'conv.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2'),
'conv.wmt14.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2', 'conv.wmt14.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2'),
'conv.wmt17.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2', 'conv.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2'),
} }
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import os
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -33,7 +34,18 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel): ...@@ -33,7 +34,18 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
return { return {
'conv.stories': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.bz2', 'conv.stories.pretrained': {
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
'checkpoint_file': 'pretrained_checkpoint.pt',
'tokenizer': 'nltk',
},
'conv.stories': {
'path': 'https://dl.fbaipublicfiles.com/fairseq/models/stories_checkpoint.tar.gz',
'checkpoint_file': 'fusion_checkpoint.pt',
'tokenizer': 'nltk',
'pretrained': 'True',
'pretrained_checkpoint': './pretrained_checkpoint.pt',
},
# Test set containing dictionaries # Test set containing dictionaries
'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2', 'data.stories': 'https://dl.fbaipublicfiles.com/fairseq/data/stories_test.tar.bz2',
} }
...@@ -97,6 +109,10 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel): ...@@ -97,6 +109,10 @@ class FConvModelSelfAtt(FairseqEncoderDecoderModel):
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
print("| loading pretrained model") print("| loading pretrained model")
if not os.path.exists(args.pretrained_checkpoint):
new_pretrained_checkpoint = os.path.join(args.data, args.pretrained_checkpoint)
if os.path.exists(new_pretrained_checkpoint):
args.pretrained_checkpoint = new_pretrained_checkpoint
trained_model = checkpoint_utils.load_model_ensemble( trained_model = checkpoint_utils.load_model_ensemble(
filenames=[args.pretrained_checkpoint], filenames=[args.pretrained_checkpoint],
task=task, task=task,
......
...@@ -53,18 +53,33 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -53,18 +53,33 @@ class TransformerModel(FairseqEncoderDecoderModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
# fmt: off # fmt: off
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
def moses_fastbpe(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'fastbpe',
}
return { return {
'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2', 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz', 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz', 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz', 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz', 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz', 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz', 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz', 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz', 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz', 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
} }
# fmt: on # fmt: on
......
...@@ -26,12 +26,20 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -26,12 +26,20 @@ class TransformerLanguageModel(FairseqLanguageModel):
@classmethod @classmethod
def hub_models(cls): def hub_models(cls):
def moses_fastbpe(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'fastbpe',
}
return { return {
'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2', 'transformer_lm.gbw.adaptive_huge': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2',
'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2', 'transformer_lm.wiki103.adaptive': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.tar.bz2',
'transformer_lm.wmt19.en': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2', 'transformer_lm.wmt19.en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2'),
'transformer_lm.wmt19.de': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2', 'transformer_lm.wmt19.de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2'),
'transformer_lm.wmt19.ru': 'https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2', 'transformer_lm.wmt19.ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2'),
} }
def __init__(self, decoder): def __init__(self, decoder):
......
...@@ -80,7 +80,7 @@ class LinearizedConvolution(ConvTBC): ...@@ -80,7 +80,7 @@ class LinearizedConvolution(ConvTBC):
kw = self.kernel_size[0] kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels) assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1) self._linearized_weight = torch.nn.Parameter(weight.view(self.out_channels, -1))
return self._linearized_weight return self._linearized_weight
def _clear_linearized_weight(self, *args): def _clear_linearized_weight(self, *args):
......
...@@ -104,6 +104,16 @@ if 'clean' in sys.argv[1:]: ...@@ -104,6 +104,16 @@ if 'clean' in sys.argv[1:]:
subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so'], shell=True) subprocess.run(['rm -f fairseq/*.so fairseq/**/*.so'], shell=True)
if 'test' in sys.argv[1:]:
try:
import fairseq.data.token_block_utils_fast
except (ImportError, ModuleNotFoundError):
raise Exception(
'Please install Cython components with `python setup.py build_ext --inplace`'
'before running unit tests.'
)
setup( setup(
name='fairseq', name='fairseq',
version='0.8.0', version='0.8.0',
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse import argparse
from multiprocessing import Manager
import random import random
import unittest import unittest
from multiprocessing import Manager
import torch import torch
import torch.nn as nn import torch.nn as nn
from fairseq import distributed_utils, optim from fairseq import distributed_utils, optim
...@@ -143,3 +149,7 @@ class TestBMUF(unittest.TestCase): ...@@ -143,3 +149,7 @@ class TestBMUF(unittest.TestCase):
def assertAlmostEqual(self, t1, t2): def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-4) self.assertLess((t1 - t2).abs().max(), 1e-4)
if __name__ == '__main__':
unittest.main()
...@@ -35,6 +35,7 @@ class TestMemoryEfficientFP16(unittest.TestCase): ...@@ -35,6 +35,7 @@ class TestMemoryEfficientFP16(unittest.TestCase):
fp16_scale_window=1, fp16_scale_window=1,
fp16_scale_tolerance=1, fp16_scale_tolerance=1,
threshold_loss_scale=1, threshold_loss_scale=1,
min_loss_scale=1e-4,
), ),
params, params,
optimizer, optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment