text_generation.py 15.9 KB
Newer Older
1
import enum
2
import warnings
3

4
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
5
from .base import Pipeline, build_pipeline_init_args
6
7


8
9
10
if is_torch_available():
    from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

11
12
13
if is_tf_available():
    import tensorflow as tf

14
15
    from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

16

17
18
19
20
21
22
class ReturnType(enum.Enum):
    TENSORS = 0
    NEW_TEXT = 1
    FULL_TEXT = 2


23
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
24
25
class TextGenerationPipeline(Pipeline):
    """
26
    Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
27
28
    specified text prompt.

29
30
31
32
33
    Example:

    ```python
    >>> from transformers import pipeline

34
    >>> generator = pipeline(model="openai-community/gpt2")
35
36
37
38
39
40
41
    >>> generator("I can't believe you did such a ", do_sample=False)
    [{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]

    >>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
    >>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
    ```

42
43
44
45
    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
    generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
    text generation parameters in [Text generation strategies](../generation_strategies) and [Text
    generation](text_generation).
46

Sylvain Gugger's avatar
Sylvain Gugger committed
47
48
    This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"text-generation"`.
49
50

    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
51
    objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models
52
    on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    """

    # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e

    XL_PREFIX = """
    In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
    voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
    Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
    and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
    accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
    begging for his blessing. <eod> </s> <eos>
    """

69
    def __init__(self, *args, **kwargs):
70
        super().__init__(*args, **kwargs)
71
        self.check_model_type(
72
            TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
73
        )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        if "prefix" not in self._preprocess_params:
            # This is very specific. The logic is quite complex and needs to be done
            # as a "default".
            # It also defines both some preprocess_kwargs and generate_kwargs
            # which is why we cannot put them in their respective methods.
            prefix = None
            if self.model.config.prefix is not None:
                prefix = self.model.config.prefix
            if prefix is None and self.model.__class__.__name__ in [
                "XLNetLMHeadModel",
                "TransfoXLLMHeadModel",
                "TFXLNetLMHeadModel",
                "TFTransfoXLLMHeadModel",
            ]:
                # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
                prefix = self.XL_PREFIX
            if prefix is not None:
                # Recalculate some generate_kwargs linked to prefix.
                preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
                self._preprocess_params = {**self._preprocess_params, **preprocess_params}
                self._forward_params = {**self._forward_params, **forward_params}

    def _sanitize_parameters(
        self,
        return_full_text=None,
        return_tensors=None,
        return_text=None,
        return_type=None,
        clean_up_tokenization_spaces=None,
        prefix=None,
104
        handle_long_generation=None,
105
        stop_sequence=None,
106
        add_special_tokens=False,
107
108
109
        truncation=None,
        padding=False,
        max_length=None,
110
        **generate_kwargs,
111
    ):
112
113
114
115
116
117
118
119
120
        preprocess_params = {
            "add_special_tokens": add_special_tokens,
            "truncation": truncation,
            "padding": padding,
            "max_length": max_length,
        }
        if max_length is not None:
            generate_kwargs["max_length"] = max_length

121
122
123
124
        if prefix is not None:
            preprocess_params["prefix"] = prefix
        if prefix:
            prefix_inputs = self.tokenizer(
125
                prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework
126
            )
127
            generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
128
129
130
131

        if handle_long_generation is not None:
            if handle_long_generation not in {"hole"}:
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
132
133
                    f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
                    " [None, 'hole']"
134
135
136
137
                )
            preprocess_params["handle_long_generation"] = handle_long_generation

        preprocess_params.update(generate_kwargs)
138
139
140
141
        forward_params = generate_kwargs

        postprocess_params = {}
        if return_full_text is not None and return_type is None:
142
143
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
Nicolas Patry's avatar
Nicolas Patry committed
144
145
            if return_tensors is not None:
                raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
146
147
            return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
        if return_tensors is not None and return_type is None:
Nicolas Patry's avatar
Nicolas Patry committed
148
149
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
150
151
152
153
154
155
            return_type = ReturnType.TENSORS
        if return_type is not None:
            postprocess_params["return_type"] = return_type
        if clean_up_tokenization_spaces is not None:
            postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces

156
157
158
159
160
161
162
163
164
        if stop_sequence is not None:
            stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
            if len(stop_sequence_ids) > 1:
                warnings.warn(
                    "Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
                    " the stop sequence will be used as the stop sequence string in the interim."
                )
            generate_kwargs["eos_token_id"] = stop_sequence_ids[0]

165
        return preprocess_params, forward_params, postprocess_params
166
167

    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
168
    def _parse_and_tokenize(self, *args, **kwargs):
169
170
171
172
173
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
174
            kwargs.update({"add_space_before_punct_symbol": True})
175

176
        return super()._parse_and_tokenize(*args, **kwargs)
177

178
    def __call__(self, text_inputs, **kwargs):
179
180
181
182
        """
        Complete the prompt(s) given as inputs.

        Args:
183
            args (`str` or `List[str]`):
184
                One or several prompts (or one list of prompts) to complete.
185
            return_tensors (`bool`, *optional*, defaults to `False`):
186
187
                Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
                `True`, the decoded text is not returned.
188
            return_text (`bool`, *optional*, defaults to `True`):
189
                Whether or not to return the decoded texts in the outputs.
190
            return_full_text (`bool`, *optional*, defaults to `True`):
191
                If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if
Sylvain Gugger's avatar
Sylvain Gugger committed
192
                *return_text* is set to True.
193
            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
194
                Whether or not to clean up the potential extra spaces in the text output.
195
            prefix (`str`, *optional*):
196
                Prefix added to prompt.
197
            handle_long_generation (`str`, *optional*):
198
199
200
201
202
                By default, this pipelines does not handle long generation (ones that exceed in one form or the other
                the model maximum length). There is no perfect way to adress this (more info
                :https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
                strategies to work around that problem depending on your use case.

203
204
                - `None` : default strategy where nothing in particular happens
                - `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
205
206
                  truncate a lot of the prompt and not suitable when generation exceed the model capacity)

207
208
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
209
                corresponding to your framework [here](./model#generative-models)).
210
211

        Return:
212
213
            A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
            of both `generated_text` and `generated_token_ids`):
214

215
            - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
Sylvain Gugger's avatar
Sylvain Gugger committed
216
217
            - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
              ids of the generated text.
218
        """
219
        return super().__call__(text_inputs, **kwargs)
220

221
    def preprocess(
222
223
224
225
226
227
228
229
230
        self,
        prompt_text,
        prefix="",
        handle_long_generation=None,
        add_special_tokens=False,
        truncation=None,
        padding=False,
        max_length=None,
        **generate_kwargs,
231
    ):
232
        inputs = self.tokenizer(
233
234
235
236
237
238
            prefix + prompt_text,
            return_tensors=self.framework,
            truncation=truncation,
            padding=padding,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
239
240
        )
        inputs["prompt_text"] = prompt_text
241
242
243
244
245
246
247
248
249
250
251
252
253

        if handle_long_generation == "hole":
            cur_len = inputs["input_ids"].shape[-1]
            if "max_new_tokens" in generate_kwargs:
                new_tokens = generate_kwargs["max_new_tokens"]
            else:
                new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
                if new_tokens < 0:
                    raise ValueError("We cannot infer how many new tokens are expected")
            if cur_len + new_tokens > self.tokenizer.model_max_length:
                keep_length = self.tokenizer.model_max_length - new_tokens
                if keep_length <= 0:
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
254
255
                        "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
                        " models max length"
256
257
258
259
260
261
                    )

                inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
                if "attention_mask" in inputs:
                    inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]

262
263
264
265
        return inputs

    def _forward(self, model_inputs, **generate_kwargs):
        input_ids = model_inputs["input_ids"]
266
        attention_mask = model_inputs.get("attention_mask", None)
267
268
269
        # Allow empty prompts
        if input_ids.shape[1] == 0:
            input_ids = None
270
            attention_mask = None
271
272
273
            in_b = 1
        else:
            in_b = input_ids.shape[0]
274
        prompt_text = model_inputs.pop("prompt_text")
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293

        # If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
        # generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
        prefix_length = generate_kwargs.pop("prefix_length", 0)
        if prefix_length > 0:
            has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
                "generation_config" in generate_kwargs
                and generate_kwargs["generation_config"].max_new_tokens is not None
            )
            if not has_max_new_tokens:
                generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length
                generate_kwargs["max_length"] += prefix_length
            has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
                "generation_config" in generate_kwargs
                and generate_kwargs["generation_config"].min_new_tokens is not None
            )
            if not has_min_new_tokens and "min_length" in generate_kwargs:
                generate_kwargs["min_length"] += prefix_length

294
295
        # BS x SL
        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
296
297
298
299
300
        out_b = generated_sequence.shape[0]
        if self.framework == "pt":
            generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
        elif self.framework == "tf":
            generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
301
302
303
        return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}

    def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
304
        generated_sequence = model_outputs["generated_sequence"][0]
305
306
307
        input_ids = model_outputs["input_ids"]
        prompt_text = model_outputs["prompt_text"]
        generated_sequence = generated_sequence.numpy().tolist()
308
309
310
        records = []
        for sequence in generated_sequence:
            if return_type == ReturnType.TENSORS:
311
                record = {"generated_token_ids": sequence}
312
313
            elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
                # Decode text
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                text = self.tokenizer.decode(
                    sequence,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                )

                # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
                if input_ids is None:
                    prompt_length = 0
                else:
                    prompt_length = len(
                        self.tokenizer.decode(
                            input_ids[0],
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
329
                        )
330
                    )
331

332
                all_text = text[prompt_length:]
333
                if return_type == ReturnType.FULL_TEXT:
334
                    all_text = prompt_text + all_text
335

336
337
                record = {"generated_text": all_text}
            records.append(record)
338

339
        return records