Unverified Commit f403cde6 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #650 from ceerRep/main

feat: basic api key support
parents 1d5d5fae f639fbc1
...@@ -25,6 +25,9 @@ async def chat_completion(request:Request,create:ChatCompletionCreate): ...@@ -25,6 +25,9 @@ async def chat_completion(request:Request,create:ChatCompletionCreate):
input_message = [json.loads(m.model_dump_json()) for m in create.messages] input_message = [json.loads(m.model_dump_json()) for m in create.messages]
if Config().api_key != '':
assert request.headers.get('Authorization', '').split()[-1] == Config().api_key
if create.stream: if create.stream:
async def inner(): async def inner():
chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time())) chunk = ChatCompletionChunk(id=id,object='chat.completion.chunk',created=int(time()))
......
...@@ -10,6 +10,7 @@ class ArgumentParser: ...@@ -10,6 +10,7 @@ class ArgumentParser:
parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers") parser = argparse.ArgumentParser(prog="kvcache.ai", description="Ktransformers")
parser.add_argument("--host", type=str, default=self.cfg.server_ip) parser.add_argument("--host", type=str, default=self.cfg.server_ip)
parser.add_argument("--port", type=int, default=self.cfg.server_port) parser.add_argument("--port", type=int, default=self.cfg.server_port)
parser.add_argument("--api_key", type=str, default=self.cfg.api_key)
parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str) parser.add_argument("--ssl_certfile", type=str)
parser.add_argument("--web", type=bool, default=self.cfg.mount_web) parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
......
...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton): ...@@ -69,6 +69,7 @@ class Config(metaclass=Singleton):
self.server: dict = cfg.get("server", {}) self.server: dict = cfg.get("server", {})
self.server_ip = self.server.get("ip", "0.0.0.0") self.server_ip = self.server.get("ip", "0.0.0.0")
self.server_port = self.server.get("port", 9016) self.server_port = self.server.get("port", 9016)
self.api_key = self.server.get("api_key", "")
# db configs # db configs
self.db_configs: dict = cfg.get("db", {}) self.db_configs: dict = cfg.get("db", {})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment