Unverified Commit b0138422 authored by Martin Schmitt's avatar Martin Schmitt Committed by GitHub
Browse files

Changed `num_beams` to `num_beams // num_beam_groups` when initialising...

Changed `num_beams` to `num_beams // num_beam_groups` when initialising `PrefixConstrainedLogitsProcessor` in `_get_logits_processor` to fix compatibility issue when constrained decoding is used together with grouped beam search (#10475)
parent 0c232519
...@@ -605,7 +605,7 @@ class GenerationMixin: ...@@ -605,7 +605,7 @@ class GenerationMixin:
if min_length is not None and eos_token_id is not None and min_length > -1: if min_length is not None and eos_token_id is not None and min_length > -1:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id)) processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None: if prefix_allowed_tokens_fn is not None:
processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams)) processors.append(PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, num_beams // num_beam_groups))
if forced_bos_token_id is not None: if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id)) processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None: if forced_eos_token_id is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment