"tests/yoso/test_modeling_yoso.py" did not exist on "fda703a55374b3caaf4e886016f7de5810fa3571"
text_generation.py 14.1 KB
Newer Older
1
import enum
2
import warnings
3

4
from .. import MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING
5
from ..utils import add_end_docstrings, is_tf_available
6
7
8
from .base import PIPELINE_INIT_ARGS, Pipeline


9
10
11
12
if is_tf_available():
    import tensorflow as tf


13
14
15
16
17
18
class ReturnType(enum.Enum):
    TENSORS = 0
    NEW_TEXT = 1
    FULL_TEXT = 2


19
20
21
@add_end_docstrings(PIPELINE_INIT_ARGS)
class TextGenerationPipeline(Pipeline):
    """
22
    Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
23
24
    specified text prompt.

25
26
27
28
29
30
31
32
33
34
35
36
37
    Example:

    ```python
    >>> from transformers import pipeline

    >>> generator = pipeline(model="gpt2")
    >>> generator("I can't believe you did such a ", do_sample=False)
    [{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]

    >>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
    >>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
    ```

Steven Liu's avatar
Steven Liu committed
38
    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
39

Sylvain Gugger's avatar
Sylvain Gugger committed
40
41
    This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"text-generation"`.
42
43
44

    The models that this pipeline can use are models that have been trained with an autoregressive language modeling
    objective, which includes the uni-directional models in the library (e.g. gpt2). See the list of available models
45
    on [huggingface.co/models](https://huggingface.co/models?filter=text-generation).
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """

    # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
    # in https://github.com/rusiaaman/XLNet-gen#methodology
    # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e

    XL_PREFIX = """
    In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
    voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
    Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
    and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
    accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
    the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
    begging for his blessing. <eod> </s> <eos>
    """

62
    def __init__(self, *args, **kwargs):
63
        super().__init__(*args, **kwargs)
64
65
66
        self.check_model_type(
            TF_MODEL_FOR_CAUSAL_LM_MAPPING if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING
        )
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
        if "prefix" not in self._preprocess_params:
            # This is very specific. The logic is quite complex and needs to be done
            # as a "default".
            # It also defines both some preprocess_kwargs and generate_kwargs
            # which is why we cannot put them in their respective methods.
            prefix = None
            if self.model.config.prefix is not None:
                prefix = self.model.config.prefix
            if prefix is None and self.model.__class__.__name__ in [
                "XLNetLMHeadModel",
                "TransfoXLLMHeadModel",
                "TFXLNetLMHeadModel",
                "TFTransfoXLLMHeadModel",
            ]:
                # For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
                prefix = self.XL_PREFIX
            if prefix is not None:
                # Recalculate some generate_kwargs linked to prefix.
                preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
                self._preprocess_params = {**self._preprocess_params, **preprocess_params}
                self._forward_params = {**self._forward_params, **forward_params}

    def _sanitize_parameters(
        self,
        return_full_text=None,
        return_tensors=None,
        return_text=None,
        return_type=None,
        clean_up_tokenization_spaces=None,
        prefix=None,
97
        handle_long_generation=None,
98
        stop_sequence=None,
99
        **generate_kwargs,
100
101
102
103
104
105
106
107
108
    ):
        preprocess_params = {}
        if prefix is not None:
            preprocess_params["prefix"] = prefix
        if prefix:
            prefix_inputs = self.tokenizer(
                prefix, padding=False, add_special_tokens=False, return_tensors=self.framework
            )
            prefix_length = prefix_inputs["input_ids"].shape[-1]
109
110
111
112

            if "max_new_tokens" in generate_kwargs:
                pass
            elif "max_length" in generate_kwargs:
113
114
115
116
117
118
                generate_kwargs["max_length"] += prefix_length
            else:
                generate_kwargs["max_length"] = self.model.config.max_length + prefix_length

            if "min_length" in generate_kwargs:
                generate_kwargs["min_length"] += prefix_length
119
120
121
        if handle_long_generation is not None:
            if handle_long_generation not in {"hole"}:
                raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
122
123
                    f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
                    " [None, 'hole']"
124
125
126
127
                )
            preprocess_params["handle_long_generation"] = handle_long_generation

        preprocess_params.update(generate_kwargs)
128
129
130
131
        forward_params = generate_kwargs

        postprocess_params = {}
        if return_full_text is not None and return_type is None:
132
133
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
Nicolas Patry's avatar
Nicolas Patry committed
134
135
            if return_tensors is not None:
                raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
136
137
            return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
        if return_tensors is not None and return_type is None:
Nicolas Patry's avatar
Nicolas Patry committed
138
139
            if return_text is not None:
                raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
140
141
142
143
144
145
            return_type = ReturnType.TENSORS
        if return_type is not None:
            postprocess_params["return_type"] = return_type
        if clean_up_tokenization_spaces is not None:
            postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces

146
147
148
149
150
151
152
153
154
        if stop_sequence is not None:
            stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
            if len(stop_sequence_ids) > 1:
                warnings.warn(
                    "Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
                    " the stop sequence will be used as the stop sequence string in the interim."
                )
            generate_kwargs["eos_token_id"] = stop_sequence_ids[0]

155
        return preprocess_params, forward_params, postprocess_params
156
157

    # overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
158
    def _parse_and_tokenize(self, *args, **kwargs):
159
160
161
162
163
        """
        Parse arguments and tokenize
        """
        # Parse arguments
        if self.model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
164
            kwargs.update({"add_space_before_punct_symbol": True})
165

166
        return super()._parse_and_tokenize(*args, **kwargs)
167

168
    def __call__(self, text_inputs, **kwargs):
169
170
171
172
        """
        Complete the prompt(s) given as inputs.

        Args:
173
            args (`str` or `List[str]`):
174
                One or several prompts (or one list of prompts) to complete.
175
            return_tensors (`bool`, *optional*, defaults to `False`):
176
177
                Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
                `True`, the decoded text is not returned.
178
            return_text (`bool`, *optional*, defaults to `True`):
179
                Whether or not to return the decoded texts in the outputs.
180
            return_full_text (`bool`, *optional*, defaults to `True`):
181
                If set to `False` only added text is returned, otherwise the full text is returned. Only meaningful if
Sylvain Gugger's avatar
Sylvain Gugger committed
182
                *return_text* is set to True.
183
            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
184
                Whether or not to clean up the potential extra spaces in the text output.
185
            prefix (`str`, *optional*):
186
                Prefix added to prompt.
187
            handle_long_generation (`str`, *optional*):
188
189
190
191
192
                By default, this pipelines does not handle long generation (ones that exceed in one form or the other
                the model maximum length). There is no perfect way to adress this (more info
                :https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
                strategies to work around that problem depending on your use case.

193
194
                - `None` : default strategy where nothing in particular happens
                - `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
195
196
                  truncate a lot of the prompt and not suitable when generation exceed the model capacity)

197
198
            generate_kwargs:
                Additional keyword arguments to pass along to the generate method of the model (see the generate method
199
                corresponding to your framework [here](./model#generative-models)).
200
201

        Return:
202
203
            A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination
            of both `generated_text` and `generated_token_ids`):
204

205
            - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
Sylvain Gugger's avatar
Sylvain Gugger committed
206
207
            - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
              ids of the generated text.
208
        """
209
        return super().__call__(text_inputs, **kwargs)
210

211
    def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
212
213
214
215
        inputs = self.tokenizer(
            prefix + prompt_text, padding=False, add_special_tokens=False, return_tensors=self.framework
        )
        inputs["prompt_text"] = prompt_text
216
217
218
219
220
221
222
223
224
225
226
227
228

        if handle_long_generation == "hole":
            cur_len = inputs["input_ids"].shape[-1]
            if "max_new_tokens" in generate_kwargs:
                new_tokens = generate_kwargs["max_new_tokens"]
            else:
                new_tokens = generate_kwargs.get("max_length", self.model.config.max_length) - cur_len
                if new_tokens < 0:
                    raise ValueError("We cannot infer how many new tokens are expected")
            if cur_len + new_tokens > self.tokenizer.model_max_length:
                keep_length = self.tokenizer.model_max_length - new_tokens
                if keep_length <= 0:
                    raise ValueError(
Sylvain Gugger's avatar
Sylvain Gugger committed
229
230
                        "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
                        " models max length"
231
232
233
234
235
236
                    )

                inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
                if "attention_mask" in inputs:
                    inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]

237
238
239
240
        return inputs

    def _forward(self, model_inputs, **generate_kwargs):
        input_ids = model_inputs["input_ids"]
241
        attention_mask = model_inputs.get("attention_mask", None)
242
243
244
        # Allow empty prompts
        if input_ids.shape[1] == 0:
            input_ids = None
245
            attention_mask = None
246
247
248
            in_b = 1
        else:
            in_b = input_ids.shape[0]
249
        prompt_text = model_inputs.pop("prompt_text")
250
251
        # BS x SL
        generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
252
253
254
255
256
        out_b = generated_sequence.shape[0]
        if self.framework == "pt":
            generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
        elif self.framework == "tf":
            generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
257
258
259
        return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}

    def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
260
        generated_sequence = model_outputs["generated_sequence"][0]
261
262
263
        input_ids = model_outputs["input_ids"]
        prompt_text = model_outputs["prompt_text"]
        generated_sequence = generated_sequence.numpy().tolist()
264
265
266
        records = []
        for sequence in generated_sequence:
            if return_type == ReturnType.TENSORS:
267
                record = {"generated_token_ids": sequence}
268
269
            elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
                # Decode text
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                text = self.tokenizer.decode(
                    sequence,
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=clean_up_tokenization_spaces,
                )

                # Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
                if input_ids is None:
                    prompt_length = 0
                else:
                    prompt_length = len(
                        self.tokenizer.decode(
                            input_ids[0],
                            skip_special_tokens=True,
                            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
285
                        )
286
                    )
287

288
289
290
291
                if return_type == ReturnType.FULL_TEXT:
                    all_text = prompt_text + text[prompt_length:]
                else:
                    all_text = text[prompt_length:]
292

293
294
                record = {"generated_text": all_text}
            records.append(record)
295

296
        return records