"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "3b0cb7945f5c9dbe55d5f76720ddf4e475c11169"
Unverified Commit cb3c821c authored by hukuda222's avatar hukuda222 Committed by GitHub
Browse files

aligned sample_beam output selection with beam_search (#25375)



* aligned sample_beam specs with beam_search

* pull origin main

* Revert "pull origin main"

This reverts commit 06d356f1137bb52272e120a03636598c44449cf3.

* update test_utils.py

* fix format

* remove comment

---------
Co-authored-by: default avatarShogo Fujita <shogo.fujita@legalontech.jp>
parent 704bf595
...@@ -1691,18 +1691,19 @@ class GenerationMixin: ...@@ -1691,18 +1691,19 @@ class GenerationMixin:
# 12. prepare beam search scorer # 12. prepare beam search scorer
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size * generation_config.num_return_sequences, batch_size=batch_size,
num_beams=generation_config.num_beams, num_beams=generation_config.num_beams,
device=inputs_tensor.device, device=inputs_tensor.device,
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping, do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length, max_length=generation_config.max_length,
) )
# 13. interleave input_ids with `num_beams` additional sequences per batch # 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids, input_ids=input_ids,
expand_size=generation_config.num_beams * generation_config.num_return_sequences, expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs, **model_kwargs,
) )
......
...@@ -438,7 +438,6 @@ class GenerationTesterMixin: ...@@ -438,7 +438,6 @@ class GenerationTesterMixin:
input_ids, input_ids,
attention_mask, attention_mask,
max_length, max_length,
num_return_sequences,
beam_scorer, beam_scorer,
beam_kwargs, beam_kwargs,
logits_warper, logits_warper,
...@@ -463,7 +462,7 @@ class GenerationTesterMixin: ...@@ -463,7 +462,7 @@ class GenerationTesterMixin:
**logits_warper_kwargs, **logits_warper_kwargs,
**model_kwargs, **model_kwargs,
) )
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences` # beam_search does not automatically interleave `batch_size` dim for `num_beams`
torch.manual_seed(0) torch.manual_seed(0)
kwargs = {} kwargs = {}
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
...@@ -471,13 +470,13 @@ class GenerationTesterMixin: ...@@ -471,13 +470,13 @@ class GenerationTesterMixin:
model, model,
input_ids, input_ids,
attention_mask, attention_mask,
num_interleave=beam_scorer.num_beams * num_return_sequences, num_interleave=beam_scorer.num_beams,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
) )
kwargs["encoder_outputs"] = encoder_outputs kwargs["encoder_outputs"] = encoder_outputs
elif attention_mask is not None: elif attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0) attention_mask = attention_mask.repeat_interleave(beam_scorer.num_beams, dim=0)
# prevent flaky generation test failures # prevent flaky generation test failures
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
...@@ -486,7 +485,7 @@ class GenerationTesterMixin: ...@@ -486,7 +485,7 @@ class GenerationTesterMixin:
with torch.no_grad(): with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {} model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample( output_beam_sample = model.beam_sample(
input_ids.repeat_interleave(beam_scorer.num_beams * num_return_sequences, dim=0), input_ids.repeat_interleave(beam_scorer.num_beams, dim=0),
beam_scorer, beam_scorer,
max_length=max_length, max_length=max_length,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -891,13 +890,9 @@ class GenerationTesterMixin: ...@@ -891,13 +890,9 @@ class GenerationTesterMixin:
self.assertListEqual(output_generate.tolist(), output_beam_search.tolist()) self.assertListEqual(output_generate.tolist(), output_beam_search.tolist())
# check `generate()` and `beam_search()` are equal for `num_return_sequences`
num_return_sequences = 2
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
input_ids.shape[0], max_length, num_return_sequences=num_return_sequences
)
output_generate, output_beam_search = self._beam_search_generate( output_generate, output_beam_search = self._beam_search_generate(
model=model, model=model,
...@@ -1036,21 +1031,15 @@ class GenerationTesterMixin: ...@@ -1036,21 +1031,15 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
# check `generate()` and `beam_search()` are equal # check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences = 2
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_generate, output_beam_sample = self._beam_sample_generate( output_generate, output_beam_sample = self._beam_sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer, beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -1074,20 +1063,15 @@ class GenerationTesterMixin: ...@@ -1074,20 +1063,15 @@ class GenerationTesterMixin:
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1) logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=1)
num_return_sequences = 2
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
max_length = 4 max_length = 4
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(input_ids.shape[0], max_length)
input_ids.shape[0] * num_return_sequences, max_length
)
beam_kwargs["num_return_sequences"] = num_return_sequences
output_beam_sample, output_generate = self._beam_sample_generate( output_beam_sample, output_generate = self._beam_sample_generate(
model=model, model=model,
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
max_length=max_length, max_length=max_length,
num_return_sequences=num_return_sequences,
beam_scorer=beam_scorer, beam_scorer=beam_scorer,
beam_kwargs=beam_kwargs, beam_kwargs=beam_kwargs,
logits_warper=logits_warper, logits_warper=logits_warper,
...@@ -1113,9 +1097,7 @@ class GenerationTesterMixin: ...@@ -1113,9 +1097,7 @@ class GenerationTesterMixin:
self.assertTrue((output_generate["sequences_scores"] < 0).all().item()) self.assertTrue((output_generate["sequences_scores"] < 0).all().item())
for output in (output_beam_sample, output_generate): for output in (output_beam_sample, output_generate):
self._check_outputs( self._check_outputs(output, input_ids, model.config, num_return_sequences=beam_scorer.num_beams)
output, input_ids, model.config, num_return_sequences=num_return_sequences * beam_scorer.num_beams
)
def test_generate_without_input_ids(self): def test_generate_without_input_ids(self):
config, _, _, max_length = self._get_input_ids_and_config() config, _, _, max_length = self._get_input_ids_and_config()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment