"src/turbomind/utils/nccl_utils.cc" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
from typing import Any, Dict, Optional, Iterable, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
......
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Optional, Tuple, Iterable from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
"""Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Optional, Tuple, Iterable from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
......
"""Inference-only Yi-VL model.""" """Inference-only Yi-VL model."""
from typing import Tuple, Iterable, Optional from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.models.llava import ( from sglang.srt.models.llava import (
LlavaLlamaForCausalLM, LlavaLlamaForCausalLM,
monkey_path_clip_vision_embed_forward, monkey_path_clip_vision_embed_forward,
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from http import HTTPStatus from http import HTTPStatus
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import JSONResponse, StreamingResponse
from sglang.srt.conversation import ( from sglang.srt.conversation import (
Conversation, Conversation,
...@@ -40,21 +40,18 @@ chat_template_name = None ...@@ -40,21 +40,18 @@ chat_template_name = None
def create_error_response( def create_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST): status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
error = ErrorResponse(message=message, ):
type=err_type, error = ErrorResponse(message=message, type=err_type, code=status_code.value)
code=status_code.value) return JSONResponse(content=error.model_dump(), status_code=error.code)
return JSONResponse(content=error.model_dump(),
status_code=error.code)
def create_streaming_error_response( def create_streaming_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
error = ErrorResponse(message=message, ) -> str:
type=err_type, error = ErrorResponse(message=message, type=err_type, code=status_code.value)
code=status_code.value)
json_str = json.dumps({"error": error.model_dump()}) json_str = json.dumps({"error": error.model_dump()})
return json_str return json_str
...@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -125,7 +122,8 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
n_prev_token = 0 n_prev_token = 0
try: try:
async for content in tokenizer_manager.generate_request( async for content in tokenizer_manager.generate_request(
adapted_request, raw_request): adapted_request, raw_request
):
text = content["text"] text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"] prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"] completion_tokens = content["meta_info"]["completion_tokens"]
...@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -154,12 +152,14 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
decode_token_logprobs=content["meta_info"][ decode_token_logprobs=content["meta_info"][
"decode_token_logprobs" "decode_token_logprobs"
][n_prev_token:], ][n_prev_token:],
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][ decode_top_logprobs=content["meta_info"][
n_prev_token: "decode_top_logprobs"
], ][n_prev_token:],
) )
n_prev_token = len(content["meta_info"]["decode_token_logprobs"]) n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
)
else: else:
logprobs = None logprobs = None
...@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request): ...@@ -188,13 +188,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n" yield f"data: {error}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(adapted_request)) generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response. # Non-streaming response.
try: try:
ret = await tokenizer_manager.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__() adapted_request, raw_request
).__anext__()
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))
...@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -299,7 +303,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
stream_buffer = "" stream_buffer = ""
try: try:
async for content in tokenizer_manager.generate_request(adapted_request, raw_request): async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
if is_first: if is_first:
# First chunk with role # First chunk with role
is_first = False is_first = False
...@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ...@@ -334,13 +340,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
yield f"data: {error}\n\n" yield f"data: {error}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream_resp(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(adapted_request)) generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response. # Non-streaming response.
try: try:
ret = await tokenizer_manager.generate_request( ret = await tokenizer_manager.generate_request(
adapted_request, raw_request).__anext__() adapted_request, raw_request
).__anext__()
except ValueError as e: except ValueError as e:
return create_error_response(str(e)) return create_error_response(str(e))
......
...@@ -13,7 +13,7 @@ import sys ...@@ -13,7 +13,7 @@ import sys
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Dict from typing import Dict, Optional
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse ...@@ -29,10 +29,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.constrained import disable_cache from sglang.srt.constrained import disable_cache
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.manager_multi import (
start_controller_process as start_controller_process_multi,
)
from sglang.srt.managers.controller.manager_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.controller.manager_single import start_controller_process as start_controller_process_single
from sglang.srt.managers.controller.manager_multi import start_controller_process as start_controller_process_multi
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api_adapter import ( from sglang.srt.openai_api_adapter import (
load_chat_template_for_openai_api, load_chat_template_for_openai_api,
...@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -97,8 +101,11 @@ async def generate_request(obj: GenerateReqInput, request: Request):
yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse(stream_results(), media_type="text/event-stream", return StreamingResponse(
background=tokenizer_manager.create_abort_task(obj)) stream_results(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(obj),
)
else: else:
try: try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
......
"""Common utilities.""" """Common utilities."""
import base64 import base64
import multiprocessing
import logging import logging
import multiprocessing
import os import os
import random import random
import socket import socket
...@@ -17,12 +17,11 @@ import requests ...@@ -17,12 +17,11 @@ import requests
import rpyc import rpyc
import torch import torch
import triton import triton
from rpyc.utils.server import ThreadedServer
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from rpyc.utils.server import ThreadedServer
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int): ...@@ -377,7 +376,7 @@ def init_rpyc_service(service: rpyc.Service, port: int):
protocol_config={ protocol_config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 3600 "sync_request_timeout": 3600,
}, },
) )
t.logger.setLevel(logging.WARN) t.logger.setLevel(logging.WARN)
...@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"): ...@@ -396,7 +395,7 @@ def connect_to_rpyc_service(port, host="localhost"):
config={ config={
"allow_public_attrs": True, "allow_public_attrs": True,
"allow_pickle": True, "allow_pickle": True,
"sync_request_timeout": 3600 "sync_request_timeout": 3600,
}, },
) )
break break
...@@ -423,7 +422,9 @@ def suppress_other_loggers(): ...@@ -423,7 +422,9 @@ def suppress_other_loggers():
vllm_default_logger.setLevel(logging.WARN) vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.config").setLevel(logging.ERROR) logging.getLogger("vllm.config").setLevel(logging.ERROR)
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN) logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN)
...@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): ...@@ -464,6 +465,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
device_name = torch.cuda.get_device_name(gpu_id) device_name = torch.cuda.get_device_name(gpu_id)
if "RTX 40" not in device_name: if "RTX 40" not in device_name:
import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
...@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -485,4 +487,3 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
) )
response = await call_next(request) response = await call_next(request)
return response return response
...@@ -356,16 +356,25 @@ def test_completion_speculative(): ...@@ -356,16 +356,25 @@ def test_completion_speculative():
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n" s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
@sgl.function @sgl.function
def gen_character_no_spec(s): def gen_character_no_spec(s):
s += "Construct a character within the following format:\n" s += "Construct a character within the following format:\n"
s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n" s += "Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
s += "\nPlease generate new Name, Birthday and Job.\n" s += "\nPlease generate new Name, Birthday and Job.\n"
s += "Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") s += (
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
)
s += "\nJob:" + sgl.gen("job", stop="\n") + "\n" s += "\nJob:" + sgl.gen("job", stop="\n") + "\n"
token_usage = sgl.global_config.default_backend.token_usage token_usage = sgl.global_config.default_backend.token_usage
...@@ -378,7 +387,9 @@ def test_completion_speculative(): ...@@ -378,7 +387,9 @@ def test_completion_speculative():
gen_character_no_spec().sync() gen_character_no_spec().sync()
usage_with_no_spec = token_usage.prompt_tokens usage_with_no_spec = token_usage.prompt_tokens
assert usage_with_spec < usage_with_no_spec, f"{usage_with_spec} vs {usage_with_no_spec}" assert (
usage_with_spec < usage_with_no_spec
), f"{usage_with_spec} vs {usage_with_no_spec}"
def test_chat_completion_speculative(): def test_chat_completion_speculative():
...@@ -386,8 +397,17 @@ def test_chat_completion_speculative(): ...@@ -386,8 +397,17 @@ def test_chat_completion_speculative():
def gen_character_spec(s): def gen_character_spec(s):
s += sgl.system("You are a helpful assistant.") s += sgl.system("You are a helpful assistant.")
s += sgl.user("Construct a character within the following format:") s += sgl.user("Construct a character within the following format:")
s += sgl.assistant("Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n") s += sgl.assistant(
"Name: Steve Jobs.\nBirthday: February 24, 1955.\nJob: Apple CEO.\n"
)
s += sgl.user("Please generate new Name, Birthday and Job.\n") s += sgl.user("Please generate new Name, Birthday and Job.\n")
s += sgl.assistant("Name:" + sgl.gen("name", stop="\n") + "\nBirthday:" + sgl.gen("birthday", stop="\n") + "\nJob:" + sgl.gen("job", stop="\n")) s += sgl.assistant(
"Name:"
+ sgl.gen("name", stop="\n")
+ "\nBirthday:"
+ sgl.gen("birthday", stop="\n")
+ "\nJob:"
+ sgl.gen("job", stop="\n")
)
gen_character_spec().sync() gen_character_spec().sync()
...@@ -15,7 +15,6 @@ from json import dumps ...@@ -15,7 +15,6 @@ from json import dumps
import numpy as np import numpy as np
import requests import requests
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -255,7 +254,9 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None): ...@@ -255,7 +254,9 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
def graceful_registry(sub_module_name): def graceful_registry(sub_module_name):
def graceful_shutdown(signum, frame): def graceful_shutdown(signum, frame):
logger.info(f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown...") logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
)
if signum == signal.SIGTERM: if signum == signal.SIGTERM:
logger.info(f"{sub_module_name} recive sigterm") logger.info(f"{sub_module_name} recive sigterm")
......
...@@ -2,6 +2,8 @@ import unittest ...@@ -2,6 +2,8 @@ import unittest
from sglang import OpenAI, set_default_backend from sglang import OpenAI, set_default_backend
from sglang.test.test_programs import ( from sglang.test.test_programs import (
test_chat_completion_speculative,
test_completion_speculative,
test_decode_int, test_decode_int,
test_decode_json, test_decode_json,
test_expert_answer, test_expert_answer,
...@@ -14,8 +16,6 @@ from sglang.test.test_programs import ( ...@@ -14,8 +16,6 @@ from sglang.test.test_programs import (
test_select, test_select,
test_stream, test_stream,
test_tool_use, test_tool_use,
test_completion_speculative,
test_chat_completion_speculative
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment