"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a8aad0ec938778bf41df1e6842a7baab81776c64"
Unverified Commit e60491ad authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Fix Llava for 0-embeddings (#30473)

parent ad697f18
...@@ -327,8 +327,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel): ...@@ -327,8 +327,11 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
if labels is not None: if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): if image_to_overwrite.sum() != image_features.shape[:-1].numel():
......
...@@ -403,8 +403,11 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): ...@@ -403,8 +403,11 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel):
if labels is not None: if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): if image_to_overwrite.sum() != image_features.shape[:-1].numel():
......
...@@ -331,8 +331,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): ...@@ -331,8 +331,11 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel):
if labels is not None: if labels is not None:
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
if image_to_overwrite.sum() != image_features.shape[:-1].numel(): if image_to_overwrite.sum() != image_features.shape[:-1].numel():
......
...@@ -459,3 +459,29 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -459,3 +459,29 @@ class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase):
EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip EXPECTED_DECODED_TEXT = ['[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays', '[INST] \nWhat is shown in this image? [/INST] The image shows two cats lying on a pink surface, which appears to be a couch or a cush'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT) self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_small_model_integration_test_unk_token(self):
# related to (#29835)
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
load_in_4bit=True,
)
prompt_with_unk = "[INST] <image>\nWhat is shown in this <unk> image? [/INST]"
inputs = self.processor(prompt_with_unk, self.image, return_tensors="pt")
# verify single forward pass
inputs = inputs.to(torch_device)
with torch.no_grad():
output = model(**inputs)
# verify generation
output = model.generate(**inputs, max_new_tokens=40)
EXPECTED_DECODED_TEXT = '[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot that displays values for multiple quantitative variables represented on axes starting from the same point. This particular radar chart' # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment