Unverified Commit 0f2f0c63 authored by Victor SANH's avatar Victor SANH Committed by GitHub
Browse files

Fix `_merge_input_ids_with_image_features` for llava model (#28333)



* fix `_merge_input_ids_with_image_features` for llava model

* Update src/transformers/models/llava/modeling_llava.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* adress comments

* style and tests

* ooops

* test the backward too

* Apply suggestions from code review
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update tests/models/vipllava/test_modeling_vipllava.py

* style and quality

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent 976189a6
...@@ -276,9 +276,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -276,9 +276,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
self.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings
return model_embeds return model_embeds
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
):
num_images, num_image_patches, embed_dim = image_features.shape num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
...@@ -307,6 +305,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -307,6 +305,10 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
final_attention_mask = torch.zeros( final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
) )
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device target_device = inputs_embeds.device
...@@ -321,6 +323,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -321,6 +323,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
...@@ -335,7 +339,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -335,7 +339,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
...@@ -420,8 +428,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -420,8 +428,8 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
) )
image_features = self.multi_modal_projector(selected_image_feature) image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, position_ids image_features, inputs_embeds, input_ids, attention_mask, labels
) )
if labels is None: if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
......
...@@ -284,9 +284,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -284,9 +284,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
self.vocab_size = model_embeds.num_embeddings self.vocab_size = model_embeds.num_embeddings
return model_embeds return model_embeds
def _merge_input_ids_with_image_features( def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
self, image_features, inputs_embeds, input_ids, attention_mask, position_ids
):
num_images, num_image_patches, embed_dim = image_features.shape num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
...@@ -315,6 +313,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -315,6 +313,10 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
final_attention_mask = torch.zeros( final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
) )
if labels is not None:
final_labels = torch.full(
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device. # set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device target_device = inputs_embeds.device
...@@ -329,6 +331,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -329,6 +331,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite = torch.all(final_embedding == 0, dim=-1)
...@@ -343,7 +347,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -343,7 +347,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
return final_embedding, final_attention_mask, position_ids
if labels is None:
final_labels = None
return final_embedding, final_attention_mask, final_labels, position_ids
@add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIPLLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=VipLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
...@@ -419,8 +427,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -419,8 +427,8 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
image_features = torch.cat(image_features, dim=-1) image_features = torch.cat(image_features, dim=-1)
image_features = self.multi_modal_projector(image_features) image_features = self.multi_modal_projector(image_features)
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features( inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, position_ids image_features, inputs_embeds, input_ids, attention_mask, labels
) )
if labels is None: if labels is None:
labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long) labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
......
...@@ -26,7 +26,7 @@ from transformers import ( ...@@ -26,7 +26,7 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
...@@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -332,3 +332,41 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
# Make sure that `generate` works # Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20) _ = model.generate(**inputs, max_new_tokens=20)
@slow
@require_torch_gpu
def test_llava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()
...@@ -26,7 +26,7 @@ from transformers import ( ...@@ -26,7 +26,7 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
...@@ -214,3 +214,41 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -214,3 +214,41 @@ class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on" EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT) self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)
@slow
@require_torch_gpu
def test_vipllava_merge_inputs_error_bug(self):
# This is a reproducer of https://github.com/huggingface/transformers/pull/28333 and makes sure it does not happen anymore
model_id = "llava-hf/vip-llava-7b-hf"
model = VipLlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(torch_device)
# Simulate some user inputs
pixel_values = torch.randn(
(2, 3, 336, 336),
dtype=torch.float,
device=torch_device,
)
input_ids = torch.tensor(
[
[32001, 32001, 1, 15043, 7084, 32000, 29871, 13, 7900],
[1, 15043, 7084, 29901, 29871, 32000, 29871, 13, 7900],
],
dtype=torch.long,
device=torch_device,
)
attention_mask = torch.tensor(
[[0, 0, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1]],
dtype=torch.long,
device=torch_device,
)
# Make sure that the loss is properly computed
loss = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
).loss
loss.backward()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment