Unverified Commit 3560ae6d authored by Pavel Denisov's avatar Pavel Denisov Committed by GitHub
Browse files

Add `inputs_embeds` support for `.generate()` with BLOOM models (#21430)

Add accepting `.generate()` calls with `inputs_embeds` on BLOOM models
parent f21af262
......@@ -842,6 +842,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
# only last token for input_ids if past is not None
......@@ -852,12 +853,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment