Unverified Commit c2d1b075 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix issues for `Pixtral-Large-Instruct-2411` (#11393)


Signed-off-by: default avatarywang96 <ywang@example.com>
Co-authored-by: default avatarywang96 <ywang@example.com>
parent 584f0ae4
...@@ -45,8 +45,12 @@ try: ...@@ -45,8 +45,12 @@ try:
except ImportError: except ImportError:
USE_XFORMERS_OPS = False USE_XFORMERS_OPS = False
PIXTRAL_IMAGE_BREAK_ID = 12 # These token ids cannot be retrieved from model config
PIXTRAL_IMAGE_END_ID = 13 # so we hardcode them here.
PIXTRAL_12B_IMAGE_BREAK_ID = 12
PIXTRAL_12B_IMAGE_END_ID = 13
PIXTRAL_LARGE_IMAGE_BREAK_ID = 14
PIXTRAL_LARGE_IMAGE_END_ID = 15
def get_max_pixtral_image_tokens(ctx: InputContext): def get_max_pixtral_image_tokens(ctx: InputContext):
...@@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext, ...@@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext,
for image_data in data_list: for image_data in data_list:
image = ImageChunk(image=image_data) image = ImageChunk(image=image_data)
encoding = tokenizer.instruct.mm_encoder(image) encoding = tokenizer.instruct.mm_encoder(image)
image = torch.from_numpy(encoding.image).to(device="cuda", image = torch.from_numpy(encoding.image).to(dtype=torch.float16)
dtype=torch.float16)
images.append(image) images.append(image)
image_tokens_list.append(encoding.tokens) image_tokens_list.append(encoding.tokens)
...@@ -237,8 +240,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -237,8 +240,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# NOTE: Image embeddings are split into separate tensors for each image # NOTE: Image embeddings are split into separate tensors for each image
# by the indices of `[IMG_END]` token. # by the indices of `[IMG_END]` token.
split_indices = torch.where( image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | (
image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 image_tokens == PIXTRAL_LARGE_IMAGE_END_ID)
split_indices = torch.where(image_end_condition)[0] + 1
if len(split_indices) <= 1: if len(split_indices) <= 1:
# Do not split, return as tensor of shape [1, fs, hs] # Do not split, return as tensor of shape [1, fs, hs]
return image_embeds.unsqueeze(0) return image_embeds.unsqueeze(0)
...@@ -260,8 +264,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -260,8 +264,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, [ input_ids, inputs_embeds, multimodal_embeddings, [
self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, self.vision_args.image_token_id,
PIXTRAL_IMAGE_BREAK_ID PIXTRAL_12B_IMAGE_END_ID,
PIXTRAL_12B_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_BREAK_ID,
PIXTRAL_LARGE_IMAGE_END_ID,
]) ])
return inputs_embeds return inputs_embeds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment