Unverified Commit 4013a4e1 authored by 任嘉's avatar 任嘉 Committed by GitHub
Browse files

Implement served_model_name to customize model id when use local mode… (#749)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 60340a36
...@@ -79,6 +79,7 @@ class TokenizerManager: ...@@ -79,6 +79,7 @@ class TokenizerManager:
self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
self.model_path = server_args.model_path self.model_path = server_args.model_path
self.served_model_name = server_args.served_model_name
self.hf_config = get_config( self.hf_config = get_config(
self.model_path, self.model_path,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
......
...@@ -190,10 +190,10 @@ async def retrieve_file_content(file_id: str): ...@@ -190,10 +190,10 @@ async def retrieve_file_content(file_id: str):
@app.get("/v1/models") @app.get("/v1/models")
def available_models(): def available_models():
"""Show available models.""" """Show available models."""
model_names = [tokenizer_manager.model_path] served_model_names = [tokenizer_manager.served_model_name]
model_cards = [] model_cards = []
for model_name in model_names: for served_model_name in served_model_names:
model_cards.append(ModelCard(id=model_name, root=model_name)) model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
return ModelList(data=model_cards) return ModelList(data=model_cards)
......
...@@ -32,6 +32,7 @@ class ServerArgs: ...@@ -32,6 +32,7 @@ class ServerArgs:
trust_remote_code: bool = True trust_remote_code: bool = True
context_length: Optional[int] = None context_length: Optional[int] = None
quantization: Optional[str] = None quantization: Optional[str] = None
served_model_name: Optional[str] = None
chat_template: Optional[str] = None chat_template: Optional[str] = None
# Port # Port
...@@ -90,6 +91,10 @@ class ServerArgs: ...@@ -90,6 +91,10 @@ class ServerArgs:
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.served_model_name is None:
self.served_model_name = self.model_path
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: if self.tp_size >= 16:
self.mem_fraction_static = 0.79 self.mem_fraction_static = 0.79
...@@ -202,6 +207,12 @@ class ServerArgs: ...@@ -202,6 +207,12 @@ class ServerArgs:
], ],
help="The quantization method.", help="The quantization method.",
) )
parser.add_argument(
"--served-model-name",
type=str,
default=ServerArgs.served_model_name,
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
)
parser.add_argument( parser.add_argument(
"--chat-template", "--chat-template",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment