Unverified Commit f13d86f9 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add image_token in conversation.py (#1632)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent aba9eae4
...@@ -70,6 +70,9 @@ class Conversation: ...@@ -70,6 +70,9 @@ class Conversation:
sep2: str = None sep2: str = None
# Stop criteria (the default one is EOS token) # Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt
image_token: str = "<image>"
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
...@@ -334,6 +337,7 @@ class Conversation: ...@@ -334,6 +337,7 @@ class Conversation:
sep=self.sep, sep=self.sep,
sep2=self.sep2, sep2=self.sep2,
stop_str=self.stop_str, stop_str=self.stop_str,
image_token=self.image_token,
) )
def dict(self): def dict(self):
...@@ -381,6 +385,7 @@ def generate_chat_conv( ...@@ -381,6 +385,7 @@ def generate_chat_conv(
stop_str=conv.stop_str, stop_str=conv.stop_str,
image_data=[], image_data=[],
modalities=[], modalities=[],
image_token=conv.image_token,
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
...@@ -412,9 +417,13 @@ def generate_chat_conv( ...@@ -412,9 +417,13 @@ def generate_chat_conv(
num_image_url += 1 num_image_url += 1
conv.modalities.append(content.modalities) conv.modalities.append(content.modalities)
if num_image_url > 1: if num_image_url > 1:
image_token = "<image>" image_token = conv.image_token
else: else:
image_token = "<image>\n" image_token = (
conv.image_token + "\n"
if conv.name != "qwen2-vl"
else conv.image_token
)
for content in message.content: for content in message.content:
if content.type == "text": if content.type == "text":
if num_image_url > 16: if num_image_url > 16:
......
...@@ -117,7 +117,9 @@ def create_streaming_error_response( ...@@ -117,7 +117,9 @@ def create_streaming_error_response(
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
global chat_template_name global chat_template_name
logger.info(f"Use chat template: {chat_template_arg}") logger.info(
f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
)
if not chat_template_exists(chat_template_arg): if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg): if not os.path.exists(chat_template_arg):
raise RuntimeError( raise RuntimeError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment