"src/vscode:/vscode.git/clone" did not exist on "b2da59b197306a49d93db2a28247de9b0f187435"
Unverified Commit 98fe8cb5 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

[Server] Add option to specify chat template for chat endpoint (#345)

parent ffa6d2f9
...@@ -9,3 +9,4 @@ xformers >= 0.0.19 ...@@ -9,3 +9,4 @@ xformers >= 0.0.19
fastapi fastapi
uvicorn uvicorn
pydantic # Required for OpenAI server. pydantic # Required for OpenAI server.
fschat # Required for OpenAI ChatCompletion Endpoint.
...@@ -36,6 +36,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds ...@@ -36,6 +36,7 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
chat_template = None
app = fastapi.FastAPI() app = fastapi.FastAPI()
...@@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]: ...@@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
async def get_gen_prompt(request) -> str: async def get_gen_prompt(request) -> str:
conv = get_conv_template(request.model) conv = get_conv_template(chat_template)
conv = Conversation( conv = Conversation(
name=conv.name, name=conv.name,
system=conv.system, system=conv.system,
...@@ -553,13 +554,20 @@ if __name__ == "__main__": ...@@ -553,13 +554,20 @@ if __name__ == "__main__":
type=json.loads, type=json.loads,
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument( parser.add_argument(
"--served-model-name", "--chat-template",
type=str, type=str,
default=None, default=None,
help="The model name used in the API. If not specified, " help="The chat template name used in the ChatCompletion endpoint. If "
"the model name will be the same as the " "not specified, we use the API model name as the template name. See "
"huggingface name.") "https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
"for the list of available templates.")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -573,7 +581,15 @@ if __name__ == "__main__": ...@@ -573,7 +581,15 @@ if __name__ == "__main__":
logger.info(f"args: {args}") logger.info(f"args: {args}")
served_model = args.served_model_name or args.model if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
if args.chat_template is not None:
chat_template = args.chat_template
else:
chat_template = served_model
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment