Unverified Commit bf53bf51 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix llava on multi images (#1247)

parent b1a540ec
......@@ -26,7 +26,7 @@ import struct
import time
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import psutil
......@@ -193,35 +193,16 @@ def allocate_init_ports(
return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
if tokenizer == None:
return [-1e5] * vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
return logit_bias
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type")
def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
):
return True
else:
return False
def is_generation_model(model_architectures, is_embedding: bool = False):
......@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
def load_image(image_file: Union[str, bytes]):
from PIL import Image
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
if isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
......@@ -334,8 +317,10 @@ def load_image(image_file):
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
elif isinstance(image_file, str):
image = Image.open(BytesIO(base64.b64decode(image_file)))
else:
raise ValueError(f"Invalid image: {image}")
return image, image_size
......
......@@ -32,8 +32,6 @@ class TestOpenAIVisionServer(unittest.TestCase):
other_args=[
"--chat-template",
"chatml-llava",
"--chunked-prefill-size",
"16384",
# "--log-requests",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment