run_generation.py 16 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
keskarnitish's avatar
keskarnitish committed
17
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

20
21
22

import argparse
import logging
23
from typing import Tuple
24
25

import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
26
import torch
27

Aymeric Augustin's avatar
Aymeric Augustin committed
28
29
30
from transformers import (
    CTRLLMHeadModel,
    CTRLTokenizer,
31
    GenerationMixin,
Aymeric Augustin's avatar
Aymeric Augustin committed
32
33
34
35
36
37
38
39
40
41
42
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    TransfoXLLMHeadModel,
    TransfoXLTokenizer,
    XLMTokenizer,
    XLMWithLMHeadModel,
    XLNetLMHeadModel,
    XLNetTokenizer,
)
43
from transformers.modeling_outputs import CausalLMOutputWithPast
44
45


Rémi Louf's avatar
Rémi Louf committed
46
logging.basicConfig(
47
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Lysandre's avatar
Lysandre committed
48
49
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
Rémi Louf's avatar
Rémi Louf committed
50
)
51
52
53
54
55
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
Rémi Louf's avatar
Rémi Louf committed
56
57
58
59
60
61
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
    "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
    "xlm": (XLMWithLMHeadModel, XLMTokenizer),
62
63
64
65
66
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
67
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""


def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

85

Rémi Louf's avatar
Rémi Louf committed
86
87
88
89
90
91
92
#
# Functions to prepare models' input
#


def prepare_ctrl_input(args, _, tokenizer, prompt_text):
    if args.temperature > 0.7:
93
        logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
Rémi Louf's avatar
Rémi Louf committed
94
95
96

    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
    if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
97
        logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
98
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
99
100
101


def prepare_xlm_input(args, model, tokenizer, prompt_text):
102
    # kwargs = {"language": None, "mask_token_id": None}
Rémi Louf's avatar
Rémi Louf committed
103
104
105
106
107
108
109
110
111
112

    # Set the language
    use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
    if hasattr(model.config, "lang2id") and use_lang_emb:
        available_languages = model.config.lang2id.keys()
        if args.xlm_language in available_languages:
            language = args.xlm_language
        else:
            language = None
            while language not in available_languages:
113
                language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
114
115

        model.config.lang_id = model.config.lang2id[language]
116
        # kwargs["language"] = tokenizer.lang2id[language]
Rémi Louf's avatar
Rémi Louf committed
117

118
    # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
Rémi Louf's avatar
Rémi Louf committed
119
    # XLM masked-language modeling (MLM) models need masked token
120
121
122
    # is_xlm_mlm = "mlm" in args.model_name_or_path
    # if is_xlm_mlm:
    #     kwargs["mask_token_id"] = tokenizer.mask_token_id
Rémi Louf's avatar
Rémi Louf committed
123

124
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
125

126

Rémi Louf's avatar
Rémi Louf committed
127
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
128
129
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
130
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
131
132
133


def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
134
135
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
136
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154


PREPROCESSING_FUNCTIONS = {
    "ctrl": prepare_ctrl_input,
    "xlm": prepare_xlm_input,
    "xlnet": prepare_xlnet_input,
    "transfo-xl": prepare_transfoxl_input,
}


def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length
155
156


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def sparse_model_config(model_config):
    embedding_size = None
    if hasattr(model_config, "hidden_size"):
        embedding_size = model_config.hidden_size
    elif hasattr(model_config, "n_embed"):
        embedding_size = model_config.n_embed
    elif hasattr(model_config, "n_embd"):
        embedding_size = model_config.n_embd

    num_head = None
    if hasattr(model_config, "num_attention_heads"):
        num_head = model_config.num_attention_heads
    elif hasattr(model_config, "n_head"):
        num_head = model_config.n_head

    if embedding_size is None or num_head is None or num_head == 0:
        raise ValueError("Check the model config")

    num_embedding_size_per_head = int(embedding_size / num_head)
    num_layer = model_config.n_layer

    return num_layer, num_head, num_embedding_size_per_head


def prepare_jit_inputs(inputs, model, tokenizer):
    num_batch = len(inputs)
    dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True)
    num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
    if model.config.model_type == "bloom":
        past_key_values = tuple(
            (
                torch.zeros(int(num_attention_heads * num_batch), num_embedding_size_per_head, 1)
                .to(model.config.torch_dtype)
                .to(model.device),
                torch.zeros(int(num_attention_heads * num_batch), 1, num_embedding_size_per_head)
                .to(model.config.torch_dtype)
                .to(model.device),
            )
            for _ in range(num_block_layers)
        )
    else:
        past_key_values = tuple(
            (
                torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
                .to(model.config.torch_dtype)
                .to(model.device),
                torch.zeros(num_batch, num_attention_heads, 1, num_embedding_size_per_head)
                .to(model.config.torch_dtype)
                .to(model.device),
            )
            for _ in range(num_block_layers)
        )

    dummy_input["attention_mask"] = torch.cat(
        [
            torch.zeros(dummy_input["attention_mask"].shape[0], 1).to(dummy_input["attention_mask"].dtype),
            dummy_input["attention_mask"],
        ],
        -1,
    )

    if model.config.use_cache:
        jit_inputs = (
            dummy_input["input_ids"].to(model.device),
            past_key_values,
            dummy_input["attention_mask"].to(model.device),
        )
    else:
        jit_inputs = (
            dummy_input["input_ids"].to(model.device),
            dummy_input["attention_mask"].to(model.device),
        )

    return jit_inputs


class _ModelFallbackWrapper(GenerationMixin):
    __slots__ = ("_optimized", "_default")

    def __init__(self, optimized, default):
        self._optimized = optimized
        self._default = default

    def __call__(self, *args, **kwargs):
        if kwargs["past_key_values"] is None:
            return self._default(*args, **kwargs)
        trace_graph_inputs = []
        kwargs.pop("position_ids", None)
        for k, v in kwargs.items():
            if v is not None and not isinstance(v, bool):
                trace_graph_inputs.append(v)
        trace_graph_inputs = tuple(trace_graph_inputs)
        outputs = self._optimized(*trace_graph_inputs)
        lm_logits = outputs[0]
        past_key_values = outputs[1]
        fixed_output = CausalLMOutputWithPast(
            loss=None,
            logits=lm_logits,
            past_key_values=past_key_values,
            hidden_states=None,
            attentions=None,
        )
        return fixed_output

    def __getattr__(self, item):
        return getattr(self._default, item)

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
    ):
        return self._default.prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
        )

    def _reorder_cache(
        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
        [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        """
        return self._default._reorder_cache(past_key_values, beam_idx)


282
283
def main():
    parser = argparse.ArgumentParser()
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
Rémi Louf's avatar
Rémi Louf committed
298

299
300
    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
Rémi Louf's avatar
Rémi Louf committed
301
302
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

303
304
305
306
307
308
309
310
311
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
    )
    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
    )
Rémi Louf's avatar
Rémi Louf committed
312
313
314
    parser.add_argument("--k", type=int, default=0)
    parser.add_argument("--p", type=float, default=0.9)

315
316
    parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
    parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
Rémi Louf's avatar
Rémi Louf committed
317
318
319
320
    parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
321
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
322
323
324
325
326
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
327
328
329
    parser.add_argument(
        "--jit", type=bool, default=False, help="Whether or not to use jit trace to accelerate inference"
    )
330
331
    args = parser.parse_args()

332
    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
333
    args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
334

335
    logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
336

337
338
    set_seed(args)

Rémi Louf's avatar
Rémi Louf committed
339
340
341
342
    # Initialize the model and tokenizer
    try:
        args.model_type = args.model_type.lower()
        model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
343
    except KeyError:
344
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
Rémi Louf's avatar
Rémi Louf committed
345

346
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
347
348
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
349
    model = model_class.from_pretrained(args.model_name_or_path)
350
351
    model.to(args.device)

352
353
354
    if args.fp16:
        model.half()

355
    args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
thomwolf's avatar
thomwolf committed
356
    logger.info(args)
357

Rémi Louf's avatar
Rémi Louf committed
358
359
360
361
362
363
    prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

    # Different models need different input formatting and/or extra arguments
    requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
    if requires_preprocessing:
        prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
364
        preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
365
366
367
368
369
370

        if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}

371
        encoded_prompt = tokenizer.encode(
372
            preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
373
        )
374
    else:
375
376
        prefix = args.prefix if args.prefix else args.padding_text
        encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
377
    encoded_prompt = encoded_prompt.to(args.device)
Rémi Louf's avatar
Rémi Louf committed
378

379
380
381
382
383
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt

384
385
386
387
388
389
390
391
392
393
394
395
    if args.jit:
        jit_input_texts = ["jit"]
        jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
        torch._C._jit_set_texpr_fuser_enabled(False)
        model.config.return_dict = False
        traced_model = torch.jit.trace(model, jit_inputs, strict=False)
        traced_model = torch.jit.freeze(traced_model.eval())
        traced_model(*jit_inputs)
        traced_model(*jit_inputs)

        model = _ModelFallbackWrapper(traced_model, model)

thomwolf's avatar
thomwolf committed
396
    output_sequences = model.generate(
397
        input_ids=input_ids,
398
        max_length=args.length + len(encoded_prompt[0]),
Rémi Louf's avatar
Rémi Louf committed
399
        temperature=args.temperature,
thomwolf's avatar
thomwolf committed
400
401
        top_k=args.k,
        top_p=args.p,
Rémi Louf's avatar
Rémi Louf committed
402
        repetition_penalty=args.repetition_penalty,
Lysandre's avatar
Lysandre committed
403
        do_sample=True,
404
        num_return_sequences=args.num_return_sequences,
Rémi Louf's avatar
Rémi Louf committed
405
406
    )

407
408
409
410
411
412
413
    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
414
        print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
415
416
417
418
419
420
421
422
423
424
425
426
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(args.stop_token) if args.stop_token else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )
Rémi Louf's avatar
Rémi Louf committed
427

428
429
        generated_sequences.append(total_sequence)
        print(total_sequence)
430

431
    return generated_sequences
432
433


Rémi Louf's avatar
Rémi Louf committed
434
if __name__ == "__main__":
435
    main()