Unverified Commit bc7a6fdc authored by NaN's avatar NaN Committed by GitHub
Browse files

Fix Constrained beam search duplication and weird output issue (#17814)

* fix(ConstrainedBeamSearchScorer.step_sentence_constraint): avoid hypothesis duplication between topk and advance

* fix(GenerationMixin.constrained_beam_search): appropriately assign beam scores instead of token scores
parent c2c0d9db
...@@ -655,7 +655,13 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -655,7 +655,13 @@ class ConstrainedBeamSearchScorer(BeamScorer):
full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1) full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
# need to make new hypothesis that advance the constraints # need to make new hypothesis that advance the constraints
track_new = {"new_seqs": [], "new_states": [], "new_indices": [], "new_tokens": [], "new_scores": []} track_new = {
"new_seqs": full_hypotheses.tolist(),
"new_states": [],
"new_indices": [],
"new_tokens": [],
"new_scores": [],
}
for seq_idx, pre_seq in enumerate(this_batch_input_ids): for seq_idx, pre_seq in enumerate(this_batch_input_ids):
# pre_seq = ith sequence generated before this step. # pre_seq = ith sequence generated before this step.
......
...@@ -3220,10 +3220,10 @@ class GenerationMixin: ...@@ -3220,10 +3220,10 @@ class GenerationMixin:
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
scores_for_all_vocab = next_token_scores_processed.clone()
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
scores_for_all_vocab = next_token_scores.clone()
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores: if output_scores:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment