Unverified Commit 690d162d authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Format code (#441)

parent 664287b2
......@@ -31,7 +31,7 @@ if __name__ == "__main__":
url + "/generate",
json={
"text": f"{a}, ",
#"input_ids": [[2] * 256] * 196,
# "input_ids": [[2] * 256] * 196,
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
......
......@@ -74,4 +74,4 @@ class Anthropic(BaseBackend):
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}
\ No newline at end of file
yield text, {}
......@@ -30,7 +30,11 @@ from sglang.lang.ir import (
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync):
......
......@@ -13,4 +13,4 @@ if __name__ == "__main__":
args = parser.parse_args()
response = requests.get(args.url + "/flush_cache")
assert response.status_code == 200
\ No newline at end of file
assert response.status_code == 200
......@@ -32,7 +32,9 @@ class GenerateReqInput:
def post_init(self):
if self.text is None:
assert self.input_ids is not None, "Either text or input_ids should be provided"
assert (
self.input_ids is not None
), "Either text or input_ids should be provided"
else:
assert self.input_ids is None, "Either text or input_ids should be provided"
......
......@@ -24,7 +24,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
from sglang.srt.managers.router.model_runner import ModelRunner
from sglang.srt.managers.router.radix_cache import RadixCache
from sglang.srt.managers.router.scheduler import Scheduler
......@@ -37,7 +37,6 @@ from sglang.srt.utils import (
set_random_seed,
)
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
......@@ -238,7 +237,9 @@ class ModelRpcServer:
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
throuhgput = self.num_generated_tokens / (
time.time() - self.last_stats_tic
)
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
logger.info(
......@@ -401,12 +402,12 @@ class ModelRpcServer:
f"#running_req: {running_req}. "
f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
)
#logger.debug(
# logger.debug(
# f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
# f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
# f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
# f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
#)
# )
new_batch = Batch.init_new(
can_run_list,
......@@ -440,11 +441,10 @@ class ModelRpcServer:
# Only transfer the selected logprobs of the next token to CPU to reduce overhead.
if last_logprobs is not None:
last_token_logprobs = (
last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids].tolist()
)
last_token_logprobs = last_logprobs[
torch.arange(len(batch.reqs), device=next_token_ids.device),
next_token_ids,
].tolist()
next_token_ids = next_token_ids.tolist()
else:
......
......@@ -17,8 +17,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,
......
......@@ -72,7 +72,8 @@ def get_pixel_values(
image_hash = hash(image_data)
if image_aspect_ratio == "pad":
image = expand2square(
image, tuple(int(x * 255) for x in processor.image_processor.image_mean)
image,
tuple(int(x * 255) for x in processor.image_processor.image_mean),
)
pixel_values = processor.image_processor(image)["pixel_values"][0]
elif image_aspect_ratio == "anyres":
......@@ -208,10 +209,12 @@ class TokenizerManager:
while True:
await event.wait()
out = self.convert_logprob_style(state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs)
out = self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob,
obj.top_logprobs_num,
obj.return_text_in_logprobs,
)
if self.server_args.log_requests and state.finished:
logger.info(f"in={obj.text}, out={out}")
......@@ -275,10 +278,13 @@ class TokenizerManager:
state = self.rid_to_state[rid]
await state.event.wait()
output_list.append(
self.convert_logprob_style(state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs))
self.convert_logprob_style(
state.out_list[-1],
obj.return_logprob[i],
obj.top_logprobs_num[i],
obj.return_text_in_logprobs,
)
)
assert state.finished
del self.rid_to_state[rid]
......@@ -311,7 +317,9 @@ class TokenizerManager:
else:
raise ValueError(f"Invalid object: {recv_obj}")
def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
def convert_logprob_style(
self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs
):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
......@@ -320,11 +328,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
)
ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
)
return ret
......
......@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward():
)
EntryClass = LlavaLlamaForCausalLM
\ No newline at end of file
EntryClass = LlavaLlamaForCausalLM
......@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
......@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaMistralForCausalLM(nn.Module):
......
......@@ -5,13 +5,9 @@ from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config
from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
......@@ -21,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaQwenForCausalLM(nn.Module):
......
"""Conversion between OpenAI APIs and native SRT APIs"""
import json
import os
......@@ -31,9 +32,9 @@ from sglang.srt.openai_protocol import (
)
from sglang.srt.utils import jsonify_pydantic_model
chat_template_name = None
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
......@@ -353,4 +354,4 @@ def to_openai_style_logprobs(
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
return ret_logprobs
\ No newline at end of file
return ret_logprobs
"""pydantic models for OpenAI API protocol"""
import time
from typing import Dict, List, Optional, Union
......@@ -178,4 +179,4 @@ class ChatCompletionStreamResponse(BaseModel):
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
\ No newline at end of file
choices: List[ChatCompletionResponseStreamChoice]
......@@ -30,15 +30,18 @@ from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
v1_completions, v1_chat_completions, load_chat_template_for_openai_api)
load_chat_template_for_openai_api,
v1_chat_completions,
v1_completions,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware,
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
get_exception_traceback,
API_KEY_HEADER_NAME,
APIKeyValidatorMiddleware
)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......
......@@ -275,7 +275,9 @@ def is_multimodal_model(model):
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type")
......
......@@ -138,6 +138,7 @@ def encode_frame(frame):
def encode_video_base64(video_path, num_frames=16):
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
......
......@@ -9,7 +9,6 @@ import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment