Unverified Commit de1cb387 authored by pengyuange's avatar pengyuange Committed by GitHub
Browse files

[Model] Support Skywork-R1V (#15397)


Signed-off-by: default avatarjiacai.liu <932997367@qq.com>
Co-authored-by: default avatarjiacai.liu <932997367@qq.com>
parent c802f543
......@@ -921,6 +921,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `SkyworkR1VChatModel`
* Skywork-R1V-38B
* T + I
* `Skywork/Skywork-R1V-38B`
*
* ✅︎
* ✅︎
- * `UltravoxModel`
* Ultravox
* T + A<sup>E+</sup>
......
......@@ -804,6 +804,41 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
)
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "Skywork/Skywork-R1V-38B"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=4096,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# Stop tokens for SkyworkR1V
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
stop_tokens = ["<|end▁of▁sentence|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
......@@ -834,6 +869,7 @@ model_example_map = {
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"skywork_chat": run_skyworkr1v,
}
......
......@@ -474,6 +474,20 @@ VLM_TEST_SETTINGS = {
vllm_output_post_proc=model_utils.qwen_vllm_to_hf_output,
prompt_path_encoder=model_utils.qwen_prompt_path_encoder,
),
"skywork_r1v": VLMTestInfo(
models=["Skywork/Skywork-R1V-38B"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin▁of▁sentence|><|User|>\n{img_prompt}<|Assistant|><think>\n", # noqa: E501
single_image_prompts=IMAGE_ASSETS.prompts({
"stop_sign": "<image>\nWhat's the content in the center of the image?", # noqa: E501
"cherry_blossom": "<image>\nWhat is the season?",
}),
multi_image_prompt="<image>\n<image>\nDescribe the two images in short.", # noqa: E501
max_model_len=4096,
use_tokenizer_eos=True,
patch_hf_runner=model_utils.skyworkr1v_patch_hf_runner,
marks=[large_gpu_mark(min_gb=80)],
),
### Tensor parallel / multi-gpu broadcast tests
"chameleon-broadcast": VLMTestInfo(
models=["facebook/chameleon-7b"],
......
......@@ -376,6 +376,63 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
return hf_model
def skyworkr1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for SkyworkR1V."""
class SkyworkR1VProcessor:
"""A simple processor for SkyworkR1V."""
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
def __call__(self, text: str, images: Union[Image, list[Image]],
**kwargs):
from vllm.model_executor.models.skyworkr1v import (
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_skyworkr1v)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values_skyworkr1v(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
]
pixel_values = torch.cat(pixel_values, dim=0)
for num_patches in num_patches_list:
context_tokens = IMG_CONTEXT * self.num_image_token \
* num_patches
image_tokens = IMG_START + context_tokens + IMG_END
text = text.replace('<image>', image_tokens, 1)
prompt = self.tokenizer(text, return_tensors="pt")
prompt.update({"pixel_values": pixel_values})
return prompt
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
"<IMG_CONTEXT>")
hf_model.model.img_context_token_id = img_context_token_id
hf_model.processor = SkyworkR1VProcessor(hf_model)
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
hf_model.model.generate = types.MethodType(_internvl_generate,
hf_model.model)
return hf_model
def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for InternVL."""
......
......@@ -262,22 +262,23 @@ def _test_processing_correctness_mistral(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-Llama3-V-2_5",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"allenai/Molmo-7B-D-0924",
"allenai/Molmo-7B-O-0924",
"nvidia/NVLM-D-72B",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
......
......@@ -294,6 +294,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
......
......@@ -496,7 +496,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "deepseek_vl_v2", "internvl_chat",
"NVLM_D", "h2ovl_chat"):
"skywork_chat", "NVLM_D", "h2ovl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
......
......@@ -190,6 +190,7 @@ _MULTIMODAL_MODELS = {
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
......
# SPDX-License-Identifier: Apache-2.0
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>'
IMG_END = '</img>'
IMG_CONTEXT = '<IMG_CONTEXT>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class SkyworkR1VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
SkyworkR1VImageEmbeddingInputs]
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
*,
width: int,
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def resolve_skyworkr1v_min_max_num(
*,
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
min_dynamic_patch = min_dynamic_patch if dynamic_image_size else 1
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_skyworkr1v_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
def calculate_skyworkr1v_targets(
*,
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks != 1
if use_thumbnail and blocks != 1:
blocks += 1
return blocks, target_width, target_height
def dynamic_preprocess_skyworkr1v(
image: Image.Image,
*,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> list[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_skyworkr1v_targets(
orig_width=orig_width,
orig_height=orig_height,
target_ratios=target_ratios,
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B
def image_to_pixel_values_skyworkr1v(
image: Image.Image,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
) -> torch.Tensor:
target_ratios = get_skyworkr1v_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
images = dynamic_preprocess_skyworkr1v(
image,
target_ratios=target_ratios,
image_size=input_size,
use_thumbnail=use_thumbnail,
)
pixel_values = torch.stack([transform(image) for image in images])
return pixel_values
class BaseSkyworkR1VProcessor(ABC):
"""
This model doesn't define its own HF processor,
so we implement our own one here.
The code to insert image tokens is based on:
https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/modeling_skywork_chat.py#L252
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
image_size: int = config.vision_config.image_size
patch_size: int = config.vision_config.patch_size
if min_dynamic_patch is None:
min_dynamic_patch = config.min_dynamic_patch
assert isinstance(min_dynamic_patch, int)
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
assert isinstance(max_dynamic_patch, int)
if dynamic_image_size is None:
dynamic_image_size = config.dynamic_image_size
assert isinstance(dynamic_image_size, bool)
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.image_size = image_size
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
raise NotImplementedError
def resolve_min_max_num(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]:
min_dynamic_patch = (self.min_dynamic_patch if min_dynamic_patch
is None else min_dynamic_patch)
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
is None else dynamic_image_size)
use_thumbnail = (self.use_thumbnail
if use_thumbnail is None else use_thumbnail)
return resolve_skyworkr1v_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
def resolve_target_ratios(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
return get_skyworkr1v_target_ratios(min_num, max_num)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
target_ratios = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
)
num_patches, _, _ = calculate_skyworkr1v_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios,
use_thumbnail=self.use_thumbnail,
)
return num_patches * self.num_image_token
def _images_to_pixel_values_lst(
self,
images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values
)
return [
image_to_pixel_values_skyworkr1v(
image,
input_size=self.image_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
else:
pixel_values_lst = self._images_to_pixel_values_lst(
images,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs: dict[str, NestedTensors] = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
torch.tensor([len(item) for item in pixel_values_lst]),
}
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[BaseSkyworkR1VProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
base_size = processor.image_size
target_ratios = processor.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
def get_replacement_skyworkr1v(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return hf_processor.get_image_repl(feature_size, num_patches)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_skyworkr1v,
)
]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'SkyworkLM2VEForCausalLM'
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _patch_quant_config(self, config: PretrainedConfig,
quant_config: QuantizationConfig):
# the awq models from OpenGVLab missing `modules_to_not_convert`
# patch the quant_config to add `modules_to_not_convert` back
if isinstance(quant_config, AWQConfig):
text_config = config.text_config
llm_quant_config = getattr(text_config, "quantization_config",
None)
if (not quant_config.modules_to_not_convert) and \
(llm_quant_config is not None):
quant_config.modules_to_not_convert.append("vision_model")
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
*,
is_mono: bool,
prefix: str,
):
if not is_mono:
vision_feature_layer = config.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
return nn.Sequential(
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
ReplicatedLinear(vit_hidden_size *
int(1 / self.downsample_ratio)**2,
llm_hidden_size,
return_bias=False),
nn.GELU(),
ReplicatedLinear(llm_hidden_size,
llm_hidden_size,
return_bias=False),
)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
int(c / (scale_factor * scale_factor)))
if self.ps_version == 'v1':
pass
else:
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor:
vit_embeds = self.vision_model(pixel_values=pixel_values)
vit_embeds = vit_embeds[:, 1:, :]
h = w = int(vit_embeds.shape[1]**0.5)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds,
scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return SkyworkR1VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: SkyworkR1VImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
# Only one image in the current batch
if len(num_patches) == 1:
return image_embeds.view(
-1, self.config.text_config.hidden_size).unsqueeze(0)
# NOTE: Image embeddings are split into separate tensors for each image
# by the size of each embedding.
feature_size = image_embeds.shape[1]
image_embeds = image_embeds.view(-1,
self.config.text_config.hidden_size)
image_feature_sizes = [
num_patches * feature_size for num_patches in num_patches
]
return image_embeds.split(image_feature_sizes)
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
self.visual_token_mask = None
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.img_context_token_id,
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
# Only required if the model is mono-architecture
if self.visual_token_mask is not None:
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
skip_prefixes = [
"action_embed", "temporal_embed", "track_embed",
"track_embed_decoder", "box_token", "cg_criterion", "cg_model",
"loc_encoder", "loc_decoder", "sam", "temporal_token",
"track_token"
]
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights)
......@@ -37,8 +37,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
Olmo2Config, RWConfig,
SolarConfig, Telechat2Config,
UltravoxConfig)
SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname
......@@ -76,6 +76,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"NVLM_D": NVLM_D_Config,
"olmo2": Olmo2Config,
"solar": SolarConfig,
"skywork_chat": SkyworkR1VChatConfig,
"telechat": Telechat2Config,
"ultravox": UltravoxConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
......
......@@ -20,6 +20,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.olmo2 import Olmo2Config
from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig
from vllm.transformers_utils.configs.solar import SolarConfig
from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
......@@ -42,6 +43,7 @@ __all__ = [
"NemotronConfig",
"NVLM_D_Config",
"Olmo2Config",
"SkyworkR1VChatConfig",
"SolarConfig",
"Telechat2Config",
"UltravoxConfig",
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/configuration_skywork_chat.py
# --------------------------------------------------------
# SkyworkR1V
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class SkyworkR1VChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment