"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "75e1eed8d190afa5be30fba05cd872d79b492a24"
Unverified Commit 5f3efdf7 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: `group_beam_search` requires `diversity_penalty>0.0` (#24456)

* add exception

* update docs
parent 43479ef9
...@@ -301,8 +301,9 @@ the `num_beams` greater than 1, and set `do_sample=True` to use this decoding st ...@@ -301,8 +301,9 @@ the `num_beams` greater than 1, and set `do_sample=True` to use this decoding st
The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse The diverse beam search decoding strategy is an extension of the beam search strategy that allows for generating a more diverse
set of beam sequences to choose from. To learn how it works, refer to [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf). set of beam sequences to choose from. To learn how it works, refer to [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf).
This approach has two main parameters: `num_beams` and `num_beam_groups`. This approach has three main parameters: `num_beams`, `num_beam_groups`, and `diversity_penalty`.
The groups are selected to ensure they are distinct enough compared to the others, and regular beam search is used within each group. The diversily penalty ensures the outputs are distinct across groups, and beam search is used within each group.
```python ```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM >>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
...@@ -328,9 +329,9 @@ The groups are selected to ensure they are distinct enough compared to the other ...@@ -328,9 +329,9 @@ The groups are selected to ensure they are distinct enough compared to the other
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) >>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30) >>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True) >>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'The Design Principles are a set of universal design principles that can be applied to any location, climate and culture, and they allow us to design the most efficient and sustainable human habitation and food production systems.' 'The aim of this project is to create a new type of living system, one that is more sustainable and efficient than the current one.'
``` ```
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
......
...@@ -1669,6 +1669,11 @@ class GenerationMixin: ...@@ -1669,6 +1669,11 @@ class GenerationMixin:
if generation_config.num_beams % generation_config.num_beam_groups != 0: if generation_config.num_beams % generation_config.num_beam_groups != 0:
raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.")
if generation_config.diversity_penalty == 0.0:
raise ValueError(
"`diversity_penalty` should be greater than `0.0`, otherwise your beam groups will be identical."
)
if stopping_criteria.max_length is None: if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
......
...@@ -2366,6 +2366,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2366,6 +2366,7 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
num_beams=2, num_beams=2,
num_beam_groups=2, num_beam_groups=2,
num_return_sequences=2, num_return_sequences=2,
diversity_penalty=1.0,
eos_token_id=None, eos_token_id=None,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment