Unverified Commit 08cd694e authored by Max Baak's avatar Max Baak Committed by GitHub
Browse files

ENH: added new output_logits option to generate function (#28667)

output_logits option behaves like output_scores, but returns the raw, unprocessed prediction logit scores,
ie. the values before they undergo logit processing and/or warping. The latter happens by default for the
regular output scores.

It's useful to have the unprocessed logit scores in certain circumstances. For example, unprocessed logit scores
are very useful with causallm models when one wants to determine the probability of a certain answer, e.g.
when asking a question with a yes/no answer. In that case getting the next-token probabilities of both "yes" and
"no" (and/or their relative ratio) is of interest for classification. The reason for getting these _before_ logit
processing and/or warping is b/c a) that can change the probabilities or b) reject the tokens of interest / reduce
the number of tokens to just 1.

For an example use-case see paper TabLLM: Few-shot Classification of Tabular Data with Large Language Models
by Stefan Hegselmann, Alejandro Buendia, Hunter Lang, Monica Agrawal, Xiaoyi Jiang, and David Sontag.
https://arxiv.org/abs/2210.10723



In addition:
- added dedicated unit test: tests/generation/test_utils/test_return_unprocessed_logit_scores
  which tests return of logics with output_logits=True in generation.
- set output_logits=True in all other generation unit tests, that also have output_scores=True.

Implemented @gante's and @amyeroberts review feedback
Co-authored-by: default avatarkx79wq <max.baak@ing.com>
parent 07e3454f
......@@ -216,6 +216,9 @@ class GenerationConfig(PushToHubMixin):
more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
output_logits (`bool`, *optional*):
Whether or not to return the unprocessed prediction logit scores. See `logits` under returned tensors for
more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
......@@ -315,6 +318,7 @@ class GenerationConfig(PushToHubMixin):
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_scores = kwargs.pop("output_scores", False)
self.output_logits = kwargs.pop("output_logits", None)
self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
# Special tokens that can be used at generation time
......
This diff is collapsed.
......@@ -269,6 +269,7 @@ class GenerationTesterMixin:
attention_mask,
max_length,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -293,6 +294,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**logits_process_kwargs,
**model_kwargs,
......@@ -317,6 +319,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
......@@ -335,6 +338,7 @@ class GenerationTesterMixin:
logits_warper_kwargs,
process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -348,6 +352,7 @@ class GenerationTesterMixin:
max_length=max_length,
num_return_sequences=num_return_sequences,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -379,6 +384,7 @@ class GenerationTesterMixin:
logits_processor=logits_processor,
logits_warper=logits_warper,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -399,6 +405,7 @@ class GenerationTesterMixin:
logits_processor,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -409,6 +416,7 @@ class GenerationTesterMixin:
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -440,6 +448,7 @@ class GenerationTesterMixin:
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -459,6 +468,7 @@ class GenerationTesterMixin:
logits_warper,
logits_warper_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -470,6 +480,7 @@ class GenerationTesterMixin:
do_sample=True,
max_length=max_length,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -506,6 +517,7 @@ class GenerationTesterMixin:
logits_warper=logits_warper,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -526,6 +538,7 @@ class GenerationTesterMixin:
logits_processor,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -536,6 +549,7 @@ class GenerationTesterMixin:
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -567,6 +581,7 @@ class GenerationTesterMixin:
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -587,6 +602,7 @@ class GenerationTesterMixin:
logits_processor,
logits_process_kwargs,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -597,6 +613,7 @@ class GenerationTesterMixin:
do_sample=False,
max_length=max_length,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -629,6 +646,7 @@ class GenerationTesterMixin:
max_length=max_length,
logits_processor=logits_processor,
output_scores=output_scores,
output_logits=output_logits,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict_in_generate=return_dict_in_generate,
......@@ -644,6 +662,7 @@ class GenerationTesterMixin:
attention_mask,
max_length,
output_scores=False,
output_logits=False,
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
......@@ -673,6 +692,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**logits_process_kwargs,
**model_kwargs,
......@@ -699,6 +719,7 @@ class GenerationTesterMixin:
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
output_logits=output_logits,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
**model_kwargs,
......@@ -729,6 +750,7 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -769,6 +791,7 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -853,6 +876,7 @@ class GenerationTesterMixin:
logits_warper_kwargs=logits_warper_kwargs,
process_kwargs=process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -964,6 +988,7 @@ class GenerationTesterMixin:
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1032,6 +1057,7 @@ class GenerationTesterMixin:
logits_process_kwargs=logits_process_kwargs,
logits_processor=logits_processor,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1126,6 +1152,7 @@ class GenerationTesterMixin:
logits_warper=logits_warper,
logits_warper_kwargs=logits_warper_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1262,6 +1289,7 @@ class GenerationTesterMixin:
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1421,6 +1449,7 @@ class GenerationTesterMixin:
logits_processor=logits_processor,
logits_process_kwargs=logits_process_kwargs,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1493,6 +1522,7 @@ class GenerationTesterMixin:
attention_mask=attention_mask,
max_length=max_length,
output_scores=True,
output_logits=True,
output_hidden_states=True,
output_attentions=True,
return_dict_in_generate=True,
......@@ -1628,6 +1658,7 @@ class GenerationTesterMixin:
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_logits": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
......@@ -1690,6 +1721,7 @@ class GenerationTesterMixin:
"num_beams": 1,
"do_sample": False,
"output_scores": True,
"output_logits": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
......@@ -1753,6 +1785,7 @@ class GenerationTesterMixin:
"do_sample": True,
"assistant_model": assistant_model,
"output_scores": True,
"output_logits": True,
"output_hidden_states": True,
"output_attentions": True,
"return_dict_in_generate": True,
......@@ -2105,6 +2138,7 @@ class GenerationTesterMixin:
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
gen_len = (
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
)
......@@ -2112,6 +2146,9 @@ class GenerationTesterMixin:
# scores
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
# unprocessed logits
self._check_logits(num_sequences_in_output, output.logits, config=config)
# Attentions
if config.is_encoder_decoder:
# encoder
......@@ -2191,6 +2228,14 @@ class GenerationTesterMixin:
self.assertEqual(len(scores), length)
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
def _check_logits(self, batch_size, scores, config):
self.assertIsInstance(scores, tuple)
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
vocab_diff = config.vocab_size - scores[0].shape[-1]
self.assertTrue(vocab_diff in [0, 1])
self.assertListEqual([config.vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
......@@ -3536,3 +3581,60 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
model.generate(**inputs, **generation_kwargs)
# update_candidate_strategy is called once but assistant_model.generation_config.num_assistant_tokens should stay 5
self.assertEqual(assistant_model.generation_config.num_assistant_tokens, 5)
def test_compare_unprocessed_logit_scores(self):
# Get unprocessed logit scores back from model generate function.
# Assert that unprocessed logits from generate() are same as those from modal eval()
# tell model to generate text and return unprocessed/unwarped logit scores
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = "generate yes or no: "
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
with torch.no_grad():
# Get logits for the next token from fwd pass
logits_fwd = model(input_ids).logits[:, -1, :][0]
# Get logits for the next token from generate function
outputs = model.generate(
input_ids=input_ids,
return_dict_in_generate=True,
output_logits=True,
max_new_tokens=1,
do_sample=True,
)
logits_gen = outputs.logits[0][0]
# assert that unprocessed logits from generate() are same as those from modal eval()
self.assertListEqual(logits_fwd.tolist(), logits_gen.tolist())
def test_return_unprocessed_logit_scores(self):
# tell model to generate text and return unprocessed/unwarped logit scores
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = "generate yes or no: "
input_ids = tokenizer([text], return_tensors="pt").input_ids.to(torch_device)
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
outputs = model.generate(
input_ids=input_ids, return_dict_in_generate=True, output_logits=True, max_new_tokens=3
)
# perform dummy check if unpreprocessed logits make sense.
# do preselection on high probabilities; find scores of y and n tokens
probs_all = torch.nn.functional.softmax(outputs.logits[2][0], dim=-1)
indices = torch.argwhere(probs_all > 0.001)
indices = indices[:, -1]
tokens_max = tokenizer.batch_decode(indices, skip_special_tokens=True)
probs_max = probs_all[probs_all > 0.001]
self.assertTrue(len(indices) >= 2)
next_token_dict = {str(t): p for t, p in zip(tokens_max, probs_max)}
self.assertTrue("n" in next_token_dict)
self.assertTrue("y" in next_token_dict)
y_prob = next_token_dict["y"]
n_prob = next_token_dict["n"]
self.assertTrue(y_prob > 0.001 and n_prob > 0.001)
self.assertTrue(y_prob <= 1.0 and n_prob <= 1.0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment