Unverified Commit e8ddc08e authored by YHPeter's avatar YHPeter Committed by GitHub
Browse files

[BUG FIX] upgrade fschat version to 0.2.23 (#650)


Co-authored-by: default avatarhao.yu <hao.yu@cn-c017.server.mila.quebec>
parent 1b0bd0fe
...@@ -7,6 +7,7 @@ from http import HTTPStatus ...@@ -7,6 +7,7 @@ from http import HTTPStatus
import json import json
import time import time
from typing import AsyncGenerator, Dict, List, Optional from typing import AsyncGenerator, Dict, List, Optional
from packaging import version
import fastapi import fastapi
from fastapi import BackgroundTasks, Request from fastapi import BackgroundTasks, Request
...@@ -31,6 +32,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer ...@@ -31,6 +32,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
try: try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True _fastchat_available = True
...@@ -72,10 +74,16 @@ async def get_gen_prompt(request) -> str: ...@@ -72,10 +74,16 @@ async def get_gen_prompt(request) -> str:
"fastchat is not installed. Please install fastchat to use " "fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`" "the chat completion and conversation APIs: `$ pip install fschat`"
) )
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model) conv = get_conversation_template(request.model)
conv = Conversation( conv = Conversation(
name=conv.name, name=conv.name,
system=conv.system, system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles, roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification messages=list(conv.messages), # prevent in-place modification
offset=conv.offset, offset=conv.offset,
...@@ -92,7 +100,7 @@ async def get_gen_prompt(request) -> str: ...@@ -92,7 +100,7 @@ async def get_gen_prompt(request) -> str:
for message in request.messages: for message in request.messages:
msg_role = message["role"] msg_role = message["role"]
if msg_role == "system": if msg_role == "system":
conv.system = message["content"] conv.system_message = message["content"]
elif msg_role == "user": elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"]) conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant": elif msg_role == "assistant":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment