Unverified Commit fd9ad817 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Organize image inputs (#1531)

parent e165a9fc
...@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput: ...@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
input_text: str input_text: str
# The input token ids # The input token ids
input_ids: List[int] input_ids: List[int]
# The pixel values for input images # The image input
pixel_values: List[float] image_inputs: dict
# The hash values of input images
image_hashes: List[int]
# The image sizes
image_sizes: List[List[int]]
# The sampling parameters # The sampling parameters
sampling_params: SamplingParams sampling_params: SamplingParams
# Whether to return the logprobs # Whether to return the logprobs
...@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput: ...@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
top_logprobs_num: int top_logprobs_num: int
# Whether to stream output # Whether to stream output
stream: bool stream: bool
# Modalities of the input images
modalites: Optional[List[str]] = None
# LoRA related # LoRA related
lora_path: Optional[str] = None # None means just use the base model lora_path: Optional[str] = None # None means just use the base model
......
...@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason):
} }
@dataclass
class ImageInputs:
pixel_values: torch.Tensor
image_hash: int
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
pad_values: Optional[list] = None
modalities: Optional[list] = None
image_embeds: Optional[List[torch.Tensor]] = None
aspect_ratio_ids: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
@staticmethod
def from_dict(obj, vocab_size):
# Use image hash as fake token_ids, which is then used for prefix matching
ret = ImageInputs(
pixel_values=obj["pixel_values"],
image_hash=hash(tuple(obj["image_hashes"])),
)
image_hash = ret.image_hash
ret.pad_values = [
(image_hash) % vocab_size,
(image_hash >> 16) % vocab_size,
(image_hash >> 32) % vocab_size,
(image_hash >> 64) % vocab_size,
]
ret.image_sizes = obj["image_sizes"]
# Only when pixel values is not None we have modalities
ret.modalities = obj["modalities"]
return ret
class Req: class Req:
"""Store all inforamtion of a request.""" """Store all inforamtion of a request."""
...@@ -147,11 +180,7 @@ class Req: ...@@ -147,11 +180,7 @@ class Req:
self.completion_tokens_wo_jump_forward = 0 self.completion_tokens_wo_jump_forward = 0
# For vision inputs # For vision inputs
self.pixel_values = None self.image_inputs: Optional[ImageInputs] = None
self.image_sizes = None
self.image_offsets = None
self.pad_value = None
self.modalities = None
# Prefix info # Prefix info
self.prefix_indices = [] self.prefix_indices = []
...@@ -654,15 +683,9 @@ class ScheduleBatch: ...@@ -654,15 +683,9 @@ class ScheduleBatch:
self.tree_cache.cache_finished_req(req, cur_all_ids) self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding # re-applying image padding
if req.pixel_values is not None: if req.image_inputs is not None:
( req.origin_input_ids = model_runner.model.pad_input_ids(
req.origin_input_ids, req.origin_input_ids_unpadded, req.image_inputs
req.image_offsets,
) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
) )
jump_forward_reqs.append(req) jump_forward_reqs.append(req)
......
...@@ -194,10 +194,9 @@ class TokenizerManager: ...@@ -194,10 +194,9 @@ class TokenizerManager:
) )
if self.is_generation: if self.is_generation:
pixel_values, image_hashes, image_sizes = await self._get_pixel_values( image_inputs = await self._get_image_inputs(
obj.image_data if not_use_index else obj.image_data[index] obj, obj.image_data if not_use_index else obj.image_data[index]
) )
modalities = obj.modalities
return_logprob = ( return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
) )
...@@ -248,10 +247,7 @@ class TokenizerManager: ...@@ -248,10 +247,7 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
pixel_values, image_hashes, image_sizes = await self._get_pixel_values( image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
obj.image_data[0]
)
modalities = obj.modalities
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
...@@ -262,15 +258,12 @@ class TokenizerManager: ...@@ -262,15 +258,12 @@ class TokenizerManager:
rid, rid,
input_text, input_text,
input_ids, input_ids,
pixel_values, image_inputs,
image_hashes,
image_sizes,
sampling_params, sampling_params,
return_logprob, return_logprob,
logprob_start_len, logprob_start_len,
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
modalities,
( (
obj.lora_path[index] obj.lora_path[index]
if isinstance(obj.lora_path, list) if isinstance(obj.lora_path, list)
...@@ -369,24 +362,20 @@ class TokenizerManager: ...@@ -369,24 +362,20 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index]) sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation: if self.is_generation:
pixel_values, image_hashes, image_sizes = ( image_inputs = await self._get_image_inputs(
await self._get_pixel_values(obj.image_data[index]) obj, obj.image_data[index]
) )
modalities = obj.modalities
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
input_text, input_text,
input_ids, input_ids,
pixel_values, image_inputs,
image_hashes,
image_sizes,
sampling_params, sampling_params,
obj.return_logprob[index], obj.return_logprob[index],
obj.logprob_start_len[index], obj.logprob_start_len[index],
obj.top_logprobs_num[index], obj.top_logprobs_num[index],
obj.stream, obj.stream,
modalities,
( (
obj.lora_path[index] obj.lora_path[index]
if isinstance(obj.lora_path, list) if isinstance(obj.lora_path, list)
...@@ -697,10 +686,11 @@ class TokenizerManager: ...@@ -697,10 +686,11 @@ class TokenizerManager:
) )
return top_logprobs return top_logprobs
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
if not image_data: if not image_data:
return None, None, None return None
# TODO: move this into a processor for each vision architecture
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
self.hf_config.image_grid_pinpoints self.hf_config.image_grid_pinpoints
...@@ -741,7 +731,12 @@ class TokenizerManager: ...@@ -741,7 +731,12 @@ class TokenizerManager:
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")
return pixel_values, image_hashes, image_sizes return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": obj.modalities,
}
async def _process_single_image( async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
......
...@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder ...@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
BaseFinishReason, BaseFinishReason,
ImageInputs,
Req, Req,
ScheduleBatch, ScheduleBatch,
) )
...@@ -340,29 +341,16 @@ class ModelTpServer: ...@@ -340,29 +341,16 @@ class ModelTpServer:
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None: # Image inputs
# Use image hash as fake token_ids, which is then used if recv_req.image_inputs is not None:
# for prefix matching req.image_inputs = ImageInputs.from_dict(
image_hash = hash(tuple(recv_req.image_hashes)) recv_req.image_inputs, self.model_config.vocab_size
req.pad_value = [
(image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size,
]
req.image_sizes = recv_req.image_sizes
(
req.origin_input_ids,
req.image_offsets,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values,
req.image_sizes,
) )
# Only when pixel values is not None we have modalities req.origin_input_ids = self.model_runner.model.pad_input_ids(
req.modalities = recv_req.modalites req.origin_input_ids_unpadded, req.image_inputs
)
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream req.stream = recv_req.stream
......
...@@ -25,7 +25,7 @@ import torch ...@@ -25,7 +25,7 @@ import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.layers.attention_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -84,17 +84,10 @@ class InputMetadata: ...@@ -84,17 +84,10 @@ class InputMetadata:
extend_logprob_start_lens_cpu: List[int] = None extend_logprob_start_lens_cpu: List[int] = None
# For multimodal # For multimodal
pixel_values: List[torch.Tensor] = None image_inputs: List[ImageInputs] = None
image_sizes: List[List[List[int]]] = None
image_offsets: List[List[int]] = None
modalities: List[List[str]] = None
def init_multimuldal_info(self, batch: ScheduleBatch): def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs self.image_inputs = [r.image_inputs for r in batch.reqs]
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [r.image_offsets for r in reqs]
self.modalities = [r.modalities for r in reqs]
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
......
...@@ -498,23 +498,10 @@ class ModelRunner: ...@@ -498,23 +498,10 @@ class ModelRunner:
get_embedding=True, get_embedding=True,
) )
def forward_extend_multi_modal(self, batch: ScheduleBatch):
input_metadata = InputMetadata.from_schedule_batch(self, batch)
return self.model.forward(
batch.input_ids,
input_metadata.positions,
input_metadata,
input_metadata.pixel_values,
input_metadata.image_sizes,
input_metadata.image_offsets,
)
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]: def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
assert batch.forward_mode is not None assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend(): if batch.forward_mode.is_decode():
return self.forward_extend_multi_modal(batch)
elif batch.forward_mode.is_decode():
return self.forward_decode(batch) return self.forward_decode(batch)
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
return self.forward_extend(batch) return self.forward_extend(batch)
......
...@@ -35,25 +35,22 @@ from vllm.config import CacheConfig ...@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaBaseForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids( def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
self, image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
offset_list = [] offset_list = []
...@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
new_w = int(new_w // times) new_w = int(new_w // times)
new_image_feature_len += new_h * (new_w + 1) new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * ( pad_ids = pad_values * (
(new_image_feature_len + len(pad_value)) // len(pad_value) (new_image_feature_len + len(pad_values)) // len(pad_values)
) )
# print("calculated new_image_feature_len: ", new_image_feature_len) # print("calculated new_image_feature_len: ", new_image_feature_len)
try: try:
...@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
offset_list.append(offset) offset_list.append(offset)
return input_ids, offset_list
image_inputs.image_offsets = offset_list
return input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
...@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
if input_metadata.forward_mode.is_extend(): if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str] # Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size # The length of the List should be equal to batch size
modalities_list = [] modalities_list = []
for modalities in input_metadata.modalities: max_image_offset = []
if modalities is not None: for im in image_inputs:
modalities_list.extend(modalities) if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im and im.image_offsets is not None:
max_image_offset.append(max(im.image_offsets))
else:
max_image_offset.append(-1)
# Embed text inputs # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Whether the requests need vision inputs
max_image_offset = np.array(
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
need_vision = start_positions <= max_image_offset need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]
########## Encode Image ######## ########## Encode Image ########
......
...@@ -26,7 +26,8 @@ from vllm.config import CacheConfig ...@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
...@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
) )
def pad_input_ids( def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
self, pad_values = image_inputs.pad_values
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
new_image_feature_len = self.image_feature_len new_image_feature_len = self.image_feature_len
pad_ids = pad_value * ( pad_ids = pad_values * (
(new_image_feature_len + len(pad_value)) // len(pad_value) (new_image_feature_len + len(pad_values)) // len(pad_values)
) )
offset = input_ids.index(self.config.image_token_index) offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id # old_len + pad_len - 1, because we need to remove image_token_id
...@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
+ pad_ids[:new_image_feature_len] + pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
return new_input_ids, [offset] image_inputs.image_offsets = [offset]
return new_input_ids
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
...@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
if input_metadata.forward_mode.is_extend(): if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size bs = input_metadata.batch_size
...@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module):
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Whether the requests need vision inputs # Whether the requests need vision inputs
max_image_offset = np.array( max_image_offset = []
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] for im in image_inputs:
) if im and im.image_offsets:
max_image_offset.append(max(im.image_offsets))
else:
max_image_offset.append(-1)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
need_vision = start_positions <= max_image_offset need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
]
########## Encode Image ######## ########## Encode Image ########
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment