Unverified Commit 55b974f9 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Process image in parallel (#1539)

parent f86c1e61
# TODO: also move pad_input_ids into this module
import asyncio
import concurrent.futures
import logging
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import List, Optional, Union
import numpy as np
import transformers
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
class BaseImageProcessor(ABC):
@abstractmethod
async def process_images_async(self, image_data, **kwargs):
pass
class DummyImageProcessor(BaseImageProcessor):
async def process_images_async(self, *args, **kwargs):
return None
class LlavaImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _image_processor):
self.hf_config = hf_config
self._image_processor = _image_processor
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
)
@staticmethod
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
image_processor=None,
):
image_processor = image_processor or global_processor.image_processor
try:
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in image_processor.image_mean),
)
pixel_values = image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None
and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, image_processor, image_grid_pinpoints
)
else:
pixel_values = image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
LlavaImageProcessor._process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return self._process_single_image_task(
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
self, image_data: List[Union[str, bytes]], request_obj
):
if not image_data:
return None
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
)
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities,
}
def get_image_processor(
hf_config, server_args: ServerArgs, _image_processor
) -> BaseImageProcessor:
return LlavaImageProcessor(hf_config, server_args, _image_processor)
def get_dummy_image_processor():
return DummyImageProcessor()
......@@ -16,17 +16,13 @@ limitations under the License.
"""TokenizerManager is a process that tokenizes the text."""
import asyncio
import concurrent.futures
import dataclasses
import json
import logging
import multiprocessing as mp
import os
from typing import Dict, List, Optional, Tuple, Union
import fastapi
import numpy as np
import transformers
import uvloop
import zmq
import zmq.asyncio
......@@ -38,6 +34,10 @@ from sglang.srt.hf_transformers_utils import (
get_processor,
get_tokenizer,
)
from sglang.srt.managers.image_processor import (
get_dummy_image_processor,
get_image_processor,
)
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
......@@ -53,11 +53,9 @@ from sglang.srt.managers.io_struct import (
UpdateWeightReqInput,
UpdateWeightReqOutput,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model, load_image
from sglang.utils import get_exception_traceback
from sglang.srt.utils import is_generation_model, is_multimodal_model
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -105,6 +103,8 @@ class TokenizerManager:
self.context_len = server_args.context_length or get_context_length(
self.hf_config
)
# Create image processor placeholder
self.image_processor = get_dummy_image_processor()
# Create tokenizer
if server_args.skip_tokenizer_init:
......@@ -119,13 +119,9 @@ class TokenizerManager:
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
max_workers=os.environ.get("SGLANG_CPU_COUNT", os.cpu_count()),
# We want to parallelize the image pre-processing so we create an executor for it
self.image_processor = get_image_processor(
self.hf_config, server_args, self.processor.image_processor
)
else:
self.tokenizer = get_tokenizer(
......@@ -194,8 +190,8 @@ class TokenizerManager:
)
if self.is_generation:
image_inputs = await self._get_image_inputs(
obj, obj.image_data if not_use_index else obj.image_data[index]
image_inputs = await self.image_processor.process_images_async(
obj.image_data if not_use_index else obj.image_data[index], obj
)
return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index]
......@@ -247,7 +243,9 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0
image_inputs = await self._get_image_inputs(obj, obj.image_data[0])
image_inputs = await self.image_processor.process_images_async(
obj.image_data[0], obj
)
return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0]
......@@ -362,8 +360,8 @@ class TokenizerManager:
sampling_params = self._get_sampling_params(obj.sampling_params[index])
if self.is_generation:
image_inputs = await self._get_image_inputs(
obj, obj.image_data[index]
image_inputs = await self.image_processor.process_images_async(
obj.image_data[index], obj
)
tokenized_obj = TokenizedGenerateReqInput(
......@@ -686,131 +684,3 @@ class TokenizerManager:
token_top_logprobs, decode_to_text
)
return top_logprobs
async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]):
if not image_data:
return None
# TODO: move this into a processor for each vision architecture
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints")
and "anyres" in aspect_ratio
else None
)
if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images
if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints
)
pixel_values.append(pixel_v)
image_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0)
else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"image_sizes": image_sizes,
"modalities": obj.modalities,
}
async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
_process_single_image_task,
image_data,
aspect_ratio,
grid_pinpoints,
)
else:
return _process_single_image_task(
image_data, aspect_ratio, grid_pinpoints, self.processor
)
global global_processor
def init_global_processor(server_args: ServerArgs):
"""Init the global processor for multi modal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
def _process_single_image_task(
image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
):
try:
processor = processor or global_processor
image, image_size = load_image(image_data)
if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)):
pixel_values[_] = pixel_values[_].astype(np.float16)
pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size
else:
# It is an image
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image.convert("RGB"))[
"pixel_values"
][0]
elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints
)
else:
pixel_values = processor.image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size
except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment