Unverified Commit 3560ae6d authored by Pavel Denisov's avatar Pavel Denisov Committed by GitHub
Browse files

Add `inputs_embeds` support for `.generate()` with BLOOM models (#21430)

Add accepting `.generate()` calls with `inputs_embeds` on BLOOM models
parent f21af262
...@@ -842,6 +842,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -842,6 +842,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs **kwargs
) -> dict: ) -> dict:
# only last token for input_ids if past is not None # only last token for input_ids if past is not None
...@@ -852,12 +853,20 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -852,12 +853,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
if past_key_values[0][0].shape[0] == input_ids.shape[0]: if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values) past_key_values = self._convert_to_bloom_cache(past_key_values)
return { # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"input_ids": input_ids, if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask, "attention_mask": attention_mask,
} }
)
return model_inputs
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment