Unverified Commit 92ce53aa authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: decoder-only models can generate with `inputs_embeds` (#21405)

parent e5db7051
...@@ -519,47 +519,40 @@ class GenerationMixin: ...@@ -519,47 +519,40 @@ class GenerationMixin:
inputs_kwarg = model_kwargs.pop(input_name, None) inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None: if inputs_kwarg is not None and inputs is not None:
raise ValueError( raise ValueError(
f"`inputs`: {inputs}` were passed alongside " f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."
f"{input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..." f"Make sure to either pass {inputs} or {input_name}=..."
) )
elif inputs_kwarg is not None: elif inputs_kwarg is not None:
inputs = inputs_kwarg inputs = inputs_kwarg
# 3. models with `input_ids` can also make use of `inputs_embeds` # 3. In the presence of `inputs_embeds` for text models:
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs): # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds" # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# 4. Only encoder-decoder models can have non `input_ids` input format # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
if not self.config.is_encoder_decoder and input_name != "input_ids": # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
raise ValueError( if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
f"If {input_name} is passed as model-specific keyword " if not self.config.is_encoder_decoder:
"input then model has to be an encoder-decoder and not a " has_inputs_embeds_forwarding = "inputs_embeds" in set(
f"{self.__class__.__name__}." inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
) )
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 5. if `inputs` is still None, try to create `input_ids` from BOS token # 4. if `inputs` is still None, try to create `input_ids` from BOS token
if inputs is None: if inputs is None:
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs")) inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
return inputs, input_name, model_kwargs return inputs, input_name, model_kwargs
def _can_retrieve_inputs_from_name(
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
If `inputs` is None and `name` is in both forward function and keyword arguments, then inputs can be retrieved
from name
"""
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
inspect.signature(self.forward).parameters.keys()
)
if can_retrieve_inputs and inputs is not None:
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
return can_retrieve_inputs
def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor: def adjust_logits_during_generation(self, logits: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
""" """
Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method. Implement in subclasses of [`PreTrainedModel`] for custom behavior to adjust the logits in the generate method.
......
...@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -981,7 +981,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past_key_values: if past_key_values:
...@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -1000,14 +1000,23 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids = position_ids[:, -1].unsqueeze(-1) position_ids = position_ids[:, -1].unsqueeze(-1)
else: else:
position_ids = None position_ids = None
return {
"input_ids": input_ids, # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
"past_key_values": past_key_values, if inputs_embeds is not None and past_key_values is None:
"use_cache": kwargs.get("use_cache"), model_inputs = {"inputs_embeds": inputs_embeds}
"position_ids": position_ids, else:
"attention_mask": attention_mask, model_inputs = {"input_ids": input_ids}
"token_type_ids": token_type_ids,
} model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2359,17 +2359,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(diff < 1e-4) self.assertTrue(diff < 1e-4)
def test_decoder_generate_with_inputs_embeds(self):
article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=5).to(torch_device)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)
# cannot generate from `inputs_embeds` for decoder only
with self.assertRaises(ValueError):
model.generate(inputs_embeds=inputs_embeds)
def test_generate_input_ids_as_kwarg(self): def test_generate_input_ids_as_kwarg(self):
article = """I need input_ids to generate""" article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
...@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2417,8 +2406,10 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
def test_generate_too_many_encoder_kwargs(self): def test_generate_too_many_encoder_kwargs(self):
article = """I need input_ids to generate""" article = """I need input_ids to generate"""
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device) model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=10).to(
torch_device
)
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device) input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model.generate(input_ids=input_ids, inputs_embeds=input_ids) model.generate(input_ids=input_ids, inputs_embeds=input_ids)
...@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -3128,3 +3119,26 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
eos_token_id = [873] eos_token_id = [873]
generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs) generated_tokens = model.generate(**tokens, eos_token_id=eos_token_id, **generation_kwargs)
self.assertTrue(expectation == len(generated_tokens[0])) self.assertTrue(expectation == len(generated_tokens[0]))
def test_generate_from_input_embeds_decoder_only(self):
# Note: the model must support generation from input embeddings
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = "Hello world"
input_ids = tokenizer.encode(text, return_tensors="pt")
# Traditional way of generating text
outputs_from_ids = model.generate(input_ids)
# Same thing, but from input embeddings
inputs_embeds = model.transformer.wte(input_ids)
outputs_from_embeds = model.generate(input_ids, inputs_embeds=inputs_embeds)
self.assertListEqual(outputs_from_ids.tolist(), outputs_from_embeds.tolist())
# But if we pass different inputs_embeds, we should get different outputs
torch.manual_seed(0)
random_embeds = torch.rand_like(inputs_embeds)
outputs_from_rand_embeds = model.generate(input_ids, inputs_embeds=random_embeds)
with self.assertRaises(AssertionError):
self.assertListEqual(outputs_from_rand_embeds.tolist(), outputs_from_embeds.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment