text_generation.py 17.8 KB
Newer Older
1
import enum
2
import warnings
3
from typing import Dict
4

5
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
6
from .base import Pipeline, build_pipeline_init_args
7
8


9
10
11
if is_torch_available():
    from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

12
13
14
if is_tf_available():
    import tensorflow as tf

15
16
    from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

17

18
19
20
21
22
23
class ReturnType(enum.Enum):
    TENSORS = 0
    NEW_TEXT = 1
    FULL_TEXT = 2


24
25
26
27
28
29
30
31
32
33
34
35
class Chat:
    """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
    to this format because the rest of the pipeline code tends to assume that lists of messages are
    actually a batch of samples rather than messages in the same conversation."""

    def __init__(self, messages: Dict):
        for message in messages:
            if not ("role" in message and "content" in message):
                raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
        self.messages = messages


36
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
37
38
class TextGenerationPipeline(Pipeline):
    """
39
    Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
40
41
    specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
    where each dict contains "role" and "content" keys.
42

43
44
45
46
47
    Example:

    ```python
    >>> from transformers import pipeline

48
    >>> generator = pipeline(model="openai-community/gpt2")
49
50
51
52
53
54
55
    >>> generator("I can't believe you did such a ", do_sample=False)
    [{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]

    >>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
    >>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
    ```

56
57
58
59
    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
    generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
    text generation parameters in [Text generation strategies](../generation_strategies) and [Text
    generation](text_generation).
60

Sylvain Gugger's avatar
Sylvain Gugger committed
61
62
    This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"text-generation"`.
63
64

    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
65
    objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models
66
    on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """

    # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e

    XL_PREFIX = """
    In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
    voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
    Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
    and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
    accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
    begging for his blessing. <eod> </s> <eos>
    """

83
    def __init__(self, *args, **kwargs):
84
        super().__init__(*args, **kwargs)
85
        self.check_model_type(
86
            TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
87
        )
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        if "prefix" not in self._preprocess_params:
            # This is very specific. The logic is quite complex and needs to be done
            # as a "default".
            # It also defines both some preprocess_kwargs and generate_kwargs
            # which is why we cannot put them in their respective methods.
            prefix = None
            if self.model.config.prefix is not None:
                prefix = self.model.config.prefix
            if prefix is None and self.model.__class__.__name__ in [
                "XLNetLMHeadModel",
                "TransfoXLLMHeadModel",
                "TFXLNetLMHeadModel",
                "TFTransfoXLLMHeadModel",
            ]:
                # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
                prefix = self.XL_PREFIX
            if prefix is not None:
                # Recalculate some generate_kwargs linked to prefix.
                preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
                self._preprocess_params = {**self._preprocess_params, **preprocess_params}
                self._forward_params = {**self._forward_params, **forward_params}

    def _sanitize_parameters(
        self,
        return_full_text=None,
        return_tensors=None,
        return_text=None,
        return_type=None,
        clean_up_tokenization_spaces=None,
        prefix=None,
118
        handle_long_generation=None,
119
        stop_sequence=None,
120
        add_special_tokens=False,
121
122
123
        truncation=None,
        padding=False,
        max_length=None,
124
        **generate_kwargs,
125
    ):
126
127
128
129
130
131
132
133
134
        preprocess_params = {
            "add_special_tokens": add_special_tokens,
            "truncation": truncation,
            "padding": padding,
            "max_length": max_length,
        }
        if max_length is not None:
            generate_kwargs["max_length"] = max_length

135
136
137
138
        if prefix is not None:
            preprocess_params["prefix"] = prefix
        if prefix:
            prefix_inputs = self.tokenizer(
139
                prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework
140
            )
141
            generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
142
143
144
145

        if handle_long_generation is not None:
            if handle_long_generation not in {"hole"}:
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
146
147
                    f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
                    " [None, 'hole']"
148
149
150
151
                )
            preprocess_params["handle_long_generation"] = handle_long_generation

        preprocess_params.update(generate_kwargs)
152
153
154
155
        forward_params = generate_kwargs

        postprocess_params = {}
        if return_full_text is not None and return_type is None:
156
157
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
Nicolas Patry's avatar
Nicolas Patry committed
158
159
            if return_tensors is not None:
                raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
160
161
            return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
        if return_tensors is not None and return_type is None:
Nicolas Patry's avatar
Nicolas Patry committed
162
163
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
164
165
166
167
168
169
            return_type = ReturnType.TENSORS
        if return_type is not None:
            postprocess_params["return_type"] = return_type
        if clean_up_tokenization_spaces is not None:
            postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces

170
171
172
173
174
175
176
177
178
        if stop_sequence is not None:
            stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
            if len(stop_sequence_ids) > 1:
                warnings.warn(
                    "Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
                    " the stop sequence will be used as the stop sequence string in the interim."
                )
            generate_kwargs["eos_token_id"] = stop_sequence_ids[0]

179
        return preprocess_params, forward_params, postprocess_params
180
181

    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
182
    def _parse_and_tokenize(self, *args, **kwargs):
183
184
185
186
187
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
188
            kwargs.update({"add_space_before_punct_symbol": True})
189

190
        return super()._parse_and_tokenize(*args, **kwargs)
191

192
    def __call__(self, text_inputs, **kwargs):
193
194
195
196
        """
        Complete the prompt(s) given as inputs.

        Args:
197
            args (`str` or `List[str]`):
198
                One or several prompts (or one list of prompts) to complete.
199
            return_tensors (`bool`, *optional*, defaults to `False`):
200
201
                Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
                `True`, the decoded text is not returned.
202
            return_text (`bool`, *optional*, defaults to `True`):
203
                Whether or not to return the decoded texts in the outputs.
204
            return_full_text (`bool`, *optional*, defaults to `True`):
205
                If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if
Sylvain Gugger's avatar
Sylvain Gugger committed
206
                *return_text* is set to True.
207
            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
208
                Whether or not to clean up the potential extra spaces in the text output.
209
            prefix (`str`, *optional*):
210
                Prefix added to prompt.
211
            handle_long_generation (`str`, *optional*):
212
213
214
215
216
                By default, this pipelines does not handle long generation (ones that exceed in one form or the other
                the model maximum length). There is no perfect way to adress this (more info
                :https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
                strategies to work around that problem depending on your use case.

217
218
                - `None` : default strategy where nothing in particular happens
                - `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
219
220
                  truncate a lot of the prompt and not suitable when generation exceed the model capacity)

221
222
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
223
                corresponding to your framework [here](./model#generative-models)).
224
225

        Return:
226
227
            A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
            of both `generated_text` and `generated_token_ids`):
228

229
            - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
Sylvain Gugger's avatar
Sylvain Gugger committed
230
231
            - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
              ids of the generated text.
232
        """
233
234
235
236
237
238
239
240
241
        if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
            # We have one or more prompts in list-of-dicts format, so this is chat mode
            if isinstance(text_inputs[0], dict):
                return super().__call__(Chat(text_inputs), **kwargs)
            else:
                chats = [Chat(chat) for chat in text_inputs]  # 馃悎 馃悎 馃悎
                return super().__call__(chats, **kwargs)
        else:
            return super().__call__(text_inputs, **kwargs)
242

243
    def preprocess(
244
245
246
247
248
249
250
251
252
        self,
        prompt_text,
        prefix="",
        handle_long_generation=None,
        add_special_tokens=False,
        truncation=None,
        padding=False,
        max_length=None,
        **generate_kwargs,
253
    ):
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        if isinstance(prompt_text, Chat):
            inputs = self.tokenizer.apply_chat_template(
                prompt_text.messages,
                truncation=truncation,
                padding=padding,
                max_length=max_length,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors=self.framework,
            )
        else:
            inputs = self.tokenizer(
                prefix + prompt_text,
                truncation=truncation,
                padding=padding,
                max_length=max_length,
                add_special_tokens=add_special_tokens,
                return_tensors=self.framework,
            )
273
        inputs["prompt_text"] = prompt_text
274
275
276
277
278
279
280
281
282
283
284
285
286

        if handle_long_generation == "hole":
            cur_len = inputs["input_ids"].shape[-1]
            if "max_new_tokens" in generate_kwargs:
                new_tokens = generate_kwargs["max_new_tokens"]
            else:
                new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
                if new_tokens < 0:
                    raise ValueError("We cannot infer how many new tokens are expected")
            if cur_len + new_tokens > self.tokenizer.model_max_length:
                keep_length = self.tokenizer.model_max_length - new_tokens
                if keep_length <= 0:
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
287
288
                        "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
                        " models max length"
289
290
291
292
293
294
                    )

                inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
                if "attention_mask" in inputs:
                    inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]

295
296
297
298
        return inputs

    def _forward(self, model_inputs, **generate_kwargs):
        input_ids = model_inputs["input_ids"]
299
        attention_mask = model_inputs.get("attention_mask", None)
300
301
302
        # Allow empty prompts
        if input_ids.shape[1] == 0:
            input_ids = None
303
            attention_mask = None
304
305
306
            in_b = 1
        else:
            in_b = input_ids.shape[0]
307
        prompt_text = model_inputs.pop("prompt_text")
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

        # If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
        # generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
        prefix_length = generate_kwargs.pop("prefix_length", 0)
        if prefix_length > 0:
            has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
                "generation_config" in generate_kwargs
                and generate_kwargs["generation_config"].max_new_tokens is not None
            )
            if not has_max_new_tokens:
                generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
                generate_kwargs["max_length"] += prefix_length
            has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
                "generation_config" in generate_kwargs
                and generate_kwargs["generation_config"].min_new_tokens is not None
            )
            if not has_min_new_tokens and "min_length" in generate_kwargs:
                generate_kwargs["min_length"] += prefix_length

327
328
        # BS x SL
        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
329
330
331
332
333
        out_b = generated_sequence.shape[0]
        if self.framework == "pt":
            generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
        elif self.framework == "tf":
            generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
334
335
336
        return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}

    def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
337
        generated_sequence = model_outputs["generated_sequence"][0]
338
339
340
        input_ids = model_outputs["input_ids"]
        prompt_text = model_outputs["prompt_text"]
        generated_sequence = generated_sequence.numpy().tolist()
341
342
343
        records = []
        for sequence in generated_sequence:
            if return_type == ReturnType.TENSORS:
344
                record = {"generated_token_ids": sequence}
345
346
            elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
                # Decode text
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                text = self.tokenizer.decode(
                    sequence,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                )

                # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
                if input_ids is None:
                    prompt_length = 0
                else:
                    prompt_length = len(
                        self.tokenizer.decode(
                            input_ids[0],
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
362
                        )
363
                    )
364

365
                all_text = text[prompt_length:]
366
                if return_type == ReturnType.FULL_TEXT:
367
368
369
370
                    if isinstance(prompt_text, str):
                        all_text = prompt_text + all_text
                    elif isinstance(prompt_text, Chat):
                        all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]
371

372
373
                record = {"generated_text": all_text}
            records.append(record)
374

375
        return records