Unverified Commit 13662fd5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix RuntimeEndpoint (#279)

parent d5ae2eba
......@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return False
return backend.flush_cache()
def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend
if backend is None:
return None
return backend.get_server_args()
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
......
......@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
verify: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
......@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend):
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key
api_key=self.api_key,
verify=self.verify,
)
pos = 0
......
......@@ -20,8 +20,6 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
......@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
......@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
chat_template_name = None
......
......@@ -88,7 +88,9 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
def http_request(
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment