Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from typing import Any, Dict, Optional, Iterable, Tuple
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
......
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
......
......@@ -2,7 +2,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights."""
from typing import Optional, Tuple, Iterable
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
......
"""Inference-only Yi-VL model."""
from typing import Tuple, Iterable, Optional
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward,
......
......@@ -6,7 +6,7 @@ import os
from http import HTTPStatus
from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from sglang.srt.conversation import (
Conversation,
......@@ -40,21 +40,18 @@ chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST):
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
error = ErrorResponse(message=message,
type=err_type,
code=status_code.value)
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
......@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_token = 0
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request):
adapted_request, raw_request
):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
......@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
n_prev_token:
],
decode_top_logprobs=content["meta_info"][
"decode_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
)
else:
logprobs = None
......@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
......@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
stream_buffer = ""
try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request):
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
if is_first:
# First chunk with role
is_first = False
......@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request))
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__()
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
......
......@@ -13,7 +13,7 @@ import sys
import threading
import time
from http import HTTPStatus
from typing import Optional, Dict
from typing import Dict, Optional
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
......@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api,
......@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj))
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else:
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
......
"""Common utilities."""
import base64
import multiprocessing
import logging
import multiprocessing
import os
import random
import socket
......@@ -17,12 +17,11 @@ import requests
import rpyc
import torch
import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
......@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
protocol_config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
"sync_request_timeout": 3600,
},
)
t.logger.setLevel(logging.WARN)
......@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
config={
"allow_public_attrs": True,
"allow_pickle": True,
"sync_request_timeout": 3600
"sync_request_timeout": 3600,
},
)
break
......@@ -423,7 +422,9 @@ def suppress_other_loggers():
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
......@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
device_name = torch.cuda.get_device_name(gpu_id)
if "RTX 40" not in device_name:
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
......@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
)
response = await call_next(request)
return response
......@@ -356,16 +356,25 @@ def test_completion_speculative():
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function
def gen_character_no_spec(s):
s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n")
s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage
......@@ -378,7 +387,9 @@ def test_completion_speculative():
gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}"
assert (
usage_with_spec < usage_with_no_spec
), f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative():
......@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n")
s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n"))
s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
gen_character_spec().sync()
\ No newline at end of file
gen_character_spec().sync()
......@@ -15,7 +15,6 @@ from json import dumps
import numpy as np
import requests
logger = logging.getLogger(__name__)
......@@ -255,8 +254,10 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...")
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm")
signal.signal(signal.SIGTERM, graceful_shutdown)
\ No newline at end of file
signal.signal(signal.SIGTERM, graceful_shutdown)
......@@ -2,6 +2,8 @@ import unittest
from sglang import OpenAI, set_default_backend
from sglang.test.test_programs import (
test_chat_completion_speculative,
test_completion_speculative,
test_decode_int,
test_decode_json,
test_expert_answer,
......@@ -14,8 +16,6 @@ from sglang.test.test_programs import (
test_select,
test_stream,
test_tool_use,
test_completion_speculative,
test_chat_completion_speculative
)
......@@ -97,4 +97,4 @@ if __name__ == "__main__":
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUp()
# t.test_chat_completion_speculative()
\ No newline at end of file
# t.test_chat_completion_speculative()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment