Unverified Commit 690d162d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Format code (#441)

parent 664287b2
...@@ -31,7 +31,7 @@ if __name__ == "__main__": ...@@ -31,7 +31,7 @@ if __name__ == "__main__":
url + "/generate", url + "/generate",
json={ json={
"text": f"{a}, ", "text": f"{a}, ",
#"input_ids": [[2] * 256] * 196, # "input_ids": [[2] * 256] * 196,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
......
...@@ -30,7 +30,11 @@ from sglang.lang.ir import ( ...@@ -30,7 +30,11 @@ from sglang.lang.ir import (
SglVarScopeEnd, SglVarScopeEnd,
SglVideo, SglVideo,
) )
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync): def run_internal(state, program, func_args, func_kwargs, sync):
......
...@@ -32,7 +32,9 @@ class GenerateReqInput: ...@@ -32,7 +32,9 @@ class GenerateReqInput:
def post_init(self): def post_init(self):
if self.text is None: if self.text is None:
assert self.input_ids is not None, "Either text or input_ids should be provided" assert (
self.input_ids is not None
), "Either text or input_ids should be provided"
else: else:
assert self.input_ids is None, "Either text or input_ids should be provided" assert self.input_ids is None, "Either text or input_ids should be provided"
......
...@@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler from sglang.srt.managers.router.scheduler import Scheduler
...@@ -37,7 +37,6 @@ from sglang.srt.utils import ( ...@@ -37,7 +37,6 @@ from sglang.srt.utils import (
set_random_seed, set_random_seed,
) )
logger = logging.getLogger("model_rpc") logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
...@@ -238,7 +237,9 @@ class ModelRpcServer: ...@@ -238,7 +237,9 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size() self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic) throuhgput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
logger.info( logger.info(
...@@ -401,12 +402,12 @@ class ModelRpcServer: ...@@ -401,12 +402,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. " f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%." f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
) )
#logger.debug( # logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. " # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. " # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. " # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. " # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
#) # )
new_batch = Batch.init_new( new_batch = Batch.init_new(
can_run_list, can_run_list,
...@@ -440,11 +441,10 @@ class ModelRpcServer: ...@@ -440,11 +441,10 @@ class ModelRpcServer:
# Only transfer the selected logprobs of the next token to CPU to reduce overhead. # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None: if last_logprobs is not None:
last_token_logprobs = ( last_token_logprobs = last_logprobs[
last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device), torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids].tolist() next_token_ids,
) ].tolist()
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
......
...@@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype ...@@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
QUANTIZATION_CONFIG_MAPPING = { QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig, "awq": AWQConfig,
......
...@@ -72,7 +72,8 @@ def get_pixel_values( ...@@ -72,7 +72,8 @@ def get_pixel_values(
image_hash = hash(image_data) image_hash = hash(image_data)
if image_aspect_ratio == "pad": if image_aspect_ratio == "pad":
image = expand2square( image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean) image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
) )
pixel_values = processor.image_processor(image)["pixel_values"][0] pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres": elif image_aspect_ratio == "anyres":
...@@ -208,10 +209,12 @@ class TokenizerManager: ...@@ -208,10 +209,12 @@ class TokenizerManager:
while True: while True:
await event.wait() await event.wait()
out = self.convert_logprob_style(state.out_list[-1], out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob, obj.return_logprob,
obj.top_logprobs_num, obj.top_logprobs_num,
obj.return_text_in_logprobs) obj.return_text_in_logprobs,
)
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}") logger.info(f"in={obj.text}, out={out}")
...@@ -275,10 +278,13 @@ class TokenizerManager: ...@@ -275,10 +278,13 @@ class TokenizerManager:
state = self.rid_to_state[rid] state = self.rid_to_state[rid]
await state.event.wait() await state.event.wait()
output_list.append( output_list.append(
self.convert_logprob_style(state.out_list[-1], self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[i], obj.return_logprob[i],
obj.top_logprobs_num[i], obj.top_logprobs_num[i],
obj.return_text_in_logprobs)) obj.return_text_in_logprobs,
)
)
assert state.finished assert state.finished
del self.rid_to_state[rid] del self.rid_to_state[rid]
...@@ -311,7 +317,9 @@ class TokenizerManager: ...@@ -311,7 +317,9 @@ class TokenizerManager:
else: else:
raise ValueError(f"Invalid object: {recv_obj}") raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs): def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
):
if return_logprob: if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
...@@ -320,12 +328,16 @@ class TokenizerManager: ...@@ -320,12 +328,16 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens( ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
) )
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens( )
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
) )
)
return ret return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
......
...@@ -5,13 +5,9 @@ from typing import List, Optional ...@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import ( ...@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaMistralForCausalLM(nn.Module): class LlavaMistralForCausalLM(nn.Module):
......
...@@ -5,13 +5,9 @@ from typing import List, Optional ...@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata from sglang.srt.managers.router.model_runner import InputMetadata
...@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import ( ...@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaQwenForCausalLM(nn.Module): class LlavaQwenForCausalLM(nn.Module):
......
"""Conversion between OpenAI APIs and native SRT APIs""" """Conversion between OpenAI APIs and native SRT APIs"""
import json import json
import os import os
...@@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import ( ...@@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import (
) )
from sglang.srt.utils import jsonify_pydantic_model from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None chat_template_name = None
def load_chat_template_for_openai_api(chat_template_arg): def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name global chat_template_name
......
"""pydantic models for OpenAI API protocol""" """pydantic models for OpenAI API protocol"""
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
......
...@@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput ...@@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import ( from sglang.srt.openai_api_adapter import (
v1_completions, v1_chat_completions, load_chat_template_for_openai_api) load_chat_template_for_openai_api,
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
get_exception_traceback, get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
) )
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
...@@ -275,7 +275,9 @@ def is_multimodal_model(model): ...@@ -275,7 +275,9 @@ def is_multimodal_model(model):
if isinstance(model, ModelConfig): if isinstance(model, ModelConfig):
model_path = model.path.lower() model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type") raise ValueError("unrecognized type")
......
...@@ -138,6 +138,7 @@ def encode_frame(frame): ...@@ -138,6 +138,7 @@ def encode_frame(frame):
def encode_video_base64(video_path, num_frames=16): def encode_video_base64(video_path, num_frames=16):
import cv2 import cv2
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}") raise IOError(f"Could not open video file:{video_path}")
......
...@@ -9,7 +9,6 @@ import argparse ...@@ -9,7 +9,6 @@ import argparse
import requests import requests
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment