run_generation.py 16.8 KB
Newer Older
1
#!/usr/bin/env python
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
keskarnitish's avatar
keskarnitish committed
17
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

20
21

import argparse
22
import inspect
23
import logging
24
from typing import Tuple
25
26

import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
27
import torch
28

Aymeric Augustin's avatar
Aymeric Augustin committed
29
from transformers import (
30
31
32
    AutoTokenizer,
    BloomForCausalLM,
    BloomTokenizerFast,
Aymeric Augustin's avatar
Aymeric Augustin committed
33
34
    CTRLLMHeadModel,
    CTRLTokenizer,
35
    GenerationMixin,
Aymeric Augustin's avatar
Aymeric Augustin committed
36
37
    GPT2LMHeadModel,
    GPT2Tokenizer,
38
39
40
    GPTJForCausalLM,
    LlamaForCausalLM,
    LlamaTokenizer,
Aymeric Augustin's avatar
Aymeric Augustin committed
41
42
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
43
    OPTForCausalLM,
Aymeric Augustin's avatar
Aymeric Augustin committed
44
45
46
47
48
49
50
    TransfoXLLMHeadModel,
    TransfoXLTokenizer,
    XLMTokenizer,
    XLMWithLMHeadModel,
    XLNetLMHeadModel,
    XLNetTokenizer,
)
51
from transformers.modeling_outputs import CausalLMOutputWithPast
52
53


Rémi Louf's avatar
Rémi Louf committed
54
logging.basicConfig(
55
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Lysandre's avatar
Lysandre committed
56
57
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
Rémi Louf's avatar
Rémi Louf committed
58
)
59
60
61
62
63
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
Rémi Louf's avatar
Rémi Louf committed
64
65
66
67
68
69
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
    "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
    "xlm": (XLMWithLMHeadModel, XLMTokenizer),
70
71
72
73
    "gptj": (GPTJForCausalLM, AutoTokenizer),
    "bloom": (BloomForCausalLM, BloomTokenizerFast),
    "llama": (LlamaForCausalLM, LlamaTokenizer),
    "opt": (OPTForCausalLM, GPT2Tokenizer),
74
75
76
77
78
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
79
PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""


def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

97

Rémi Louf's avatar
Rémi Louf committed
98
99
100
101
102
103
104
#
# Functions to prepare models' input
#


def prepare_ctrl_input(args, _, tokenizer, prompt_text):
    if args.temperature > 0.7:
105
        logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
Rémi Louf's avatar
Rémi Louf committed
106
107
108

    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
    if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
109
        logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
110
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
111
112
113


def prepare_xlm_input(args, model, tokenizer, prompt_text):
114
    # kwargs = {"language": None, "mask_token_id": None}
Rémi Louf's avatar
Rémi Louf committed
115
116
117
118
119
120
121
122
123
124

    # Set the language
    use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
    if hasattr(model.config, "lang2id") and use_lang_emb:
        available_languages = model.config.lang2id.keys()
        if args.xlm_language in available_languages:
            language = args.xlm_language
        else:
            language = None
            while language not in available_languages:
125
                language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
126
127

        model.config.lang_id = model.config.lang2id[language]
128
        # kwargs["language"] = tokenizer.lang2id[language]
Rémi Louf's avatar
Rémi Louf committed
129

130
    # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
Rémi Louf's avatar
Rémi Louf committed
131
    # XLM masked-language modeling (MLM) models need masked token
132
133
134
    # is_xlm_mlm = "mlm" in args.model_name_or_path
    # if is_xlm_mlm:
    #     kwargs["mask_token_id"] = tokenizer.mask_token_id
Rémi Louf's avatar
Rémi Louf committed
135

136
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
137

138

Rémi Louf's avatar
Rémi Louf committed
139
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
140
141
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
142
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
143
144
145


def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
146
147
    prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
    prompt_text = prefix + prompt_text
148
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166


PREPROCESSING_FUNCTIONS = {
    "ctrl": prepare_ctrl_input,
    "xlm": prepare_xlm_input,
    "xlnet": prepare_xlnet_input,
    "transfo-xl": prepare_transfoxl_input,
}


def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length
167
168


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
def sparse_model_config(model_config):
    embedding_size = None
    if hasattr(model_config, "hidden_size"):
        embedding_size = model_config.hidden_size
    elif hasattr(model_config, "n_embed"):
        embedding_size = model_config.n_embed
    elif hasattr(model_config, "n_embd"):
        embedding_size = model_config.n_embd

    num_head = None
    if hasattr(model_config, "num_attention_heads"):
        num_head = model_config.num_attention_heads
    elif hasattr(model_config, "n_head"):
        num_head = model_config.n_head

    if embedding_size is None or num_head is None or num_head == 0:
        raise ValueError("Check the model config")

    num_embedding_size_per_head = int(embedding_size / num_head)
188
189
190
191
192
193
    if hasattr(model_config, "n_layer"):
        num_layer = model_config.n_layer
    elif hasattr(model_config, "num_hidden_layers"):
        num_layer = model_config.num_hidden_layers
    else:
        raise ValueError("Number of hidden layers couldn't be determined from the model config")
194
195
196
197

    return num_layer, num_head, num_embedding_size_per_head


198
def generate_past_key_values(model, batch_size, seq_len):
199
200
201
202
    num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
    if model.config.model_type == "bloom":
        past_key_values = tuple(
            (
203
204
                torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len)
                .to(model.dtype)
205
                .to(model.device),
206
207
                torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head)
                .to(model.dtype)
208
209
210
211
212
213
214
                .to(model.device),
            )
            for _ in range(num_block_layers)
        )
    else:
        past_key_values = tuple(
            (
215
216
                torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
                .to(model.dtype)
217
                .to(model.device),
218
219
                torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
                .to(model.dtype)
220
221
222
223
                .to(model.device),
            )
            for _ in range(num_block_layers)
        )
224
    return past_key_values
225

226
227
228
229
230
231
232

def prepare_jit_inputs(inputs, model, tokenizer):
    batch_size = len(inputs)
    dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
    dummy_input = dummy_input.to(model.device)
    if model.config.use_cache:
        dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1)
233
234
    dummy_input["attention_mask"] = torch.cat(
        [
235
236
237
            torch.zeros(dummy_input["attention_mask"].shape[0], 1)
            .to(dummy_input["attention_mask"].dtype)
            .to(model.device),
238
239
240
241
            dummy_input["attention_mask"],
        ],
        -1,
    )
242
    return dummy_input
243
244
245
246
247
248
249
250
251
252


class _ModelFallbackWrapper(GenerationMixin):
    __slots__ = ("_optimized", "_default")

    def __init__(self, optimized, default):
        self._optimized = optimized
        self._default = default

    def __call__(self, *args, **kwargs):
253
254
        if kwargs["past_key_values"] is None and self._default.config.use_cache:
            kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0)
255
        kwargs.pop("position_ids", None)
256
257
258
259
        for k in list(kwargs.keys()):
            if kwargs[k] is None or isinstance(kwargs[k], bool):
                kwargs.pop(k)
        outputs = self._optimized(**kwargs)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
        lm_logits = outputs[0]
        past_key_values = outputs[1]
        fixed_output = CausalLMOutputWithPast(
            loss=None,
            logits=lm_logits,
            past_key_values=past_key_values,
            hidden_states=None,
            attentions=None,
        )
        return fixed_output

    def __getattr__(self, item):
        return getattr(self._default, item)

    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
    ):
        return self._default.prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
        )

    def _reorder_cache(
        self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
        [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        """
        return self._default._reorder_cache(past_key_values, beam_idx)


292
293
def main():
    parser = argparse.ArgumentParser()
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
Rémi Louf's avatar
Rémi Louf committed
308

309
310
    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
Rémi Louf's avatar
Rémi Louf committed
311
312
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

313
314
315
316
317
318
319
320
321
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
    )
    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
    )
Rémi Louf's avatar
Rémi Louf committed
322
323
324
    parser.add_argument("--k", type=int, default=0)
    parser.add_argument("--p", type=float, default=0.9)

325
326
    parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
    parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
Rémi Louf's avatar
Rémi Louf committed
327
328
329
330
    parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
331
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
332
333
334
335
336
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
337
    parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
338
339
    args = parser.parse_args()

340
    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
341
    args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
342

343
    logger.warning(f"device: {args.device}, n_gpu: {args.n_gpu}, 16-bits training: {args.fp16}")
344

345
346
    set_seed(args)

Rémi Louf's avatar
Rémi Louf committed
347
348
349
350
    # Initialize the model and tokenizer
    try:
        args.model_type = args.model_type.lower()
        model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
351
    except KeyError:
352
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
Rémi Louf's avatar
Rémi Louf committed
353

354
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
355
356
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
357
    model = model_class.from_pretrained(args.model_name_or_path)
358
359
    model.to(args.device)

360
361
    if args.fp16:
        model.half()
362
363
    max_seq_length = getattr(model.config, "max_position_embeddings", 0)
    args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
thomwolf's avatar
thomwolf committed
364
    logger.info(args)
365

Rémi Louf's avatar
Rémi Louf committed
366
367
368
369
370
371
    prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

    # Different models need different input formatting and/or extra arguments
    requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
    if requires_preprocessing:
        prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
372
        preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
373
374
375
376
377
378

        if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
            tokenizer_kwargs = {"add_space_before_punct_symbol": True}
        else:
            tokenizer_kwargs = {}

379
        encoded_prompt = tokenizer.encode(
380
            preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
381
        )
382
    else:
383
384
        prefix = args.prefix if args.prefix else args.padding_text
        encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
385
    encoded_prompt = encoded_prompt.to(args.device)
Rémi Louf's avatar
Rémi Louf committed
386

387
388
389
390
391
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt

392
    if args.jit:
393
        jit_input_texts = ["enable jit"]
394
395
396
        jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
        torch._C._jit_set_texpr_fuser_enabled(False)
        model.config.return_dict = False
397
398
399
400
401
        if hasattr(model, "forward"):
            sig = inspect.signature(model.forward)
        else:
            sig = inspect.signature(model.__call__)
        jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
402
403
404
405
406
407
408
        traced_model = torch.jit.trace(model, jit_inputs, strict=False)
        traced_model = torch.jit.freeze(traced_model.eval())
        traced_model(*jit_inputs)
        traced_model(*jit_inputs)

        model = _ModelFallbackWrapper(traced_model, model)

thomwolf's avatar
thomwolf committed
409
    output_sequences = model.generate(
410
        input_ids=input_ids,
411
        max_length=args.length + len(encoded_prompt[0]),
Rémi Louf's avatar
Rémi Louf committed
412
        temperature=args.temperature,
thomwolf's avatar
thomwolf committed
413
414
        top_k=args.k,
        top_p=args.p,
Rémi Louf's avatar
Rémi Louf committed
415
        repetition_penalty=args.repetition_penalty,
Lysandre's avatar
Lysandre committed
416
        do_sample=True,
417
        num_return_sequences=args.num_return_sequences,
Rémi Louf's avatar
Rémi Louf committed
418
419
    )

420
421
422
423
424
425
426
    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
427
        print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
428
429
430
431
432
433
434
435
436
437
438
439
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(args.stop_token) if args.stop_token else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )
Rémi Louf's avatar
Rémi Louf committed
440

441
442
        generated_sequences.append(total_sequence)
        print(total_sequence)
443

444
    return generated_sequences
445
446


Rémi Louf's avatar
Rémi Louf committed
447
if __name__ == "__main__":
448
    main()