Unverified Commit bd891aed authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Fix max length for BLIP generation (#29296)



* fix mal_length for blip

* update also min length

* fixes

* add a comment

* Update src/transformers/models/instructblip/modeling_instructblip.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/blip_2/modeling_blip_2.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* make fixup

* fix length when user passed

* remove else

* remove brackets

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 4fc708f9
...@@ -1453,6 +1453,7 @@ class GenerationMixin: ...@@ -1453,6 +1453,7 @@ class GenerationMixin:
and not self.config.is_encoder_decoder and not self.config.is_encoder_decoder
): ):
generation_config.max_length -= inputs_tensor.shape[1] generation_config.max_length -= inputs_tensor.shape[1]
generation_config.min_length = max(generation_config.min_length - inputs_tensor.shape[1], 0)
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static": if generation_config.cache_implementation == "static":
......
...@@ -1827,6 +1827,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1827,6 +1827,11 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate( outputs = self.language_model.generate(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -1537,6 +1537,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1537,6 +1537,11 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
inputs_embeds = self.get_input_embeddings()(input_ids) inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds
if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1]
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate( outputs = self.language_model.generate(
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment