Unverified Commit dcc49d8a authored by Billy Bradley's avatar Billy Bradley Committed by GitHub
Browse files

In assisted decoding, pass model_kwargs to model's forward call (fix...

In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242)

* In assisted decoding, pass model_kwargs to model's forward call

Previously, assisted decoding would ignore any additional kwargs
that it doesn't explicitly handle. This was inconsistent with other
generation methods, which pass the model_kwargs through
prepare_inputs_for_generation and forward the returned dict to the
model's forward call.

The prepare_inputs_for_generation method needs to be amended in all
models, as previously it only kept the last input ID when a past_key_values
was passed.

* Improve variable names in _extend_attention_mask

* Refactor extending token_type_ids into a function

* Replace deepcopy with copy to optimize performance

* Update new persimmon model with llama changes for assisted generation

* Update new mistral model for assisted generation with prepare_inputs_for_generation

* Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation
parent 1e3c9dda
...@@ -970,9 +970,18 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): ...@@ -970,9 +970,18 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -1118,9 +1118,18 @@ class XmodForCausalLM(XmodPreTrainedModel): ...@@ -1118,9 +1118,18 @@ class XmodForCausalLM(XmodPreTrainedModel):
if attention_mask is None: if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape) attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past_key_values is used
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
......
...@@ -2906,3 +2906,89 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2906,3 +2906,89 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
model.generation_config.max_length = 10 model.generation_config.max_length = 10
model.generate(input_ids) model.generate(input_ids)
self.assertEqual(len(warning_list), 0) self.assertEqual(len(warning_list), 0)
def test_model_kwarg_assisted_decoding_decoder_only(self):
# PT-only test: TF doesn't support assisted decoding yet.
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model.config.pad_token_id = tokenizer.eos_token_id
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with token_type_ids
outputs_tti = model.generate(
input_ids,
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
assistant.config.pad_token_id = tokenizer.eos_token_id
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device),
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist())
def test_model_kwarg_assisted_decoding_encoder_decoder(self):
# PT-only test: TF doesn't support assisted decoding yet.
# Bart subclass with a kwarg that distorts the output
class FakeBart(BartForConditionalGeneration):
def forward(self, input_ids, foo=False, **kwargs):
outs = super().forward(input_ids, **kwargs)
if foo:
outs["logits"][:, :, :] = 0.0
return outs
def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
inputs = super().prepare_inputs_for_generation(*args, **kwargs)
inputs["foo"] = foo
return inputs
model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to(
torch_device
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration")
text = "Hello world"
tokenized_inputs = tokenizer([text], return_tensors="pt")
input_ids = tokenized_inputs.input_ids.to(torch_device)
# Traditional way of generating text
outputs_normal = model.generate(input_ids)
self.assertEqual(outputs_normal.shape, (1, 20))
# Should be different with foo
outputs_foo = model.generate(
input_ids,
foo=True,
)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist())
# Assistant model
assistant = AutoModelForSeq2SeqLM.from_pretrained(
"hf-internal-testing/tiny-random-BartForConditionalGeneration"
).to(torch_device)
# If assisted generation passes model_kwargs correctly, should be same as previous
outputs_assisted = model.generate(
input_ids,
foo=True,
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment