run_generation.py 9.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python3
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
keskarnitish's avatar
keskarnitish committed
17
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18
"""
Aymeric Augustin's avatar
Aymeric Augustin committed
19

20
21
22
23
24

import argparse
import logging

import numpy as np
Aymeric Augustin's avatar
Aymeric Augustin committed
25
import torch
26

Aymeric Augustin's avatar
Aymeric Augustin committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import (
    CTRLLMHeadModel,
    CTRLTokenizer,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    OpenAIGPTLMHeadModel,
    OpenAIGPTTokenizer,
    TransfoXLLMHeadModel,
    TransfoXLTokenizer,
    XLMTokenizer,
    XLMWithLMHeadModel,
    XLNetLMHeadModel,
    XLNetTokenizer,
)
41
42


Rémi Louf's avatar
Rémi Louf committed
43
logging.basicConfig(
44
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
Rémi Louf's avatar
Rémi Louf committed
45
)
46
47
48
49
50
logger = logging.getLogger(__name__)

MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop

MODEL_CLASSES = {
Rémi Louf's avatar
Rémi Louf committed
51
52
53
54
55
56
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
    "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
    "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
    "xlm": (XLMWithLMHeadModel, XLMTokenizer),
57
58
59
60
61
}

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
62
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""


def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

80

Rémi Louf's avatar
Rémi Louf committed
81
82
83
84
85
86
87
#
# Functions to prepare models' input
#


def prepare_ctrl_input(args, _, tokenizer, prompt_text):
    if args.temperature > 0.7:
88
        logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
Rémi Louf's avatar
Rémi Louf committed
89
90
91

    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
    if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
92
        logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
93
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
94
95
96


def prepare_xlm_input(args, model, tokenizer, prompt_text):
97
    # kwargs = {"language": None, "mask_token_id": None}
Rémi Louf's avatar
Rémi Louf committed
98
99
100
101
102
103
104
105
106
107

    # Set the language
    use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
    if hasattr(model.config, "lang2id") and use_lang_emb:
        available_languages = model.config.lang2id.keys()
        if args.xlm_language in available_languages:
            language = args.xlm_language
        else:
            language = None
            while language not in available_languages:
108
                language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
109
110

        model.config.lang_id = model.config.lang2id[language]
111
        # kwargs["language"] = tokenizer.lang2id[language]
Rémi Louf's avatar
Rémi Louf committed
112

113
    # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
Rémi Louf's avatar
Rémi Louf committed
114
    # XLM masked-language modeling (MLM) models need masked token
115
116
117
    # is_xlm_mlm = "mlm" in args.model_name_or_path
    # if is_xlm_mlm:
    #     kwargs["mask_token_id"] = tokenizer.mask_token_id
Rémi Louf's avatar
Rémi Louf committed
118

119
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
120

121

Rémi Louf's avatar
Rémi Louf committed
122
123
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
    prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
124
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
125
126
127
128


def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
    prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
129
    return prompt_text
Rémi Louf's avatar
Rémi Louf committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147


PREPROCESSING_FUNCTIONS = {
    "ctrl": prepare_ctrl_input,
    "xlm": prepare_xlm_input,
    "xlnet": prepare_xlnet_input,
    "transfo-xl": prepare_transfoxl_input,
}


def adjust_length_to_model(length, max_sequence_length):
    if length < 0 and max_sequence_length > 0:
        length = max_sequence_length
    elif 0 < max_sequence_length < length:
        length = max_sequence_length  # No generation bigger than model size
    elif length < 0:
        length = MAX_LENGTH  # avoid infinite loop
    return length
148
149
150
151


def main():
    parser = argparse.ArgumentParser()
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
Rémi Louf's avatar
Rémi Louf committed
166

167
168
    parser.add_argument("--prompt", type=str, default="")
    parser.add_argument("--length", type=int, default=20)
Rémi Louf's avatar
Rémi Louf committed
169
170
    parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")

171
172
173
174
175
176
177
178
179
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
    )
    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
    )
Rémi Louf's avatar
Rémi Louf committed
180
181
182
183
184
185
186
187
    parser.add_argument("--k", type=int, default=0)
    parser.add_argument("--p", type=float, default=0.9)

    parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
    parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")

    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
188
    parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
189
190
    args = parser.parse_args()

191
    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
192
193
194
195
    args.n_gpu = torch.cuda.device_count()

    set_seed(args)

Rémi Louf's avatar
Rémi Louf committed
196
197
198
199
    # Initialize the model and tokenizer
    try:
        args.model_type = args.model_type.lower()
        model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
200
    except KeyError:
201
        raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
Rémi Louf's avatar
Rémi Louf committed
202

203
204
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path)
205
206
    model.to(args.device)

207
    args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
thomwolf's avatar
thomwolf committed
208
    logger.info(args)
209

Rémi Louf's avatar
Rémi Louf committed
210
211
212
213
214
215
    prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")

    # Different models need different input formatting and/or extra arguments
    requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
    if requires_preprocessing:
        prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
216
        preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
217
218
219
        encoded_prompt = tokenizer.encode(
            preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
        )
220
221
    else:
        encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
222
    encoded_prompt = encoded_prompt.to(args.device)
Rémi Louf's avatar
Rémi Louf committed
223

thomwolf's avatar
thomwolf committed
224
    output_sequences = model.generate(
thomwolf's avatar
thomwolf committed
225
        input_ids=encoded_prompt,
226
        max_length=args.length + len(encoded_prompt[0]),
Rémi Louf's avatar
Rémi Louf committed
227
        temperature=args.temperature,
thomwolf's avatar
thomwolf committed
228
229
        top_k=args.k,
        top_p=args.p,
Rémi Louf's avatar
Rémi Louf committed
230
        repetition_penalty=args.repetition_penalty,
Lysandre's avatar
Lysandre committed
231
        do_sample=True,
232
        num_return_sequences=args.num_return_sequences,
Rémi Louf's avatar
Rémi Louf committed
233
234
    )

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    # Remove the batch dimension when returning multiple sequences
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()

    generated_sequences = []

    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1))
        generated_sequence = generated_sequence.tolist()

        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        # Remove all text after the stop token
        text = text[: text.find(args.stop_token) if args.stop_token else None]

        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )
Rémi Louf's avatar
Rémi Louf committed
255

256
257
        generated_sequences.append(total_sequence)
        print(total_sequence)
258

259
    return generated_sequences
260
261


Rémi Louf's avatar
Rémi Louf committed
262
if __name__ == "__main__":
263
    main()