"src/vscode:/vscode.git/clone" did not exist on "83f8a5ff70d9305735b23131ed2015d3db0e7422"
Unverified Commit 13662fd5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix RuntimeEndpoint (#279)

parent d5ae2eba
...@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs): ...@@ -43,18 +43,21 @@ def Runtime(*args, **kwargs):
def set_default_backend(backend: BaseBackend): def set_default_backend(backend: BaseBackend):
global_config.default_backend = backend global_config.default_backend = backend
def flush_cache(backend: BaseBackend = None): def flush_cache(backend: BaseBackend = None):
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
if backend is None: if backend is None:
return False return False
return backend.flush_cache() return backend.flush_cache()
def get_server_args(backend: BaseBackend = None): def get_server_args(backend: BaseBackend = None):
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
if backend is None: if backend is None:
return None return None
return backend.get_server_args() return backend.get_server_args()
def gen( def gen(
name: Optional[str] = None, name: Optional[str] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
......
...@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request ...@@ -12,7 +12,13 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend): class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None): def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
verify: Optional[str] = None,
):
super().__init__() super().__init__()
self.support_concate_and_append = True self.support_concate_and_append = True
...@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -61,7 +67,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -71,7 +77,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -159,7 +165,7 @@ class RuntimeEndpoint(BaseBackend):
json=data, json=data,
stream=True, stream=True,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
pos = 0 pos = 0
......
...@@ -20,8 +20,6 @@ import requests ...@@ -20,8 +20,6 @@ import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
...@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process ...@@ -56,11 +54,14 @@ from sglang.srt.managers.router.manager import start_router_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import handle_port_init from sglang.srt.utils import handle_port_init
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key" API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware): class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str): def __init__(self, app, api_key: str):
super().__init__(app) super().__init__(app)
...@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware): ...@@ -77,6 +78,7 @@ class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
response = await call_next(request) response = await call_next(request)
return response return response
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
chat_template_name = None chat_template_name = None
......
...@@ -88,7 +88,9 @@ class HttpResponse: ...@@ -88,7 +88,9 @@ class HttpResponse:
return self.resp.status return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None): def http_request(
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
):
"""A faster version of requests.post with low-level urllib API.""" """A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"} headers = {"Content-Type": "application/json; charset=utf-8"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment