Unverified Commit 2f766f38 authored by bppps's avatar bppps Committed by GitHub
Browse files

[Bugfix]: distinguish processors for deepseek_vl2 and deepseek_ocr to p… (#12384)

parent 069e490b
from typing import Tuple
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
from PIL import Image, ImageOps
from transformers import (
AutoProcessor,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
from sglang.srt.multimodal.customized_mm_processor_utils import (
register_customized_processor,
)
BASE_SIZE = 1024
IMAGE_SIZE = 640
......@@ -18,18 +29,59 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
class ImageTransform:
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor
def __len__(self):
return len(self.input_ids)
class ImageTransform(object):
def __init__(
self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
normalize: bool = True,
):
self.mean = mean
self.std = std
self.normalize = normalize
# only load torchvision.transforms when needed
try:
import torchvision.transforms as T
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
) from err
transform_pipelines = [T.ToTensor()]
if normalize:
......@@ -42,6 +94,464 @@ class ImageTransform:
return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekOCRProcessor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<|▁pad▁|>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.candidate_resolutions = candidate_resolutions
self.image_size = candidate_resolutions[0][0]
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer
# must set this,padding side with make a difference in batch inference
self.tokenizer.padding_side = "left"
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token
image_token_id = self.tokenizer.vocab.get(image_token)
if image_token_id is None:
special_tokens = [image_token]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# add special tokens for SFT data
special_tokens = ["<|User|>", "<|Assistant|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
images_list = []
images_seq_mask = []
images_spatial_crop = []
image_index = 0
image_token_cnt = messages.count(self.image_token)
(
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
)
image_index = image_token_cnt
images_list += images
images_seq_mask += seq_mask
images_spatial_crop = spatial_crop
return (
input_ids,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
)
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
cropping: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
if conversations is not None, then it will always apply the SFT format to conversations;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prompt = conversations or prompt
(
input_ids,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
) = self.format_messages_v2(prompt, images, max_req_input_len)
target_ids = torch.LongTensor(masked_tokenized_str)
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
images_crop=images_crop,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
text: list[str] = None,
**kwargs,
):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one(
prompt=prompt or text,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
inference_mode=inference_mode,
system_prompt=system_prompt,
max_req_input_len=max_req_input_len,
)
return prepare
def find_all_indices(self, messages, target_value):
indices = []
for index, item in enumerate(messages):
if item == target_value:
indices.append(index)
return indices
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
conversation = conversation
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
if cropping:
images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
"""process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(
image,
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
"""add image tokens"""
num_queries = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision"
......@@ -223,6 +733,7 @@ class DeepseekV2Config(PretrainedConfig):
)
@register_customized_processor(processor_class=DeepseekOCRProcessor)
class DeepseekVLV2Config(PretrainedConfig):
# model_type = "deepseek_vl_v2"
model_type = "deepseek-ocr"
......@@ -232,6 +743,7 @@ class DeepseekVLV2Config(PretrainedConfig):
tile_tag: str = "2D"
global_view_pos: str = "head"
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
customized_processor_type: type[Any] = DeepseekOCRProcessor
def __init__(
self,
......@@ -258,5 +770,4 @@ class DeepseekVLV2Config(PretrainedConfig):
self.hidden_size = self.text_config.hidden_size
class DeepseekOCRConfig(DeepseekV2Config):
model_type = "DeepseekOCR"
AutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor)
......@@ -11,8 +11,6 @@ from transformers import (
ProcessorMixin,
)
from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
def select_best_resolution(image_size, candidate_resolutions):
# used for cropping
......@@ -63,7 +61,6 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
......@@ -107,68 +104,6 @@ class ImageTransform(object):
return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
......@@ -198,7 +133,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
......@@ -241,7 +176,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
**kwargs,
)
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
......@@ -251,34 +186,35 @@ class DeepseekVLV2Processor(ProcessorMixin):
image_index = 0
image_token_cnt = messages.count(self.image_token)
(
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
)
image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images
images_seq_mask += seq_mask
images_spatial_crop = spatial_crop
images_spatial_crop += spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return (
input_ids,
tokenized_data,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
)
@property
......@@ -315,7 +251,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
cropping: bool = True,
**kwargs,
):
"""
......@@ -339,22 +274,47 @@ class DeepseekVLV2Processor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens
"""
prompt = conversations or prompt
assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
(
input_ids,
tokenized_str,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
) = self.format_messages_v2(prompt, images, max_req_input_len)
) = self.format_messages_v2(conversations, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
......@@ -363,7 +323,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
images_crop=images_crop,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
......@@ -381,14 +340,10 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
text: list[str] = None,
**kwargs,
):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one(
prompt=prompt or text,
prompt=prompt,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
......@@ -413,83 +368,85 @@ class DeepseekVLV2Processor(ProcessorMixin):
bos: bool = True,
eos: bool = True,
cropping: bool = True,
max_req_input_len: int = -1,
):
"""Tokenize text with <image> tags."""
conversation = conversation
assert conversation.count(self.image_token) == len(images)
images_list, images_seq_mask, images_spatial_crop = [], [], []
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
"""select best resolution for anyres"""
if cropping:
images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
best_width, best_height = select_best_resolution(
image.size, self.candidate_resolutions
)
else:
crop_ratio = [1, 1]
best_width, best_height = self.image_size, self.image_size
# print(image.size, (best_width, best_height)) # check the select_best_resolutions func
"""process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(
image,
(self.base_size, self.base_size),
(self.image_size, self.image_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""process the local views"""
local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
if num_width_tiles > 1 or num_height_tiles > 1:
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
"""record height / width crop num"""
num_width_tiles, num_height_tiles = (
best_width // self.image_size,
best_height // self.image_size,
)
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""add image tokens"""
num_queries = math.ceil(
h = w = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
# global views tokens h * (w + 1), 1 is for line separator
tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
# print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
......@@ -505,64 +462,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
return tokenized_str, images_list, images_seq_mask, images_spatial_crop
class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
......@@ -647,6 +547,7 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"]
......
from typing import Dict, Type
from transformers import PretrainedConfig, ProcessorMixin
# Useful for registering a custom processor different from Hugging Face's default.
_CUSTOMIZED_MM_PROCESSOR: Dict[str, Type[ProcessorMixin]] = dict()
def register_customized_processor(
processor_class: Type[ProcessorMixin],
):
"""Class decorator that maps a config class's model_type field to a customized processor class.
Args:
processor_class: A processor class that inherits from ProcessorMixin
Example:
```python
@register_customized_processor(MyCustomProcessor)
class MyModelConfig(PretrainedConfig):
model_type = "my_model"
```
"""
def decorator(config_class: PretrainedConfig):
if not hasattr(config_class, "model_type"):
raise ValueError(
f"Class {config_class.__name__} with register_customized_processor should "
f"have a 'model_type' class attribute."
)
_CUSTOMIZED_MM_PROCESSOR[config_class.model_type] = processor_class
return config_class
return decorator
......@@ -54,6 +54,7 @@ from sglang.srt.configs import (
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
......@@ -172,6 +173,16 @@ def _load_deepseek_v32_model(
)
def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool:
# TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr.
# Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
return (
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
)
@lru_cache_frozenset(maxsize=32)
def get_config(
model: str,
......@@ -235,11 +246,7 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY:
model_type = config.model_type
if model_type == "deepseek_vl_v2":
if (
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
):
if _is_deepseek_ocr_model(config):
model_type = "deepseek-ocr"
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model, revision=revision)
......@@ -445,6 +452,10 @@ def get_processor(
**kwargs,
)
if _is_deepseek_ocr_model(config):
# Temporary hack for load deepseek-ocr
config.model_type = "deepseek-ocr"
# fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
if "size" not in kwargs:
......@@ -461,6 +472,15 @@ def get_processor(
revision=revision,
**kwargs,
)
else:
if config.model_type in _CUSTOMIZED_MM_PROCESSOR:
processor = _CUSTOMIZED_MM_PROCESSOR[config.model_type].from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
else:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment