Unverified Commit 2f766f38 authored by bppps's avatar bppps Committed by GitHub
Browse files

[Bugfix]: distinguish processors for deepseek_vl2 and deepseek_ocr to p… (#12384)

parent 069e490b
from typing import Tuple import math
from dataclasses import dataclass
import torchvision.transforms as T from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
from transformers import PretrainedConfig import torch
from PIL import Image, ImageOps
from transformers import (
AutoProcessor,
LlamaTokenizerFast,
PretrainedConfig,
ProcessorMixin,
)
from sglang.srt.multimodal.customized_mm_processor_utils import (
register_customized_processor,
)
BASE_SIZE = 1024 BASE_SIZE = 1024
IMAGE_SIZE = 640 IMAGE_SIZE = 640
...@@ -18,18 +29,59 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path ...@@ -18,18 +29,59 @@ MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
PROMPT = "<image>\n<|grounding|>Convert the document to markdown." PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
class ImageTransform: class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
def __getitem__(self, item):
return self.__dict__[item]
def __contains__(self, key):
return key in self.__dict__
def __setitem__(self, key, value):
self.__dict__[key] = value
@dataclass
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor
def __len__(self):
return len(self.input_ids)
class ImageTransform(object):
def __init__( def __init__(
self, self,
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5), mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
std: Tuple[float, float, float] = (0.5, 0.5, 0.5), std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
normalize: bool = True, normalize: bool = True,
): ):
self.mean = mean self.mean = mean
self.std = std self.std = std
self.normalize = normalize self.normalize = normalize
# only load torchvision.transforms when needed
try:
import torchvision.transforms as T
# FIXME: add version check for gguf
except ImportError as err:
raise ImportError(
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
) from err
transform_pipelines = [T.ToTensor()] transform_pipelines = [T.ToTensor()]
if normalize: if normalize:
...@@ -42,6 +94,464 @@ class ImageTransform: ...@@ -42,6 +94,464 @@ class ImageTransform:
return x return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekOCRProcessor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"]
def __init__(
self,
tokenizer: LlamaTokenizerFast,
candidate_resolutions: Tuple[Tuple[int, int]],
patch_size: int,
downsample_ratio: int,
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
normalize: bool = True,
image_token: str = "<image>",
pad_token: str = "<|▁pad▁|>",
add_special_token: bool = False,
sft_format: str = "deepseek",
mask_prompt: bool = True,
ignore_id: int = -100,
**kwargs,
):
self.candidate_resolutions = candidate_resolutions
self.image_size = candidate_resolutions[0][0]
self.patch_size = patch_size
self.image_mean = image_mean
self.image_std = image_std
self.normalize = normalize
self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize
)
self.tokenizer = tokenizer
# must set this,padding side with make a difference in batch inference
self.tokenizer.padding_side = "left"
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
if tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({"pad_token": pad_token})
# add image token
image_token_id = self.tokenizer.vocab.get(image_token)
if image_token_id is None:
special_tokens = [image_token]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token_id = self.tokenizer.vocab.get(image_token)
# add five special tokens for grounding-related tasks
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
# add special tokens for SFT data
special_tokens = ["<|User|>", "<|Assistant|>"]
special_tokens_dict = {"additional_special_tokens": special_tokens}
self.tokenizer.add_special_tokens(special_tokens_dict)
self.image_token = image_token
self.pad_token = pad_token
self.add_special_token = add_special_token
self.sft_format = sft_format
self.mask_prompt = mask_prompt
self.ignore_id = ignore_id
super().__init__(
tokenizer,
**kwargs,
)
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = []
masked_tokenized_data = [] # labels
images_list = []
images_seq_mask = []
images_spatial_crop = []
image_index = 0
image_token_cnt = messages.count(self.image_token)
(
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
messages,
pil_images[image_index : image_index + image_token_cnt],
bos=True,
eos=True,
cropping=len(pil_images) <= 2,
)
image_index = image_token_cnt
images_list += images
images_seq_mask += seq_mask
images_spatial_crop = spatial_crop
return (
input_ids,
masked_tokenized_data,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
)
@property
def bos_id(self):
return self.tokenizer.bos_token_id
@property
def eos_id(self):
return self.tokenizer.eos_token_id
@property
def pad_id(self):
return self.tokenizer.pad_token_id
def encode(self, text: str, bos: bool = True, eos: bool = False):
t = self.tokenizer.encode(text, add_special_tokens=False)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int], **kwargs) -> str:
return self.tokenizer.decode(t, **kwargs)
def process_one(
self,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
cropping: bool = True,
**kwargs,
):
"""
Args:
prompt (str): the formatted prompt;
conversations (List[Dict]): conversations with a list of messages;
images (List[ImageType]): the list of images;
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
if conversations is not None, then it will always apply the SFT format to conversations;
inference_mode (bool): if True, then remove the last eos token;
system_prompt (str): the system prompt;
**kwargs:
Returns:
outputs (BaseProcessorOutput): the output of the processor,
- input_ids (torch.LongTensor): [N + image tokens]
- target_ids (torch.LongTensor): [N + image tokens]
- images (torch.FloatTensor): [n_images, 3, H, W]
- image_id (int): the id of the image token
- num_image_tokens (List[int]): the number of image tokens
"""
prompt = conversations or prompt
(
input_ids,
masked_tokenized_str,
images_list,
images_seq_mask,
images_spatial_crop,
images_crop,
) = self.format_messages_v2(prompt, images, max_req_input_len)
target_ids = torch.LongTensor(masked_tokenized_str)
if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size))
else:
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
images_crop=images_crop,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
return prepare
def __call__(
self,
*,
prompt: str = None,
conversations: List[Dict[str, str]] = None,
images: List[Image.Image] = None,
apply_sft_format: bool = False,
inference_mode: bool = True,
system_prompt: str = "",
max_req_input_len: int = -1,
text: list[str] = None,
**kwargs,
):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one(
prompt=prompt or text,
conversations=conversations,
images=images,
apply_sft_format=apply_sft_format,
inference_mode=inference_mode,
system_prompt=system_prompt,
max_req_input_len=max_req_input_len,
)
return prepare
def find_all_indices(self, messages, target_value):
indices = []
for index, item in enumerate(messages):
if item == target_value:
indices.append(index)
return indices
def tokenize_with_images(
self,
conversation: str,
images: List[Image.Image],
bos: bool = True,
eos: bool = True,
cropping: bool = True,
):
"""Tokenize text with <image> tags."""
conversation = conversation
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = []
for text_sep, image in zip(text_splits, images):
"""encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
image_shapes.append(image.size)
if image.size[0] <= 640 and image.size[1] <= 640:
crop_ratio = [1, 1]
else:
if cropping:
images_crop_raw, crop_ratio = dynamic_preprocess(
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
"""process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad(
image,
(self.base_size, self.base_size),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
images_list.append(self.image_transform(global_view))
num_width_tiles, num_height_tiles = crop_ratio
images_spatial_crop.append([num_width_tiles, num_height_tiles])
if num_width_tiles > 1 or num_height_tiles > 1:
for i in range(len(images_crop_raw)):
images_crop_list.append(self.image_transform(images_crop_raw[i]))
"""add image tokens"""
num_queries = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio
)
num_queries_base = math.ceil(
(self.base_size // self.patch_size) / self.downsample_ratio
)
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image))
"""process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep)
"""add the bos and eos tokens"""
if bos:
tokenized_str = [self.bos_id] + tokenized_str
images_seq_mask = [False] + images_seq_mask
if eos:
tokenized_str = tokenized_str + [self.eos_id]
images_seq_mask = images_seq_mask + [False]
assert len(tokenized_str) == len(
images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = []
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
class VisionEncoderConfig(PretrainedConfig): class VisionEncoderConfig(PretrainedConfig):
model_type: str = "vision" model_type: str = "vision"
...@@ -223,6 +733,7 @@ class DeepseekV2Config(PretrainedConfig): ...@@ -223,6 +733,7 @@ class DeepseekV2Config(PretrainedConfig):
) )
@register_customized_processor(processor_class=DeepseekOCRProcessor)
class DeepseekVLV2Config(PretrainedConfig): class DeepseekVLV2Config(PretrainedConfig):
# model_type = "deepseek_vl_v2" # model_type = "deepseek_vl_v2"
model_type = "deepseek-ocr" model_type = "deepseek-ocr"
...@@ -232,6 +743,7 @@ class DeepseekVLV2Config(PretrainedConfig): ...@@ -232,6 +743,7 @@ class DeepseekVLV2Config(PretrainedConfig):
tile_tag: str = "2D" tile_tag: str = "2D"
global_view_pos: str = "head" global_view_pos: str = "head"
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),) candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
customized_processor_type: type[Any] = DeepseekOCRProcessor
def __init__( def __init__(
self, self,
...@@ -258,5 +770,4 @@ class DeepseekVLV2Config(PretrainedConfig): ...@@ -258,5 +770,4 @@ class DeepseekVLV2Config(PretrainedConfig):
self.hidden_size = self.text_config.hidden_size self.hidden_size = self.text_config.hidden_size
class DeepseekOCRConfig(DeepseekV2Config): AutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor)
model_type = "DeepseekOCR"
...@@ -11,8 +11,6 @@ from transformers import ( ...@@ -11,8 +11,6 @@ from transformers import (
ProcessorMixin, ProcessorMixin,
) )
from sglang.srt.configs.deepseek_ocr import BASE_SIZE, IMAGE_SIZE, MAX_CROPS, MIN_CROPS
def select_best_resolution(image_size, candidate_resolutions): def select_best_resolution(image_size, candidate_resolutions):
# used for cropping # used for cropping
...@@ -63,7 +61,6 @@ class DictOutput(object): ...@@ -63,7 +61,6 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput): class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor input_ids: torch.LongTensor
target_ids: torch.LongTensor target_ids: torch.LongTensor
images_crop: torch.LongTensor
pixel_values: ( pixel_values: (
torch.Tensor torch.Tensor
) # rename from "images" to "pixel_values" for compatibility ) # rename from "images" to "pixel_values" for compatibility
...@@ -107,68 +104,6 @@ class ImageTransform(object): ...@@ -107,68 +104,6 @@ class ImageTransform(object):
return x return x
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
class DeepseekVLV2Processor(ProcessorMixin): class DeepseekVLV2Processor(ProcessorMixin):
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
attributes = ["tokenizer"] attributes = ["tokenizer"]
...@@ -198,7 +133,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -198,7 +133,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
self.image_std = image_std self.image_std = image_std
self.normalize = normalize self.normalize = normalize
self.downsample_ratio = downsample_ratio self.downsample_ratio = downsample_ratio
self.base_size = BASE_SIZE
self.image_transform = ImageTransform( self.image_transform = ImageTransform(
mean=image_mean, std=image_std, normalize=normalize mean=image_mean, std=image_std, normalize=normalize
) )
...@@ -241,7 +176,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -241,7 +176,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
**kwargs, **kwargs,
) )
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1): def format_messages_v2(self, messages, pil_images, max_req_input_len=-1):
"""play the role of format_messages_v2 and get_images_info in the last version""" """play the role of format_messages_v2 and get_images_info in the last version"""
tokenized_data = [] tokenized_data = []
masked_tokenized_data = [] # labels masked_tokenized_data = [] # labels
...@@ -251,34 +186,35 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -251,34 +186,35 @@ class DeepseekVLV2Processor(ProcessorMixin):
image_index = 0 image_index = 0
image_token_cnt = messages.count(self.image_token) image_token_cnt = messages.count(self.image_token)
( tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
input_ids,
images,
images_crop,
seq_mask,
spatial_crop,
num_image_tokens,
image_shapes,
) = self.tokenize_with_images(
messages, messages,
pil_images[image_index : image_index + image_token_cnt], pil_images[image_index : image_index + image_token_cnt],
bos=True, bos=True,
eos=True, eos=True,
cropping=len(pil_images) <= 2, cropping=len(pil_images) <= 2,
max_req_input_len=max_req_input_len,
) )
image_index = image_token_cnt image_index = image_token_cnt
tokenized_data += tokenized_str
if self.mask_prompt:
masked_tokenized_data += [self.ignore_id] * len(tokenized_str)
else:
masked_tokenized_data += tokenized_str
images_list += images images_list += images
images_seq_mask += seq_mask images_seq_mask += seq_mask
images_spatial_crop = spatial_crop images_spatial_crop += spatial_crop
assert len(tokenized_data) == len(
images_seq_mask
), f"format_messages_v2: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
return ( return (
input_ids, tokenized_data,
masked_tokenized_data, masked_tokenized_data,
images_list, images_list,
images_seq_mask, images_seq_mask,
images_spatial_crop, images_spatial_crop,
images_crop,
) )
@property @property
...@@ -315,7 +251,6 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -315,7 +251,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True, inference_mode: bool = True,
system_prompt: str = "", system_prompt: str = "",
max_req_input_len: int = -1, max_req_input_len: int = -1,
cropping: bool = True,
**kwargs, **kwargs,
): ):
""" """
...@@ -339,22 +274,47 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -339,22 +274,47 @@ class DeepseekVLV2Processor(ProcessorMixin):
- num_image_tokens (List[int]): the number of image tokens - num_image_tokens (List[int]): the number of image tokens
""" """
prompt = conversations or prompt assert (
prompt is None or conversations is None
), "prompt and conversations cannot be used at the same time."
( (
input_ids, tokenized_str,
masked_tokenized_str, masked_tokenized_str,
images_list, images_list,
images_seq_mask, images_seq_mask,
images_spatial_crop, images_spatial_crop,
images_crop, ) = self.format_messages_v2(conversations, images, max_req_input_len)
) = self.format_messages_v2(prompt, images, max_req_input_len)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str) target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
if inference_mode:
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0: if len(images_list) == 0:
images = torch.zeros((1, 3, self.image_size, self.image_size)) images = torch.zeros((1, 3, self.image_size, self.image_size))
images_spatial_crop = torch.zeros((1, 2), dtype=torch.long)
else: else:
images = torch.stack(images_list, dim=0) images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack( images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0 [images_spatial_crop], dim=0
...@@ -363,7 +323,6 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -363,7 +323,6 @@ class DeepseekVLV2Processor(ProcessorMixin):
prepare = VLChatProcessorOutput( prepare = VLChatProcessorOutput(
input_ids=input_ids, input_ids=input_ids,
target_ids=target_ids, target_ids=target_ids,
images_crop=images_crop,
pixel_values=images, pixel_values=images,
images_seq_mask=images_seq_mask, images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop, images_spatial_crop=images_spatial_crop,
...@@ -381,14 +340,10 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -381,14 +340,10 @@ class DeepseekVLV2Processor(ProcessorMixin):
inference_mode: bool = True, inference_mode: bool = True,
system_prompt: str = "", system_prompt: str = "",
max_req_input_len: int = -1, max_req_input_len: int = -1,
text: list[str] = None,
**kwargs, **kwargs,
): ):
assert text is None or isinstance(text, list)
if text is not None:
text = text[0]
prepare = self.process_one( prepare = self.process_one(
prompt=prompt or text, prompt=prompt,
conversations=conversations, conversations=conversations,
images=images, images=images,
apply_sft_format=apply_sft_format, apply_sft_format=apply_sft_format,
...@@ -413,83 +368,85 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -413,83 +368,85 @@ class DeepseekVLV2Processor(ProcessorMixin):
bos: bool = True, bos: bool = True,
eos: bool = True, eos: bool = True,
cropping: bool = True, cropping: bool = True,
max_req_input_len: int = -1,
): ):
"""Tokenize text with <image> tags.""" """Tokenize text with <image> tags."""
images_list, images_seq_mask, images_spatial_crop = [], [], []
conversation = conversation
assert conversation.count(self.image_token) == len(images)
text_splits = conversation.split(self.image_token) text_splits = conversation.split(self.image_token)
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
[],
[],
[],
[],
)
image_shapes = []
num_image_tokens = []
tokenized_str = [] tokenized_str = []
for text_sep, image in zip(text_splits, images): for text_sep, image in zip(text_splits, images):
"""encode text_sep""" """encode text_sep"""
tokenized_sep = self.encode(text_sep, bos=False, eos=False) tokenized_sep = self.encode(text_sep, bos=False, eos=False)
tokenized_str += tokenized_sep tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep) images_seq_mask += [False] * len(tokenized_sep)
image_shapes.append(image.size) """select best resolution for anyres"""
if cropping:
if image.size[0] <= 640 and image.size[1] <= 640: best_width, best_height = select_best_resolution(
crop_ratio = [1, 1] image.size, self.candidate_resolutions
)
else: else:
if cropping: best_width, best_height = self.image_size, self.image_size
images_crop_raw, crop_ratio = dynamic_preprocess( # print(image.size, (best_width, best_height)) # check the select_best_resolutions func
image, image_size=IMAGE_SIZE
)
else:
crop_ratio = [1, 1]
"""process the global view""" """process the global view"""
if self.image_size <= 640 and not cropping:
image = image.resize((self.image_size, self.image_size))
global_view = ImageOps.pad( global_view = ImageOps.pad(
image, image,
(self.base_size, self.base_size), (self.image_size, self.image_size),
color=tuple(int(x * 255) for x in self.image_transform.mean), color=tuple(int(x * 255) for x in self.image_transform.mean),
) )
images_list.append(self.image_transform(global_view)) images_list.append(self.image_transform(global_view))
num_width_tiles, num_height_tiles = crop_ratio """process the local views"""
images_spatial_crop.append([num_width_tiles, num_height_tiles]) local_view = ImageOps.pad(
image,
(best_width, best_height),
color=tuple(int(x * 255) for x in self.image_transform.mean),
)
for i in range(0, best_height, self.image_size):
for j in range(0, best_width, self.image_size):
images_list.append(
self.image_transform(
local_view.crop(
(j, i, j + self.image_size, i + self.image_size)
)
)
)
if num_width_tiles > 1 or num_height_tiles > 1: """record height / width crop num"""
for i in range(len(images_crop_raw)): num_width_tiles, num_height_tiles = (
images_crop_list.append(self.image_transform(images_crop_raw[i])) best_width // self.image_size,
best_height // self.image_size,
)
images_spatial_crop.append([num_width_tiles, num_height_tiles])
"""add image tokens""" """add image tokens"""
num_queries = math.ceil( h = w = math.ceil(
(self.image_size // self.patch_size) / self.downsample_ratio (self.image_size // self.patch_size) / self.downsample_ratio
) )
num_queries_base = math.ceil( # global views tokens h * (w + 1), 1 is for line separator
(self.base_size // self.patch_size) / self.downsample_ratio tokenized_image = [self.image_token_id] * h * (w + 1)
# add a separator between global and local views
tokenized_image += [self.image_token_id]
# local views tokens, (num_height_tiles * h) * (num_width_tiles * w + 1)
tokenized_image += (
[self.image_token_id]
* (num_height_tiles * h)
* (num_width_tiles * w + 1)
) )
tokenized_image = (
[self.image_token_id] * num_queries_base + [self.image_token_id]
) * num_queries_base
tokenized_image += [self.image_token_id]
if num_width_tiles > 1 or num_height_tiles > 1:
tokenized_image += (
[self.image_token_id] * (num_queries * num_width_tiles)
+ [self.image_token_id]
) * (num_queries * num_height_tiles)
tokenized_str += tokenized_image tokenized_str += tokenized_image
images_seq_mask += [True] * len(tokenized_image) images_seq_mask += [True] * len(tokenized_image)
num_image_tokens.append(len(tokenized_image)) # print(width_crop_num, height_crop_num, len(tokenized_image)) # test the correctness of the number of image-related tokens
"""process the last text split""" """process the last text split"""
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False) tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
# deal with video, limit with request len
if max_req_input_len > -1:
if max_req_input_len < len(tokenized_sep) + len(tokenized_str) - 1:
rest = max_req_input_len - len(tokenized_sep) - 1 - 1024
tokenized_str = tokenized_str[:rest]
images_seq_mask = images_seq_mask[:rest]
tokenized_str += tokenized_sep tokenized_str += tokenized_sep
images_seq_mask += [False] * len(tokenized_sep) images_seq_mask += [False] * len(tokenized_sep)
...@@ -505,64 +462,7 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -505,64 +462,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
images_seq_mask images_seq_mask
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}" ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
masked_tokenized_str = [] return tokenized_str, images_list, images_seq_mask, images_spatial_crop
for token_index in tokenized_str:
if token_index != self.image_token_id:
masked_tokenized_str.append(token_index)
else:
masked_tokenized_str.append(self.ignore_id)
assert (
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
), (
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
)
input_ids = torch.LongTensor(tokenized_str)
target_ids = torch.LongTensor(masked_tokenized_str)
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
self.ignore_id
)
input_ids[input_ids < 0] = self.pad_id
inference_mode = True
if inference_mode:
# Remove the ending eos token
assert input_ids[-1] == self.eos_id
input_ids = input_ids[:-1]
target_ids = target_ids[:-1]
images_seq_mask = images_seq_mask[:-1]
if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
input_ids = input_ids.unsqueeze(0)
return (
input_ids,
pixel_values,
images_crop,
images_seq_mask,
images_spatial_crop,
num_image_tokens,
image_shapes,
)
class DeepseekVL2VisionEncoderConfig(PretrainedConfig): class DeepseekVL2VisionEncoderConfig(PretrainedConfig):
...@@ -647,6 +547,7 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig): ...@@ -647,6 +547,7 @@ class DeepseekVL2MlpProjectorConfig(PretrainedConfig):
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
model_type = "deepseek_v2" model_type = "deepseek_v2"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
......
from typing import Dict, Type
from transformers import PretrainedConfig, ProcessorMixin
# Useful for registering a custom processor different from Hugging Face's default.
_CUSTOMIZED_MM_PROCESSOR: Dict[str, Type[ProcessorMixin]] = dict()
def register_customized_processor(
processor_class: Type[ProcessorMixin],
):
"""Class decorator that maps a config class's model_type field to a customized processor class.
Args:
processor_class: A processor class that inherits from ProcessorMixin
Example:
```python
@register_customized_processor(MyCustomProcessor)
class MyModelConfig(PretrainedConfig):
model_type = "my_model"
```
"""
def decorator(config_class: PretrainedConfig):
if not hasattr(config_class, "model_type"):
raise ValueError(
f"Class {config_class.__name__} with register_customized_processor should "
f"have a 'model_type' class attribute."
)
_CUSTOMIZED_MM_PROCESSOR[config_class.model_type] = processor_class
return config_class
return decorator
...@@ -54,6 +54,7 @@ from sglang.srt.configs import ( ...@@ -54,6 +54,7 @@ from sglang.srt.configs import (
from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector from sglang.srt.connector import create_remote_connector
from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR
from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
_CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [
...@@ -172,6 +173,16 @@ def _load_deepseek_v32_model( ...@@ -172,6 +173,16 @@ def _load_deepseek_v32_model(
) )
def _is_deepseek_ocr_model(config: PretrainedConfig) -> bool:
# TODO: Remove this workaround related when AutoConfig correctly identifies deepseek-ocr.
# Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
return (
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
)
@lru_cache_frozenset(maxsize=32) @lru_cache_frozenset(maxsize=32)
def get_config( def get_config(
model: str, model: str,
...@@ -235,11 +246,7 @@ def get_config( ...@@ -235,11 +246,7 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
model_type = config.model_type model_type = config.model_type
if model_type == "deepseek_vl_v2": if model_type == "deepseek_vl_v2":
if ( if _is_deepseek_ocr_model(config):
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
):
model_type = "deepseek-ocr" model_type = "deepseek-ocr"
config_class = _CONFIG_REGISTRY[model_type] config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model, revision=revision) config = config_class.from_pretrained(model, revision=revision)
...@@ -445,6 +452,10 @@ def get_processor( ...@@ -445,6 +452,10 @@ def get_processor(
**kwargs, **kwargs,
) )
if _is_deepseek_ocr_model(config):
# Temporary hack for load deepseek-ocr
config.model_type = "deepseek-ocr"
# fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided. # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided.
if config.model_type in {"qwen2_vl", "sarashina2_vision"}: if config.model_type in {"qwen2_vl", "sarashina2_vision"}:
if "size" not in kwargs: if "size" not in kwargs:
...@@ -462,13 +473,22 @@ def get_processor( ...@@ -462,13 +473,22 @@ def get_processor(
**kwargs, **kwargs,
) )
else: else:
processor = AutoProcessor.from_pretrained( if config.model_type in _CUSTOMIZED_MM_PROCESSOR:
tokenizer_name, processor = _CUSTOMIZED_MM_PROCESSOR[config.model_type].from_pretrained(
*args, tokenizer_name,
trust_remote_code=trust_remote_code, *args,
revision=revision, trust_remote_code=trust_remote_code,
**kwargs, revision=revision,
) **kwargs,
)
else:
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
**kwargs,
)
except ValueError as e: except ValueError as e:
error_message = str(e) error_message = str(e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment