Unverified Commit 3d64cf01 authored by akxxsb's avatar akxxsb Committed by GitHub
Browse files

[Server] use fastchat.model.model_adapter.get_conversation_template method to...

[Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357)
parent 98fe8cb5
...@@ -13,8 +13,9 @@ from fastapi import BackgroundTasks, Request ...@@ -13,8 +13,9 @@ from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import (Conversation, SeparatorStyle, from fastchat.conversation import Conversation, SeparatorStyle
get_conv_template) from fastchat.model.model_adapter import get_conversation_template
import uvicorn import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -36,7 +37,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds ...@@ -36,7 +37,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
chat_template = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
...@@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]: ...@@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str: async def get_gen_prompt(request) -> str:
conv = get_conv_template(chat_template) conv = get_conversation_template(request.model)
conv = Conversation( conv = Conversation(
name=conv.name, name=conv.name,
system=conv.system, system=conv.system,
...@@ -560,14 +560,7 @@ if __name__ == "__main__": ...@@ -560,14 +560,7 @@ if __name__ == "__main__":
help="The model name used in the API. If not " help="The model name used in the API. If not "
"specified, the model name will be the same as " "specified, the model name will be the same as "
"the huggingface name.") "the huggingface name.")
parser.add_argument(
"--chat-template",
type=str,
default=None,
help="The chat template name used in the ChatCompletion endpoint. If "
"not specified, we use the API model name as the template name. See "
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
"for the list of available templates.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -586,11 +579,6 @@ if __name__ == "__main__": ...@@ -586,11 +579,6 @@ if __name__ == "__main__":
else: else:
served_model = args.model served_model = args.model
if args.chat_template is not None:
chat_template = args.chat_template
else:
chat_template = served_model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config()) engine_model_config = asyncio.run(engine.get_model_config())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment