text_generation.mdx 4.85 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Generation

15
Each framework has a generate method for auto-regressive text generation implemented in their respective `GenerationMixin` class:
16

17
18
19
- PyTorch [`~generation.GenerationMixin.generate`] is implemented in [`~generation.GenerationMixin`].
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
Regardless of your framework of choice, you can parameterize the generate method with a [`~generation.GenerationConfig`]
class instance. Please refer to this class for the complete list of generation parameters, which control the behavior
of the generation method.

All models have a default generation configuration that will be used if you don't provide one. If you have a loaded
model instance `model`, you can inspect the default generation configuration with `model.generation_config`. If you'd
like to set a new default generation configuration, you can create a new [`~generation.GenerationConfig`] instance and
store it with `save_pretrained`, making sure to leave its `config_file_name` argument empty.

```python
from transformers import AutoModelForCausalLM, GenerationConfig

model = AutoModelForCausalLM.from_pretrained("my_account/my_model")

# Inspect the default generation configuration
print(model.generation_config)

# Set a new default generation configuration
generation_config = GenerationConfig(
    max_new_tokens=50, do_sample=True, top_k=50, eos_token_id=model.config.eos_token_id
)
generation_config.save_pretrained("my_account/my_model", push_to_hub=True)
```

<Tip>

If you inspect a serialized [`~generation.GenerationConfig`] file or print a class instance, you will notice that
default values are omitted. Some attributes, like `max_length`, have a conservative default value, to avoid running
into resource limitations. Make sure you double-check the defaults in the documentation.

</Tip>

You can also store several generation parametrizations in a single directory, making use of the `config_file_name`
argument in `save_pretrained`. You can latter instantiate them with `from_pretrained`. This is useful if you want to
store several generation configurations for a single model (e.g. one for creative text generation with sampling, and
other for summarization with beam search).

```python
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

translation_generation_config = GenerationConfig(
    num_beams=4,
    early_stopping=True,
    decoder_start_token_id=0,
    eos_token_id=model.config.eos_token_id,
    pad_token=model.config.pad_token_id,
)
# If you were working on a model for which your had the right Hub permissions, you could store a named generation
# config as follows
translation_generation_config.save_pretrained("t5-small", "translation_generation_config.json", push_to_hub=True)

# You could then use the named generation config file to parameterize generation
generation_config = GenerationConfig.from_pretrained("t5-small", "translation_generation_config.json")
inputs = tokenizer("translate English to French: Configuration files are easy to use!", return_tensors="pt")
outputs = model.generate(**inputs, generation_config=generation_config)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
# ['Les fichiers de configuration sont faciles 脿 utiliser !']
```

Finally, you can specify ad hoc modifications to the used generation configuration by passing the attribute you
wish to override directly to the generate method (e.g. `model.generate(inputs, max_new_tokens=512)`). Each
framework's `generate` method docstring (available below) has a few illustrative examples on the different strategies
to parameterize it.

88
89
90
91
92

## GenerationConfig

[[autodoc]] generation.GenerationConfig
	- from_pretrained
93
	- from_model_config
94
95
	- save_pretrained

96
## GenerationMixin
97

98
[[autodoc]] generation.GenerationMixin
99
100
101
102
103
	- generate
	- greedy_search
	- sample
	- beam_search
	- beam_sample
104
	- contrastive_search
105
106
107
	- group_beam_search
	- constrained_beam_search

108
## TFGenerationMixin
109

110
[[autodoc]] generation.TFGenerationMixin
111
112
	- generate

113
## FlaxGenerationMixin
114

115
[[autodoc]] generation.FlaxGenerationMixin
116
	- generate