Unverified Commit 13b23704 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Correct llava mask & fix missing setter for `vocab_size` (#29389)

* correct llava mask

* fix vipllava as wlel

* mask out embedding for padding tokens

* add test

* fix style

* add setter

* fix test on suggestion
parent aa17cf98
...@@ -147,6 +147,10 @@ class LlavaConfig(PretrainedConfig): ...@@ -147,6 +147,10 @@ class LlavaConfig(PretrainedConfig):
) )
return self._vocab_size return self._vocab_size
@vocab_size.setter
def vocab_size(self, value):
self._vocab_size = value
def to_dict(self): def to_dict(self):
output = super().to_dict() output = super().to_dict()
output.pop("_vocab_size", None) output.pop("_vocab_size", None)
......
...@@ -344,6 +344,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -344,6 +344,12 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None: if labels is None:
final_labels = None final_labels = None
...@@ -449,10 +455,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -449,10 +455,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
# Get the target length # Get the target length
target_seqlen = first_layer_past_key_value.shape[-1] + 1 target_length = input_ids.shape[1]
past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones( extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), (attention_mask.shape[0], past_length),
dtype=attention_mask.dtype, dtype=attention_mask.dtype,
device=attention_mask.device, device=attention_mask.device,
) )
...@@ -467,7 +474,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -467,7 +474,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
# Zero-out the places where we don't need to attend # Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model( outputs = self.language_model(
......
...@@ -356,7 +356,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -356,7 +356,6 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
num_images, num_image_patches, embed_dim = image_features.shape num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
# 1. Create a mask to know where special image tokens are # 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index special_image_token_mask = input_ids == self.config.image_token_index
...@@ -418,6 +417,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -418,6 +417,12 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None: if labels is None:
final_labels = None final_labels = None
......
...@@ -347,6 +347,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -347,6 +347,12 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
final_attention_mask |= image_to_overwrite final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
if labels is None: if labels is None:
final_labels = None final_labels = None
...@@ -442,11 +448,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -442,11 +448,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
# Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 # Sum all dimensions of head_dim (-1) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0) batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-1) == 0)
# Get the target length target_length = input_ids.shape[1]
target_seqlen = first_layer_past_key_value.shape[-2] + 1 past_length = first_layer_past_key_value.shape[-1]
extended_attention_mask = torch.ones( extended_attention_mask = torch.ones(
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), (attention_mask.shape[0], past_length),
dtype=attention_mask.dtype, dtype=attention_mask.dtype,
device=attention_mask.device, device=attention_mask.device,
) )
...@@ -461,7 +467,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -461,7 +467,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
# Zero-out the places where we don't need to attend # Zero-out the places where we don't need to attend
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1) attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
outputs = self.language_model( outputs = self.language_model(
......
...@@ -27,7 +27,14 @@ from transformers import ( ...@@ -27,7 +27,14 @@ from transformers import (
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
) )
from transformers.testing_utils import require_bitsandbytes, require_torch, require_torch_gpu, slow, torch_device from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
...@@ -470,10 +477,45 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -470,10 +477,45 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20) output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this serene location, one should be cautious about the weather conditions and potential', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: When visiting this place, which appears to be a dock or pier extending over a body of water', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_torch
@require_vision
def test_batched_generation(self):
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf").to(torch_device)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
prompt1 = "<image>\n<image>\nUSER: What's the the difference of two images?\nASSISTANT:"
prompt2 = "<image>\nUSER: Describe the image.\nASSISTANT:"
prompt3 = "<image>\nUSER: Describe the image.\nASSISTANT:"
url1 = "https://images.unsplash.com/photo-1552053831-71594a27632d?q=80&w=3062&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
url2 = "https://images.unsplash.com/photo-1617258683320-61900b281ced?q=80&w=3087&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
image1 = Image.open(requests.get(url1, stream=True).raw)
image2 = Image.open(requests.get(url2, stream=True).raw)
inputs = processor(
text=[prompt1, prompt2, prompt3],
images=[image1, image2, image1, image2],
return_tensors="pt",
padding=True,
).to(torch_device)
model = model.eval()
EXPECTED_OUTPUT = [
"\n \nUSER: What's the the difference of two images?\nASSISTANT: In the two images, the primary difference is the presence of a small dog holding a flower in one",
"\nUSER: Describe the image.\nASSISTANT: The image features a small, fluffy dog sitting on a sidewalk. The dog is holding",
"\nUSER: Describe the image.\nASSISTANT: The image features a lone, adult llama standing on a grassy hill. The llama",
]
generate_ids = model.generate(**inputs, max_new_tokens=20)
outputs = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.assertEqual(outputs, EXPECTED_OUTPUT)
@slow @slow
@require_bitsandbytes @require_bitsandbytes
def test_llava_index_error_bug(self): def test_llava_index_error_bug(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment