"notebooks/conditioning_sht.ipynb" did not exist on "c68a4ed5ed0fa67a9e5ad7853b00a1eb57d8848a"
translation.py 8.26 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Alexei Baevski's avatar
Alexei Baevski committed
8
import itertools
Myle Ott's avatar
Myle Ott committed
9
10
import os

11
from fairseq import options, utils
Myle Ott's avatar
Myle Ott committed
12
from fairseq.data import (
Myle Ott's avatar
Myle Ott committed
13
14
15
16
17
18
19
    ConcatDataset,
    data_utils,
    Dictionary,
    IndexedCachedDataset,
    IndexedDataset,
    IndexedRawTextDataset,
    LanguagePairDataset,
Myle Ott's avatar
Myle Ott committed
20
21
22
23
24
25
26
)

from . import FairseqTask, register_task


@register_task('translation')
class TranslationTask(FairseqTask):
Myle Ott's avatar
Myle Ott committed
27
28
29
30
31
32
33
34
35
    """
    Translate from one (source) language to another (target) language.

    Args:
        src_dict (Dictionary): dictionary for the source language
        tgt_dict (Dictionary): dictionary for the target language

    .. note::

Myle Ott's avatar
Myle Ott committed
36
37
        The translation task is compatible with :mod:`fairseq-train`,
        :mod:`fairseq-generate` and :mod:`fairseq-interactive`.
Myle Ott's avatar
Myle Ott committed
38
39
40
41
42
43
44
45

    The translation task provides the following additional command-line
    arguments:

    .. argparse::
        :ref: fairseq.tasks.translation_parser
        :prog:
    """
Myle Ott's avatar
Myle Ott committed
46
47
48
49

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
50
        # fmt: off
Sergey Edunov's avatar
Sergey Edunov committed
51
        parser.add_argument('data', nargs='+', help='path(s) to data directorie(s)')
Myle Ott's avatar
Myle Ott committed
52
53
54
55
        parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
                            help='source language')
        parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
                            help='target language')
Myle Ott's avatar
Myle Ott committed
56
57
        parser.add_argument('--lazy-load', action='store_true',
                            help='load the dataset lazily')
Myle Ott's avatar
Myle Ott committed
58
59
60
        parser.add_argument('--raw-text', action='store_true',
                            help='load raw text dataset')
        parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
Myle Ott's avatar
Myle Ott committed
61
                            help='pad the source on the left')
Myle Ott's avatar
Myle Ott committed
62
        parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
Myle Ott's avatar
Myle Ott committed
63
                            help='pad the target on the left')
Myle Ott's avatar
Myle Ott committed
64
65
66
67
        parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the source sequence')
        parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
                            help='max number of tokens in the target sequence')
Myle Ott's avatar
Myle Ott committed
68
69
        parser.add_argument('--upsample-primary', default=1, type=int,
                            help='amount to upsample primary dataset')
70
        # fmt: on
Myle Ott's avatar
Myle Ott committed
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    @staticmethod
    def load_pretrained_model(path, src_dict_path, tgt_dict_path, arg_overrides=None):
        model = utils.load_checkpoint_to_cpu(path)
        args = model['args']
        state_dict = model['model']
        args = utils.override_model_args(args, arg_overrides)
        src_dict = Dictionary.load(src_dict_path)
        tgt_dict = Dictionary.load(tgt_dict_path)
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()

        task = TranslationTask(args, src_dict, tgt_dict)
        model = task.build_model(args)
        model.upgrade_state_dict(state_dict)
        model.load_state_dict(state_dict, strict=True)
        return model

Myle Ott's avatar
Myle Ott committed
90
91
92
93
94
95
96
    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args)
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict

    @classmethod
    def setup_task(cls, args, **kwargs):
Myle Ott's avatar
Myle Ott committed
97
98
99
100
101
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
Myle Ott's avatar
Myle Ott committed
102
103
104
105
106
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
Sergey Edunov's avatar
Sergey Edunov committed
107
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
Myle Ott's avatar
Myle Ott committed
108
109
110
111
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
112
113
        src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
Myle Ott's avatar
Myle Ott committed
114
115
116
117
118
119
120
121
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)

Peng-Jen Chen's avatar
Peng-Jen Chen committed
122
    def load_dataset(self, split, combine=False, **kwargs):
Myle Ott's avatar
Myle Ott committed
123
124
125
126
127
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
Myle Ott's avatar
Myle Ott committed
128

Sergey Edunov's avatar
Sergey Edunov committed
129
130
        def split_exists(split, src, tgt, lang, data_path):
            filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
Myle Ott's avatar
Myle Ott committed
131
132
            if self.args.raw_text and IndexedRawTextDataset.exists(filename):
                return True
Myle Ott's avatar
Myle Ott committed
133
            elif not self.args.raw_text and IndexedDataset.exists(filename):
Myle Ott's avatar
Myle Ott committed
134
135
136
137
138
139
                return True
            return False

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                return IndexedRawTextDataset(path, dictionary)
Myle Ott's avatar
Myle Ott committed
140
            elif IndexedDataset.exists(path):
Myle Ott's avatar
Myle Ott committed
141
142
143
144
                if self.args.lazy_load:
                    return IndexedDataset(path, fix_lua_indexing=True)
                else:
                    return IndexedCachedDataset(path, fix_lua_indexing=True)
Myle Ott's avatar
Myle Ott committed
145
146
            return None

Alexei Baevski's avatar
Alexei Baevski committed
147
148
149
        src_datasets = []
        tgt_datasets = []

Sergey Edunov's avatar
Sergey Edunov committed
150
151
        data_paths = self.args.data

Myle Ott's avatar
Myle Ott committed
152
        for dk, data_path in enumerate(data_paths):
Sergey Edunov's avatar
Sergey Edunov committed
153
154
155
156
157
158
159
160
161
            for k in itertools.count():
                split_k = split + (str(k) if k > 0 else '')

                # infer langcode
                src, tgt = self.args.source_lang, self.args.target_lang
                if split_exists(split_k, src, tgt, src, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, src, tgt))
                elif split_exists(split_k, tgt, src, src, data_path):
                    prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
Alexei Baevski's avatar
Alexei Baevski committed
162
                else:
Myle Ott's avatar
Myle Ott committed
163
                    if k > 0 or dk > 0:
Sergey Edunov's avatar
Sergey Edunov committed
164
165
166
167
168
169
170
171
172
173
174
                        break
                    else:
                        raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))

                src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
                tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))

                print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))

                if not combine:
                    break
Alexei Baevski's avatar
Alexei Baevski committed
175
176
177
178
179
180

        assert len(src_datasets) == len(tgt_datasets)

        if len(src_datasets) == 1:
            src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
        else:
Myle Ott's avatar
Myle Ott committed
181
182
183
184
            sample_ratios = [1] * len(src_datasets)
            sample_ratios[0] = self.args.upsample_primary
            src_dataset = ConcatDataset(src_datasets, sample_ratios)
            tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
Alexei Baevski's avatar
Alexei Baevski committed
185

Myle Ott's avatar
Myle Ott committed
186
        self.datasets[split] = LanguagePairDataset(
Myle Ott's avatar
Myle Ott committed
187
188
            src_dataset, src_dataset.sizes, self.src_dict,
            tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
Myle Ott's avatar
Myle Ott committed
189
190
191
192
193
194
            left_pad_source=self.args.left_pad_source,
            left_pad_target=self.args.left_pad_target,
            max_source_positions=self.args.max_source_positions,
            max_target_positions=self.args.max_target_positions,
        )

195
    def max_positions(self):
Myle Ott's avatar
Myle Ott committed
196
        """Return the max sentence length allowed by the task."""
197
198
        return (self.args.max_source_positions, self.args.max_target_positions)

Myle Ott's avatar
Myle Ott committed
199
200
    @property
    def source_dictionary(self):
Myle Ott's avatar
Myle Ott committed
201
        """Return the source :class:`~fairseq.data.Dictionary`."""
Myle Ott's avatar
Myle Ott committed
202
203
204
205
        return self.src_dict

    @property
    def target_dictionary(self):
Myle Ott's avatar
Myle Ott committed
206
        """Return the target :class:`~fairseq.data.Dictionary`."""
Myle Ott's avatar
Myle Ott committed
207
        return self.tgt_dict