Unverified Commit 0d4f3a9f authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Make API Key OpenAI-compatible (#917)

parent afd411d0
...@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -15,7 +15,6 @@ class RuntimeEndpoint(BaseBackend):
def __init__( def __init__(
self, self,
base_url: str, base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
verify: Optional[str] = None, verify: Optional[str] = None,
): ):
...@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend): ...@@ -23,13 +22,11 @@ class RuntimeEndpoint(BaseBackend):
self.support_concate_and_append = True self.support_concate_and_append = True
self.base_url = base_url self.base_url = base_url
self.auth_token = auth_token
self.api_key = api_key self.api_key = api_key
self.verify = verify self.verify = verify
res = http_request( res = http_request(
self.base_url + "/get_model_info", self.base_url + "/get_model_info",
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -67,7 +64,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -79,7 +75,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -91,7 +86,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -139,7 +133,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -193,7 +186,6 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
stream=True, stream=True,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -225,7 +217,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -243,7 +234,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
...@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend): ...@@ -267,7 +257,6 @@ class RuntimeEndpoint(BaseBackend):
res = http_request( res = http_request(
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token,
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
......
...@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import ( ...@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
API_KEY_HEADER_NAME, add_api_key_middleware,
APIKeyValidatorMiddleware,
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
set_torch_compile_config,
set_ulimit, set_ulimit,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request): ...@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request) return await v1_chat_completions(tokenizer_manager, raw_request)
@app.get("/v1/models")
def available_models():
"""Show available models."""
served_model_names = [tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
return ModelList(data=model_cards)
@app.post("/v1/files") @app.post("/v1/files")
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
return await v1_files_create( return await v1_files_create(
...@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str): ...@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id) return await v1_retrieve_file_content(file_id)
@app.get("/v1/models")
def available_models():
"""Show available models."""
served_model_names = [tokenizer_manager.served_model_name]
model_cards = []
for served_model_name in served_model_names:
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
return ModelList(data=model_cards)
def _set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
_set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
model_overide_args: Optional[dict] = None, model_overide_args: Optional[dict] = None,
pipe_finish_writer: Optional[mp.connection.Connection] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None,
): ):
server_args.check_server_args()
"""Launch an HTTP server.""" """Launch an HTTP server."""
global tokenizer_manager global tokenizer_manager
...@@ -258,16 +210,8 @@ def launch_server( ...@@ -258,16 +210,8 @@ def launch_server(
format="%(message)s", format="%(message)s",
) )
if not server_args.disable_flashinfer: server_args.check_server_args()
assert_pkg_version( _set_envs_and_config(server_args)
"flashinfer",
"0.1.3",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
set_envs_and_config(server_args)
# Allocate ports # Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
...@@ -284,7 +228,7 @@ def launch_server( ...@@ -284,7 +228,7 @@ def launch_server(
) )
logger.info(f"{server_args=}") logger.info(f"{server_args=}")
# Handle multi-node tensor parallelism # Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1: if server_args.nnodes > 1:
if server_args.node_rank != 0: if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes tp_size_local = server_args.tp_size // server_args.nnodes
...@@ -349,8 +293,9 @@ def launch_server( ...@@ -349,8 +293,9 @@ def launch_server(
sys.exit(1) sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive() assert proc_controller.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "": # Add api key authorization
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key) if server_args.api_key:
add_api_key_middleware(app, server_args.api_key)
# Send a warmup request # Send a warmup request
t = threading.Thread( t = threading.Thread(
...@@ -372,15 +317,58 @@ def launch_server( ...@@ -372,15 +317,58 @@ def launch_server(
t.join() t.join()
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
# Check flashinfer version
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
def _wait_and_warmup(server_args, pipe_finish_writer): def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {} headers = {}
url = server_args.url() url = server_args.url()
if server_args.api_key: if server_args.api_key:
headers[API_KEY_HEADER_NAME] = server_args.api_key headers["Authorization"] = f"Bearer {server_args.api_key}"
# Wait until the server is launched # Wait until the server is launched
for _ in range(120): for _ in range(120):
time.sleep(0.5) time.sleep(1)
try: try:
requests.get(url + "/get_model_info", timeout=5, headers=headers) requests.get(url + "/get_model_info", timeout=5, headers=headers)
break break
......
...@@ -61,7 +61,7 @@ class ServerArgs: ...@@ -61,7 +61,7 @@ class ServerArgs:
show_time_cost: bool = False show_time_cost: bool = False
# Other # Other
api_key: str = "" api_key: Optional[str] = None
file_storage_pth: str = "SGlang_storage" file_storage_pth: str = "SGlang_storage"
# Data parallelism # Data parallelism
...@@ -307,7 +307,7 @@ class ServerArgs: ...@@ -307,7 +307,7 @@ class ServerArgs:
"--api-key", "--api-key",
type=str, type=str,
default=ServerArgs.api_key, default=ServerArgs.api_key,
help="Set API key of the server.", help="Set API key of the server. It is also used in the OpenAI API compatible server.",
) )
parser.add_argument( parser.add_argument(
"--file-storage-pth", "--file-storage-pth",
......
...@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager): ...@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
raise RuntimeError("Could not create or locate cache dir") raise RuntimeError("Could not create or locate cache dir")
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
def get_ip_address(ifname): def get_ip_address(ifname):
""" """
Get the IP address of a network interface. Get the IP address of a network interface.
...@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args): ...@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
dist.destroy_process_group() dist.destroy_process_group()
def set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type) current_soft, current_hard = resource.getrlimit(resource_type)
...@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader(): ...@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
origin_weight_loader(self, param, loaded_weight, loaded_shard_id) origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
def add_api_key_middleware(app, api_key):
@app.middleware("http")
async def authentication(request, call_next):
if request.method == "OPTIONS":
return await call_next(request)
if request.url.path.startswith("/health"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key:
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return await call_next(request)
...@@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace): ...@@ -391,7 +391,11 @@ def get_call_select(args: argparse.Namespace):
def popen_launch_server( def popen_launch_server(
model: str, base_url: str, timeout: float, other_args: tuple = () model: str,
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: tuple = (),
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -408,12 +412,19 @@ def popen_launch_server( ...@@ -408,12 +412,19 @@ def popen_launch_server(
port, port,
*other_args, *other_args,
] ]
if api_key:
command += ["--api-key", api_key]
process = subprocess.Popen(command, stdout=None, stderr=None) process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
try: try:
response = requests.get(f"{base_url}/v1/models") headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = requests.get(f"{base_url}/v1/models", headers=headers)
if response.status_code == 200: if response.status_code == 200:
return process return process
except requests.RequestException: except requests.RequestException:
......
...@@ -76,19 +76,13 @@ class HttpResponse: ...@@ -76,19 +76,13 @@ class HttpResponse:
return self.resp.status return self.resp.status
def http_request( def http_request(url, json=None, stream=False, api_key=None, verify=None):
url, json=None, stream=False, auth_token=None, api_key=None, verify=None
):
"""A faster version of requests.post with low-level urllib API.""" """A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"} headers = {"Content-Type": "application/json; charset=utf-8"}
# add the Authorization header if an auth token is provided # add the Authorization header if an api key is provided
if auth_token is not None:
headers["Authorization"] = f"Bearer {auth_token}"
# add the API Key header if an API key is provided
if api_key is not None: if api_key is not None:
headers["X-API-Key"] = api_key headers["Authorization"] = f"Bearer {api_key}"
if stream: if stream:
return requests.post(url, json=json, stream=True, headers=headers) return requests.post(url, json=json, stream=True, headers=headers)
......
...@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000" cls.base_url = f"http://localhost:30000"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.base_url += "/v1" cls.base_url += "/v1"
@classmethod @classmethod
...@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase):
kill_child_process(cls.process.pid) kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs, use_list_input): def run_completion(self, echo, logprobs, use_list_input):
client = openai.Client(api_key="EMPTY", base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
if use_list_input: if use_list_input:
...@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs): def run_completion_stream(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is" prompt = "The capital of France is"
generator = client.completions.create( generator = client.completions.create(
model=self.model, model=self.model,
...@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_chat_completion(self, logprobs): def run_chat_completion(self, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( response = client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[
...@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs): def run_chat_completion_stream(self, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create( generator = client.chat.completions.create(
model=self.model, model=self.model,
messages=[ messages=[
...@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase):
self.run_chat_completion_stream(logprobs) self.run_chat_completion_stream(logprobs)
def test_regex(self): def test_regex(self):
client = openai.Client(api_key="EMPTY", base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = ( regex = (
r"""\{\n""" r"""\{\n"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment