Unverified Commit bf53bf51 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix llava on multi images (#1247)

parent b1a540ec
...@@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -240,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
- Qwen / Qwen 2 / Qwen 2 MoE - Qwen / Qwen 2 / Qwen 2 MoE
- DeepSeek / DeepSeek 2 - DeepSeek / DeepSeek 2
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384` - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
- LLaVA 1.5 / 1.6 / NeXT - LLaVA 1.5 / 1.6 / NeXT
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
......
...@@ -184,13 +184,9 @@ if __name__ == "__main__": ...@@ -184,13 +184,9 @@ if __name__ == "__main__":
# Parse the arguments # Parse the arguments
args = parser.parse_args() args = parser.parse_args()
cur_port = args.port cur_port = args.port
cur_chunk = args.chunk_idx cur_chunk = args.chunk_idx
num_chunks = args.num_chunks num_chunks = args.num_chunks
num_frames = args.num_frames num_frames = args.num_frames
if "34b" in args.model_path.lower(): if "34b" in args.model_path.lower():
...@@ -202,7 +198,6 @@ if __name__ == "__main__": ...@@ -202,7 +198,6 @@ if __name__ == "__main__":
exit() exit()
model_overide_args = {} model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
model_overide_args["architectures"] = ["LlavaVidForCausalLM"] model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = args.num_frames model_overide_args["num_frames"] = args.num_frames
...@@ -235,7 +230,6 @@ if __name__ == "__main__": ...@@ -235,7 +230,6 @@ if __name__ == "__main__":
print(f"chat template: {runtime.endpoint.chat_template.name}") print(f"chat template: {runtime.endpoint.chat_template.name}")
# Run a single request # Run a single request
# try:
print("\n========== single ==========\n") print("\n========== single ==========\n")
root = args.video_dir root = args.video_dir
if os.path.isfile(root): if os.path.isfile(root):
...@@ -257,13 +251,10 @@ if __name__ == "__main__": ...@@ -257,13 +251,10 @@ if __name__ == "__main__":
) # Calculate the average processing time ) # Calculate the average processing time
print(f"Average processing time per video: {average_time:.2f} seconds") print(f"Average processing time per video: {average_time:.2f} seconds")
runtime.shutdown() runtime.shutdown()
# except Exception as e:
# print(e)
runtime.shutdown()
# # # Run a batch of requests # # Run a batch of requests
# print("\n========== batch ==========\n") # print("\n========== batch ==========\n")
# if not os.path.exists(args.save_dir): # if not os.path.exists(args.save_dir):
# os.makedirs(args.save_dir) # os.makedirs(args.save_dir)
# batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks)
# runtime.shutdown() # runtime.shutdown()
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
launch_server(server_args, model_overide_args, None)
...@@ -119,24 +119,7 @@ def get_tokenizer( ...@@ -119,24 +119,7 @@ def get_tokenizer(
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if tokenizer_name.endswith(".json"):
return TiktokenTokenizer(tokenizer_name)
if tokenizer_name.endswith(".model"):
return SentencePieceTokenizer(tokenizer_name)
"""Gets a tokenizer for the given model name via Huggingface.""" """Gets a tokenizer for the given model name via Huggingface."""
if is_multimodal_model(tokenizer_name):
processor = get_processor(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs,
)
tokenizer = processor.tokenizer
return tokenizer
if tokenizer_mode == "slow": if tokenizer_mode == "slow":
if kwargs.get("use_fast", False): if kwargs.get("use_fast", False):
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
...@@ -199,135 +182,3 @@ def get_processor( ...@@ -199,135 +182,3 @@ def get_processor(
**kwargs, **kwargs,
) )
return processor return processor
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
from jinja2 import Template
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name = "tmp-json"
with open(tokenizer_path, "rb") as fin:
tok_dict = json.load(fin)
mergeable_ranks = {
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
default_allowed_special = None
kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
"mergeable_ranks": mergeable_ranks,
"special_tokens": special_tokens,
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
def encode_patched(
self,
text: str,
*,
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self,
text,
allowed_special=allowed_special,
disallowed_special=(),
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
class SentencePieceTokenizer:
def __init__(self, tokenizer_path):
import sentencepiece as spm
from jinja2 import Template
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer.eos_id()
self.vocab_size = tokenizer.vocab_size()
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def encode(self, x, add_special_tokens=False):
return self.tokenizer.encode(x)
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
...@@ -55,6 +55,7 @@ class GenerateReqInput: ...@@ -55,6 +55,7 @@ class GenerateReqInput:
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
raise ValueError("Either text or input_ids should be provided.") raise ValueError("Either text or input_ids should be provided.")
if ( if (
isinstance(self.sampling_params, dict) isinstance(self.sampling_params, dict)
and self.sampling_params.get("n", 1) != 1 and self.sampling_params.get("n", 1) != 1
...@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput: ...@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
input_ids: List[int] input_ids: List[int]
# The pixel values for input images # The pixel values for input images
pixel_values: List[float] pixel_values: List[float]
# The hash of input images # The hash values of input images
image_hash: int image_hashes: List[int]
# The image size # The image sizes
image_size: List[int] image_sizes: List[List[int]]
# The sampling parameters # The sampling parameters
sampling_params: SamplingParams sampling_params: SamplingParams
# Whether to return the logprobs # Whether to return the logprobs
......
...@@ -121,8 +121,8 @@ class Req: ...@@ -121,8 +121,8 @@ class Req:
# For vision input # For vision input
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_sizes = None
self.image_offset = None self.image_offsets = None
self.pad_value = None self.pad_value = None
# Prefix info # Prefix info
...@@ -600,12 +600,12 @@ class ScheduleBatch: ...@@ -600,12 +600,12 @@ class ScheduleBatch:
if req.pixel_values is not None: if req.pixel_values is not None:
( (
req.origin_input_ids, req.origin_input_ids,
req.image_offset, req.image_offsets,
) = model_runner.model.pad_input_ids( ) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.origin_input_ids_unpadded,
req.pad_value, req.pad_value,
req.pixel_values.shape, req.pixel_values,
req.image_size, req.image_sizes,
) )
jump_forward_reqs.append(req) jump_forward_reqs.append(req)
......
...@@ -23,6 +23,7 @@ import multiprocessing as mp ...@@ -23,6 +23,7 @@ import multiprocessing as mp
import os import os
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import fastapi
import numpy as np import numpy as np
import transformers import transformers
import uvloop import uvloop
...@@ -96,21 +97,18 @@ class TokenizerManager: ...@@ -96,21 +97,18 @@ class TokenizerManager:
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
model_overide_args=model_overide_args, model_overide_args=model_overide_args,
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
self.hf_config.architectures, self.server_args.is_embedding self.hf_config.architectures, self.server_args.is_embedding
) )
self.context_len = server_args.context_length or get_context_length(
if server_args.context_length is not None: self.hf_config
self.context_len = server_args.context_length )
else:
self.context_len = get_context_length(self.hf_config)
# Create tokenizer # Create tokenizer
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(self.model_path): if is_multimodal_model(self.hf_config.architectures):
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
...@@ -118,6 +116,9 @@ class TokenizerManager: ...@@ -118,6 +116,9 @@ class TokenizerManager:
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
# We want to parallelize the image pre-processing so we
# create an executor for it
self.executor = concurrent.futures.ProcessPoolExecutor( self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor, initializer=init_global_processor,
mp_context=mp.get_context("fork"), mp_context=mp.get_context("fork"),
...@@ -134,12 +135,14 @@ class TokenizerManager: ...@@ -134,12 +135,14 @@ class TokenizerManager:
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
# for update model weights # For update model weights
self.model_update_lock = asyncio.Lock() self.model_update_lock = asyncio.Lock()
self.model_update_result = None self.model_update_result = None
async def generate_request( async def generate_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
): ):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -160,7 +163,7 @@ class TokenizerManager: ...@@ -160,7 +163,7 @@ class TokenizerManager:
async def _handle_single_request( async def _handle_single_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
request, request: Optional[fastapi.Request] = None,
index: Optional[int] = None, index: Optional[int] = None,
is_cache_for_prefill: Optional[bool] = False, is_cache_for_prefill: Optional[bool] = False,
): ):
...@@ -182,8 +185,8 @@ class TokenizerManager: ...@@ -182,8 +185,8 @@ class TokenizerManager:
) )
if self.is_generation: if self.is_generation:
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data obj.image_data if not_use_index else obj.image_data[index]
) )
return_logprob = ( return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
...@@ -195,7 +198,6 @@ class TokenizerManager: ...@@ -195,7 +198,6 @@ class TokenizerManager:
) )
if return_logprob and logprob_start_len == -1: if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1 logprob_start_len = len(input_ids) - 1
top_logprobs_num = ( top_logprobs_num = (
obj.top_logprobs_num obj.top_logprobs_num
if not_use_index if not_use_index
...@@ -238,13 +240,14 @@ class TokenizerManager: ...@@ -238,13 +240,14 @@ class TokenizerManager:
sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params = SamplingParams(**obj.sampling_params[0])
sampling_params.max_new_tokens = 0 sampling_params.max_new_tokens = 0
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0] obj.image_data[0]
) )
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
# Send to the controller
if self.is_generation: if self.is_generation:
if return_logprob and logprob_start_len == -1: if return_logprob and logprob_start_len == -1:
logprob_start_len = len(input_ids) - 1 logprob_start_len = len(input_ids) - 1
...@@ -253,8 +256,8 @@ class TokenizerManager: ...@@ -253,8 +256,8 @@ class TokenizerManager:
input_text, input_text,
input_ids, input_ids,
pixel_values, pixel_values,
image_hash, image_hashes,
image_size, image_sizes,
sampling_params, sampling_params,
return_logprob, return_logprob,
logprob_start_len, logprob_start_len,
...@@ -268,24 +271,24 @@ class TokenizerManager: ...@@ -268,24 +271,24 @@ class TokenizerManager:
input_ids, input_ids,
sampling_params, sampling_params,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
# Recv results
event = asyncio.Event() event = asyncio.Event()
state = ReqState([], False, event) state = ReqState([], False, event)
self.rid_to_state[rid] = state self.rid_to_state[rid] = state
if not is_cache_for_prefill: if not is_cache_for_prefill:
async for response in self._wait_for_response( async for response in self._wait_for_response(state, obj, rid, request):
event, state, obj, rid, request
):
yield response yield response
else: else:
assert self.is_generation assert self.is_generation
await self._wait_for_cache_prefill_response(event, state, obj, rid, request) await self._wait_for_cache_prefill_response(state, obj, rid, request)
yield input_ids yield input_ids
async def _handle_batch_request( async def _handle_batch_request(
self, obj: Union[GenerateReqInput, EmbeddingReqInput], request self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
request: Optional[fastapi.Request] = None,
): ):
batch_size = obj.batch_size batch_size = obj.batch_size
if self.is_generation: if self.is_generation:
...@@ -340,8 +343,8 @@ class TokenizerManager: ...@@ -340,8 +343,8 @@ class TokenizerManager:
if self.is_generation: if self.is_generation:
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
obj.logprob_start_len[index] = len(input_ids) - 1 obj.logprob_start_len[index] = len(input_ids) - 1
pixel_values, image_hash, image_size = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = (
obj.image_data[index] await self._get_pixel_values(obj.image_data[index])
) )
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
...@@ -349,8 +352,8 @@ class TokenizerManager: ...@@ -349,8 +352,8 @@ class TokenizerManager:
input_text, input_text,
input_ids, input_ids,
pixel_values, pixel_values,
image_hash, image_hashes,
image_size, image_sizes,
sampling_params, sampling_params,
obj.return_logprob[index], obj.return_logprob[index],
obj.logprob_start_len[index], obj.logprob_start_len[index],
...@@ -372,7 +375,6 @@ class TokenizerManager: ...@@ -372,7 +375,6 @@ class TokenizerManager:
generators.append( generators.append(
self._wait_for_response( self._wait_for_response(
event,
state, state,
obj, obj,
rid, rid,
...@@ -388,6 +390,7 @@ class TokenizerManager: ...@@ -388,6 +390,7 @@ class TokenizerManager:
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = [None] * len(tasks) output_list = [None] * len(tasks)
# Recv results
while tasks: while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
...@@ -426,25 +429,18 @@ class TokenizerManager: ...@@ -426,25 +429,18 @@ class TokenizerManager:
sampling_params.verify() sampling_params.verify()
return sampling_params return sampling_params
async def _get_pixel_values(self, image_data):
if image_data is None:
return None, None, None
else:
return await self._get_pixel_values_internal(image_data)
async def _wait_for_response( async def _wait_for_response(
self, self,
event: asyncio.Event,
state: ReqState, state: ReqState,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
rid: str, rid: str,
request, request: Optional[fastapi.Request] = None,
index: int = None, index: Optional[int] = None,
response_index: int = 0, response_index: int = 0,
): ):
while True: while True:
try: try:
await asyncio.wait_for(event.wait(), timeout=4) await asyncio.wait_for(state.event.wait(), timeout=4)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if request is not None and await request.is_disconnected(): if request is not None and await request.is_disconnected():
for rid in [obj.rid] if obj.is_single else obj.rid: for rid in [obj.rid] if obj.is_single else obj.rid:
...@@ -478,16 +474,15 @@ class TokenizerManager: ...@@ -478,16 +474,15 @@ class TokenizerManager:
yield out yield out
break break
event.clear() state.event.clear()
yield out yield out
async def _wait_for_cache_prefill_response( async def _wait_for_cache_prefill_response(
self, self,
event: asyncio.Event,
state: ReqState, state: ReqState,
obj: GenerateReqInput, obj: GenerateReqInput,
rid: str, rid: str,
request, request: Optional[fastapi.Request] = None,
): ):
while True: while True:
try: try:
...@@ -514,7 +509,9 @@ class TokenizerManager: ...@@ -514,7 +509,9 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_router.send_pyobj(req) self.send_to_router.send_pyobj(req)
async def update_weights(self, obj: UpdateWeightReqInput, request): async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop: if self.to_create_loop:
self.create_handle_loop() self.create_handle_loop()
...@@ -659,12 +656,11 @@ class TokenizerManager: ...@@ -659,12 +656,11 @@ class TokenizerManager:
) )
return top_logprobs return top_logprobs
async def _get_pixel_values_internal(self, image_data, aspect_ratio=None): async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
aspect_ratio = ( if not image_data:
getattr(self.hf_config, "image_aspect_ratio", None) return None, None, None
if aspect_ratio is None
else aspect_ratio aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
)
grid_pinpoints = ( grid_pinpoints = (
self.hf_config.image_grid_pinpoints self.hf_config.image_grid_pinpoints
if hasattr(self.hf_config, "image_grid_pinpoints") if hasattr(self.hf_config, "image_grid_pinpoints")
...@@ -673,35 +669,42 @@ class TokenizerManager: ...@@ -673,35 +669,42 @@ class TokenizerManager:
) )
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
pixel_values, image_hash, image_size = [], [], [] # Multiple images
if len(image_data) > 1: if len(image_data) > 1:
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
for img_data in image_data: for img_data in image_data:
pixel_v, image_h, image_s = await self._process_single_image( pixel_v, image_h, image_s = await self._process_single_image(
img_data, aspect_ratio, grid_pinpoints img_data, aspect_ratio, grid_pinpoints
) )
pixel_values.append(pixel_v) pixel_values.append(pixel_v)
image_hash.append(image_h) image_hashes.append(image_h)
image_size.append(image_s) image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
pixel_values = np.stack(pixel_values, axis=0) pixel_values = np.stack(pixel_values, axis=0)
else: else:
# A single image
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints image_data[0], aspect_ratio, grid_pinpoints
) )
image_hash = [image_hash] image_hashes = [image_hash]
image_size = [image_size] image_sizes = [image_size]
elif isinstance(image_data, str): elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints image_data, aspect_ratio, grid_pinpoints
) )
image_hash = [image_hash] image_hashes = [image_hash]
image_size = [image_size] image_sizes = [image_size]
else: else:
pixel_values, image_hash, image_size = None, None, None raise ValueError(f"Invalid image data: {image_data}")
return pixel_values, image_hash, image_size return pixel_values, image_hashes, image_sizes
async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints): async def _process_single_image(
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
):
if self.executor is not None: if self.executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
...@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs): ...@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
def _process_single_image_task( def _process_single_image_task(
image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None image_data: Union[str, bytes],
image_aspect_ratio: Optional[str] = None,
image_grid_pinpoints: Optional[str] = None,
processor=None,
): ):
try: try:
processor = processor or global_processor processor = processor or global_processor
image, image_size = load_image(image_data) image, image_size = load_image(image_data)
if image_size is not None: if image_size is not None:
# It is a video with multiple images
image_hash = hash(image_data) image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"] pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)): for _ in range(len(pixel_values)):
...@@ -745,6 +752,7 @@ def _process_single_image_task( ...@@ -745,6 +752,7 @@ def _process_single_image_task(
pixel_values = np.stack(pixel_values, axis=0) pixel_values = np.stack(pixel_values, axis=0)
return pixel_values, image_hash, image_size return pixel_values, image_hash, image_size
else: else:
# It is an image
image_hash = hash(image_data) image_hash = hash(image_data)
if image_aspect_ratio == "pad": if image_aspect_ratio == "pad":
image = expand2square( image = expand2square(
...@@ -754,13 +762,18 @@ def _process_single_image_task( ...@@ -754,13 +762,18 @@ def _process_single_image_task(
pixel_values = processor.image_processor(image.convert("RGB"))[ pixel_values = processor.image_processor(image.convert("RGB"))[
"pixel_values" "pixel_values"
][0] ][0]
elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: elif image_aspect_ratio == "anyres" or (
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
):
pixel_values = process_anyres_image( pixel_values = process_anyres_image(
image, processor.image_processor, image_grid_pinpoints image, processor.image_processor, image_grid_pinpoints
) )
else: else:
pixel_values = processor.image_processor(image)["pixel_values"][0] pixel_values = processor.image_processor(image)["pixel_values"][0]
if isinstance(pixel_values, np.ndarray):
pixel_values = pixel_values.astype(np.float16) pixel_values = pixel_values.astype(np.float16)
return pixel_values, image_hash, image.size return pixel_values, image_hash, image.size
except Exception: except Exception:
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
...@@ -108,7 +108,7 @@ class ModelTpServer: ...@@ -108,7 +108,7 @@ class ModelTpServer:
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
else: else:
if is_multimodal_model(server_args.model_path): if is_multimodal_model(self.model_config.hf_config.architectures):
self.processor = get_processor( self.processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
...@@ -333,26 +333,24 @@ class ModelTpServer: ...@@ -333,26 +333,24 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None: if req.pixel_values is not None:
image_hash = ( # Use image hash as fake token_ids, which is then used
hash(tuple(recv_req.image_hash)) # for prefix matching
if isinstance(recv_req.image_hash, list) image_hash = hash(tuple(recv_req.image_hashes))
else recv_req.image_hash
)
req.pad_value = [ req.pad_value = [
(image_hash) % self.model_config.vocab_size, (image_hash) % self.model_config.vocab_size,
(image_hash >> 16) % self.model_config.vocab_size, (image_hash >> 16) % self.model_config.vocab_size,
(image_hash >> 32) % self.model_config.vocab_size, (image_hash >> 32) % self.model_config.vocab_size,
(image_hash >> 64) % self.model_config.vocab_size, (image_hash >> 64) % self.model_config.vocab_size,
] ]
req.image_size = recv_req.image_size req.image_sizes = recv_req.image_sizes
( (
req.origin_input_ids, req.origin_input_ids,
req.image_offset, req.image_offsets,
) = self.model_runner.model.pad_input_ids( ) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded, req.origin_input_ids_unpadded,
req.pad_value, req.pad_value,
req.pixel_values.shape, req.pixel_values,
req.image_size, req.image_sizes,
) )
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
...@@ -368,6 +366,7 @@ class ModelTpServer: ...@@ -368,6 +366,7 @@ class ModelTpServer:
req.jump_forward_map = self.jump_forward_cache.query( req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string computed_regex_string
) )
# Init regex fsm # Init regex fsm
elif req.sampling_params.regex is not None: elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
......
...@@ -16,7 +16,7 @@ limitations under the License. ...@@ -16,7 +16,7 @@ limitations under the License.
"""ModelRunner runs the forward passes of the models.""" """ModelRunner runs the forward passes of the models."""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List
import numpy as np import numpy as np
import torch import torch
...@@ -58,6 +58,7 @@ class InputMetadata: ...@@ -58,6 +58,7 @@ class InputMetadata:
# For extend # For extend
extend_seq_lens: torch.Tensor = None extend_seq_lens: torch.Tensor = None
extend_prefix_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None extend_start_loc: torch.Tensor = None
extend_no_prefix: bool = None extend_no_prefix: bool = None
...@@ -69,8 +70,8 @@ class InputMetadata: ...@@ -69,8 +70,8 @@ class InputMetadata:
# For multimodal # For multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
image_sizes: List[List[int]] = None image_sizes: List[List[List[int]]] = None
image_offsets: List[int] = None image_offsets: List[List[int]] = None
# Trition attention backend # Trition attention backend
triton_max_seq_len: int = 0 triton_max_seq_len: int = 0
...@@ -87,20 +88,8 @@ class InputMetadata: ...@@ -87,20 +88,8 @@ class InputMetadata:
def init_multimuldal_info(self, batch: ScheduleBatch): def init_multimuldal_info(self, batch: ScheduleBatch):
reqs = batch.reqs reqs = batch.reqs
self.pixel_values = [r.pixel_values for r in reqs] self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs] self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [] self.image_offsets = [r.image_offsets for r in reqs]
for r in reqs:
if isinstance(r.image_offset, list):
self.image_offsets.append(
[
(image_offset - len(r.prefix_indices))
for image_offset in r.image_offset
]
)
elif isinstance(r.image_offset, int):
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
elif r.image_offset is None:
self.image_offsets.append(0)
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets position_ids_offsets = batch.position_ids_offsets
...@@ -153,6 +142,7 @@ class InputMetadata: ...@@ -153,6 +142,7 @@ class InputMetadata:
for i, r in enumerate(batch.reqs) for i, r in enumerate(batch.reqs)
] ]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
...@@ -238,10 +228,10 @@ class InputMetadata: ...@@ -238,10 +228,10 @@ class InputMetadata:
prefix_lens_cpu, prefix_lens_cpu,
flashinfer_use_ragged, flashinfer_use_ragged,
): ):
if self.forward_mode != ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None prefix_lens = None
else:
prefix_lens = self.extend_prefix_lens
update_flashinfer_indices( update_flashinfer_indices(
self.forward_mode, self.forward_mode,
......
...@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_config import AttentionArch from sglang.srt.model_config import AttentionArch, ModelConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__) ...@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
class ModelRunner: class ModelRunner:
def __init__( def __init__(
self, self,
model_config, model_config: ModelConfig,
mem_fraction_static: float, mem_fraction_static: float,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
...@@ -85,7 +85,9 @@ class ModelRunner: ...@@ -85,7 +85,9 @@ class ModelRunner:
self.tp_size = tp_size self.tp_size = tp_size
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
global_server_args_dict.update( global_server_args_dict.update(
{ {
"disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer": server_args.disable_flashinfer,
...@@ -95,6 +97,13 @@ class ModelRunner: ...@@ -95,6 +97,13 @@ class ModelRunner:
} }
) )
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.load_model() self.load_model()
self.init_memory_pool( self.init_memory_pool(
...@@ -507,9 +516,9 @@ class ModelRunner: ...@@ -507,9 +516,9 @@ class ModelRunner:
raise Exception( raise Exception(
f"Capture cuda graph failed: {e}\n" f"Capture cuda graph failed: {e}\n"
"Possible solutions:\n" "Possible solutions:\n"
"1. disable torch compile by not using --enable-torch-compile\n" "1. disable cuda graph by --disable-cuda-graph\n"
"2. disable cuda graph by --disable-cuda-graph\n" "2. set --mem-fraction-static to a smaller value\n"
"3. set --mem-fraction-static to a smaller value\n" "3. disable torch compile by not using --enable-torch-compile\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
) )
......
...@@ -17,7 +17,7 @@ limitations under the License. ...@@ -17,7 +17,7 @@ limitations under the License.
# Adapted from # Adapted from
# https://github.com/THUDM/ChatGLM2-6B # https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights.""" """Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
......
...@@ -273,9 +273,9 @@ class Grok1Model(nn.Module): ...@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
hidden_states.mul_(self.config.embedding_multiplier_scale)
else: else:
hidden_states = input_embeds hidden_states = input_embeds
hidden_states.mul_(self.config.embedding_multiplier_scale)
for i in range(len(self.layers)): for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata) hidden_states = self.layers[i](positions, hidden_states, input_metadata)
...@@ -284,7 +284,7 @@ class Grok1Model(nn.Module): ...@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
return hidden_states return hidden_states
class Grok1ModelForCausalLM(nn.Module): class Grok1ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -415,4 +415,10 @@ def _prepare_presharded_weights( ...@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
return hf_folder, hf_weights_files, use_safetensors return hf_folder, hf_weights_files, use_safetensors
EntryClass = Grok1ModelForCausalLM class Grok1ModelForCausalLM(Grok1ForCausalLM):
"""An alias for backward-compatbility."""
pass
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
...@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -357,6 +357,9 @@ class LlamaForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -364,8 +367,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -374,8 +375,6 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module): ...@@ -103,8 +103,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module): ...@@ -113,8 +111,6 @@ class LlamaForClassification(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -57,6 +57,9 @@ class LlamaEmbeddingModel(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -64,8 +67,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -74,8 +75,6 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
return return
if name.startswith("model.vision_tower") and name not in params_dict:
return
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -28,7 +28,6 @@ from transformers import ( ...@@ -28,7 +28,6 @@ from transformers import (
LlavaConfig, LlavaConfig,
MistralConfig, MistralConfig,
Qwen2Config, Qwen2Config,
SiglipVisionConfig,
SiglipVisionModel, SiglipVisionModel,
) )
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
...@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -66,13 +65,18 @@ class LlavaLlamaForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
) )
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad" image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad"
offset_list = [] offset_list = []
for image_s in image_size: for image_s in image_sizes:
if len(image_size) > 16: if len(image_sizes) > 16:
# 2x2 pooling with stride 2 # 2x2 pooling with stride 2
new_image_feature_len = ( new_image_feature_len = (
math.ceil(self.image_size / self.patch_size / 2) ** 2 math.ceil(self.image_size / self.patch_size / 2) ** 2
...@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -153,17 +157,15 @@ class LlavaLlamaForCausalLM(nn.Module):
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Embed text input # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = ( # Whether the requests need vision inputs
(positions[input_metadata.extend_start_loc] < self.image_feature_len) max_image_offset = np.array(
.cpu() [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
.numpy()
) )
# FIXME: We need to substract the length of the system prompt start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) need_vision = start_positions <= max_image_offset
need_vision = need_vision & has_pixel
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
...@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -332,31 +334,35 @@ class LlavaLlamaForCausalLM(nn.Module):
new_image_features.append(image_feature) new_image_features.append(image_feature)
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
pad_dim = image_features[pt].shape[-1] # 576, 4096 prefix_len = prefix_lens_cpu[i]
dim = input_embeds.shape[1]
assert ( # Multiple images
pad_dim == dim for j, image_offset in enumerate(image_offsets[i]):
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) if image_offset < prefix_len:
# Fill in the placeholder for the image continue
tmp_image_feature = image_features[pt][j]
pad_len = tmp_image_feature.shape[0]
left_idx = start_idx + (image_offset - prefix_len)
right_idx = start_idx + (image_offset - prefix_len) + pad_len
try: try:
for j, image_off in enumerate(image_offsets[i]): input_embeds[left_idx:right_idx] = tmp_image_feature
# print("actual image_features length: ", image_features[pt][j].shape[0])
pad_len = image_features[pt][j].shape[0]
input_embeds[
start_idx + image_off : start_idx + image_off + pad_len
] = image_features[pt][j]
except RuntimeError as e: except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}") print(f"RuntimeError in image encoding: {e}")
print(image_features[pt].shape) print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(input_embeds.shape) print(
print(start_idx, image_offsets[i]) f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
)
pt += 1 pt += 1
return self.language_model( return self.language_model(
...@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -366,8 +372,9 @@ class LlavaLlamaForCausalLM(nn.Module):
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
if "clip" in vision_path: if "clip" in vision_path:
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
...@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -422,8 +429,6 @@ class LlavaLlamaForCausalLM(nn.Module):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
return self.image_size // self.patch_size return self.image_size // self.patch_size
...@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): ...@@ -495,36 +500,4 @@ class LlavaMistralForCausalLM(LlavaLlamaForCausalLM):
) )
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
...@@ -26,11 +26,6 @@ from vllm.config import CacheConfig ...@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.models.llama2 import LlamaForCausalLM from sglang.srt.models.llama2 import LlamaForCausalLM
...@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
) )
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): def pad_input_ids(
self,
input_ids: List[int],
pad_value: List[int],
pixel_values: List,
image_sizes: List[List[int]],
):
new_image_feature_len = self.image_feature_len new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * ( pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value) (new_image_feature_len + len(pad_value)) // len(pad_value)
...@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
+ pad_ids[:new_image_feature_len] + pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
return new_input_ids, offset return new_input_ids, [offset]
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
...@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Embed text input # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input # Whether the requests need vision inputs
need_vision = ( max_image_offset = np.array(
(positions[input_metadata.extend_start_loc] < self.image_feature_len) [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
.cpu()
.numpy()
) )
# FIXME: We need to substract the length of the system prompt start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) need_vision = start_positions <= max_image_offset
need_vision = need_vision & has_pixel
if need_vision.any(): if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ######## ########## Encode Image ########
...@@ -183,30 +165,35 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -183,30 +165,35 @@ class LlavaVidForCausalLM(nn.Module):
new_image_features.append(image_feature.flatten(0, 1)) new_image_features.append(image_feature.flatten(0, 1))
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096 prefix_len = prefix_lens_cpu[i]
dim = input_embeds.shape[1]
assert ( # Multiple images
pad_dim == dim for image_offset in image_offsets[i]:
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) if image_offset < prefix_len:
# Fill in the placeholder for the image continue
tmp_image_feature = image_features[pt]
pad_len = tmp_image_feature.shape[0]
left_idx = start_idx + (image_offset - prefix_len)
right_idx = start_idx + (image_offset - prefix_len) + pad_len
try: try:
input_embeds[ input_embeds[left_idx:right_idx] = tmp_image_feature
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e: except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}") print(f"RuntimeError in image encoding: {e}")
print(input_embeds.shape) print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
print(start_idx, image_offsets[i]) print(
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
)
pt += 1 pt += 1
return self.language_model( return self.language_model(
...@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, input_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir # huggingface_name or path_of_clip_relative_to_llava_model_dir
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
vision_path = self.config.mm_vision_tower vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained( self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16 vision_path, torch_dtype=torch.float16
...@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
@property @property
def num_patches_per_side(self): def num_patches_per_side(self):
return self.image_size // self.patch_size return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaVidForCausalLM EntryClass = LlavaVidForCausalLM
...@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -312,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
...@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -319,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -329,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -24,10 +24,7 @@ from vllm.config import CacheConfig ...@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llava import ( from sglang.srt.models.llava import LlavaLlamaForCausalLM
LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward,
)
class YiVLForCausalLM(LlavaLlamaForCausalLM): class YiVLForCausalLM(LlavaLlamaForCausalLM):
...@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self.config._name_or_path, self.config._name_or_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
subfolder=self.vision_tower_subfolder, subfolder=self.vision_tower_subfolder,
).cuda() ).to("cuda")
self.vision_tower.eval() self.vision_tower.eval()
...@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
# load language model # load language model
self.language_model.load_weights(weights) self.language_model.load_weights(weights)
monkey_path_clip_vision_embed_forward()
class YiVLMultiModalProjector(nn.Module): class YiVLMultiModalProjector(nn.Module):
def __init__(self, config: LlavaConfig): def __init__(self, config: LlavaConfig):
......
...@@ -335,12 +335,12 @@ def launch_server( ...@@ -335,12 +335,12 @@ def launch_server(
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
if server_args.dp_size == 1: if server_args.dp_size == 1:
start_process = start_controller_process_single start_controller_process = start_controller_process_single
else: else:
start_process = start_controller_process_multi start_controller_process = start_controller_process_multi
proc_controller = mp.Process( proc_controller = mp.Process(
target=start_process, target=start_controller_process,
args=(server_args, port_args, pipe_controller_writer, model_overide_args), args=(server_args, port_args, pipe_controller_writer, model_overide_args),
) )
proc_controller.start() proc_controller.start()
...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -421,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.6", "0.1.5",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment