translation.py 8.01 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Alexei Baevski's avatar
Alexei Baevski committed
8
9
import itertools
import numpy as np
Myle Ott's avatar
Myle Ott committed
10
11
import os

12
from fairseq import options, utils
Myle Ott's avatar
Myle Ott committed
13
from fairseq.data import (
Myle Ott's avatar
Myle Ott committed
14
15
    data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
    IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
Myle Ott's avatar
Myle Ott committed
16
17
18
19
20
21
22
)

from . import FairseqTask, register_task


@register_task('translation')
class TranslationTask(FairseqTask):
Myle Ott's avatar
Myle Ott committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    """
    Translate from one (source) language to another (target) language.

    Args:
        src_dict (Dictionary): dictionary for the source language
        tgt_dict (Dictionary): dictionary for the target language

    .. note::

        The translation task is compatible with :mod:`train.py <train>`,
        :mod:`generate.py <generate>` and :mod:`interactive.py <interactive>`.

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """
Myle Ott's avatar
Myle Ott committed
42
43
44
45

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
46
        # fmt: off
Sergey Edunov's avatar
Sergey Edunov committed
47
        parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
Myle Ott's avatar
Myle Ott committed
48
49
50
51
52
53
54
        parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
                            help='source language')
        parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
                            help='target language')
        parser.add_argument('--raw-text', action='store_true',
                            help='load raw text dataset')
        parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
Myle Ott's avatar
Myle Ott committed
55
                            help='pad the source on the left')
Myle Ott's avatar
Myle Ott committed
56
        parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
Myle Ott's avatar
Myle Ott committed
57
                            help='pad the target on the left')
Myle Ott's avatar
Myle Ott committed
58
59
60
61
        parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the source sequence')
        parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the target sequence')
Myle Ott's avatar
Myle Ott committed
62
63
        parser.add_argument('--upsample-primary', default=1, type=int,
                            help='amount to upsample primary dataset')
64
        # fmt: on
Myle Ott's avatar
Myle Ott committed
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    @staticmethod
    def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        src_dict = Dictionary.load(src_dict_path)
        tgt_dict = Dictionary.load(tgt_dict_path)
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = TranslationTask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model

Myle Ott's avatar
Myle Ott committed
84
85
86
87
88
89
90
    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args)
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict

    @classmethod
    def setup_task(cls, args, **kwargs):
Myle Ott's avatar
Myle Ott committed
91
92
93
94
95
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
Myle Ott's avatar
Myle Ott committed
96
97
98
99
100
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
Sergey Edunov's avatar
Sergey Edunov committed
101
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
Myle Ott's avatar
Myle Ott committed
102
103
104
105
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
Sergey Edunov's avatar
Sergey Edunov committed
106
107
        src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
Myle Ott's avatar
Myle Ott committed
108
109
110
111
112
113
114
115
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)

Peng-Jen Chen's avatar
Peng-Jen Chen committed
116
    def load_dataset(self, split, combine=False, **kwargs):
Myle Ott's avatar
Myle Ott committed
117
118
119
120
121
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
Myle Ott's avatar
Myle Ott committed
122

Sergey Edunov's avatar
Sergey Edunov committed
123
124
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
Myle Ott's avatar
Myle Ott committed
125
126
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
Myle Ott's avatar
Myle Ott committed
127
            elif not self.args.raw_text and IndexedDataset.exists(filename):
Myle Ott's avatar
Myle Ott committed
128
129
130
131
132
133
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
Myle Ott's avatar
Myle Ott committed
134
135
            elif IndexedDataset.exists(path):
                return IndexedCachedDataset(path, fix_lua_indexing=True)
Myle Ott's avatar
Myle Ott committed
136
137
            return None

Alexei Baevski's avatar
Alexei Baevski committed
138
139
140
        src_datasets = []
        tgt_datasets = []

Sergey Edunov's avatar
Sergey Edunov committed
141
142
        data_paths = self.args.data

Myle Ott's avatar
Myle Ott committed
143
        for dk, data_path in enumerate(data_paths):
Sergey Edunov's avatar
Sergey Edunov committed
144
145
146
147
148
149
150
151
152
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
Alexei Baevski's avatar
Alexei Baevski committed
153
                else:
Myle Ott's avatar
Myle Ott committed
154
                    if k > 0 or dk > 0:
Sergey Edunov's avatar
Sergey Edunov committed
155
156
157
158
159
160
161
162
163
164
165
                        break
                    else:
                        raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

                src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))

                if not combine:
                    break
Alexei Baevski's avatar
Alexei Baevski committed
166
167
168
169
170
171

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
Myle Ott's avatar
Myle Ott committed
172
173
174
175
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
Alexei Baevski's avatar
Alexei Baevski committed
176

Myle Ott's avatar
Myle Ott committed
177
        self.datasets[split] = LanguagePairDataset(
Myle Ott's avatar
Myle Ott committed
178
179
            src_dataset, src_dataset.sizes, self.src_dict,
            tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
Myle Ott's avatar
Myle Ott committed
180
181
182
183
184
185
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )

186
    def max_positions(self):
Myle Ott's avatar
Myle Ott committed
187
        """Return the max sentence length allowed by the task."""
188
189
        return (self.args.max_source_positions, self.args.max_target_positions)

Myle Ott's avatar
Myle Ott committed
190
191
    @property
    def source_dictionary(self):
Myle Ott's avatar
Myle Ott committed
192
        """Return the source :class:`~fairseq.data.Dictionary`."""
Myle Ott's avatar
Myle Ott committed
193
194
195
196
        return self.src_dict

    @property
    def target_dictionary(self):
Myle Ott's avatar
Myle Ott committed
197
        """Return the target :class:`~fairseq.data.Dictionary`."""
Myle Ott's avatar
Myle Ott committed
198
        return self.tgt_dict