"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "852e032ca6505f8ddd9881a7ed67ea0dd9fc7603"
Unverified Commit b469ebc5 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Prepend `bos token` to Blip generations (#29642)



* prepend "bos" to blip generation

* minor changes

* Update src/transformers/models/blip_2/modeling_blip_2.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/instructblip/modeling_instructblip.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add generation tester mixin

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent ee38fc31
...@@ -1828,8 +1828,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1828,8 +1828,10 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
# TODO (joao, raushan): refactor `generate` to avoid these operations with VLMs
if not self.language_model.config.is_encoder_decoder: if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate( outputs = self.language_model.generate(
...@@ -1838,4 +1840,16 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel): ...@@ -1838,4 +1840,16 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel):
**generate_kwargs, **generate_kwargs,
) )
# this is a temporary workaround to be consistent with other generation models and
# have BOS as the first token, even though under the hood we are calling LM with embeds
if not self.language_model.config.is_encoder_decoder:
bos_tokens = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else:
outputs = torch.cat([bos_tokens, outputs], dim=-1)
return outputs return outputs
...@@ -1538,8 +1538,9 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1538,8 +1538,9 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds # add image_embeds length to max_length, so that the final max_length in counted only on token embeds
# -1 is to account for the prepended BOS after `generate.`
if not self.language_model.config.is_encoder_decoder: if not self.language_model.config.is_encoder_decoder:
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1]
outputs = self.language_model.generate( outputs = self.language_model.generate(
...@@ -1548,13 +1549,21 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): ...@@ -1548,13 +1549,21 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel):
**generate_kwargs, **generate_kwargs,
) )
# the InstructBLIP authors used inconsistent tokenizer/model files during training, # this is a temporary workaround to be consistent with other generation models and
# with the tokenizer's bos token being set to </s> which has ID=2, # have BOS as the first token, even though under the hood we are calling LM with embeds
# whereas the model's text config has bos token id = 0 if not self.language_model.config.is_encoder_decoder:
if self.config.text_config.architectures[0] == "LLaMAForCausalLM": # the InstructBLIP authors used inconsistent tokenizer/model files during training,
if isinstance(outputs, torch.Tensor): # with the tokenizer's bos token being set to </s> which has ID=2,
outputs[outputs == 0] = 2 # whereas the model's text config has bos token id = 0
bos_token_id = (
2
if self.config.text_config.architectures[0] == "LLaMAForCausalLM"
else self.config.text_config.bos_token_id
)
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
if not isinstance(outputs, torch.Tensor):
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1)
else: else:
outputs.sequences[outputs.sequences == 0] = 2 outputs = torch.cat([bos_tokens, outputs], dim=-1)
return outputs return outputs
...@@ -32,6 +32,7 @@ from transformers.testing_utils import ( ...@@ -32,6 +32,7 @@ from transformers.testing_utils import (
) )
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
ModelTesterMixin, ModelTesterMixin,
...@@ -434,7 +435,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester: ...@@ -434,7 +435,7 @@ class Blip2ForConditionalGenerationDecoderOnlyModelTester:
@require_torch @require_torch
class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase): class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Blip2ForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
...@@ -683,7 +684,7 @@ class Blip2ModelTester: ...@@ -683,7 +684,7 @@ class Blip2ModelTester:
@require_torch @require_torch
class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else () all_model_classes = (Blip2ForConditionalGeneration, Blip2Model) if is_torch_available() else ()
pipeline_model_mapping = ( pipeline_model_mapping = (
{ {
...@@ -869,7 +870,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase): ...@@ -869,7 +870,8 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
prompt = "Question: which city is this? Answer:" prompt = "Question: which city is this? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16) inputs = processor(images=image, text=prompt, return_tensors="pt").to(torch_device, dtype=torch.float16)
predictions = model.generate(**inputs) # max_length for BLIP includes prompt length from now on, use max_new_tokens
predictions = model.generate(**inputs, max_new_tokens=11)
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip() generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
# Test output # Test output
......
...@@ -39,6 +39,7 @@ from transformers.testing_utils import ( ...@@ -39,6 +39,7 @@ from transformers.testing_utils import (
) )
from transformers.utils import is_torch_available, is_vision_available from transformers.utils import is_torch_available, is_vision_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ( from ...test_modeling_common import (
ModelTesterMixin, ModelTesterMixin,
...@@ -452,7 +453,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester: ...@@ -452,7 +453,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyModelTester:
@require_torch @require_torch
class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, unittest.TestCase): class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else () all_model_classes = (InstructBlipForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment