Unverified Commit 3edba9bc authored by Tanjiro's avatar Tanjiro Committed by GitHub
Browse files

[fix] added image token as prefix for deepseek-ocr (#12358)

parent e5ec9764
...@@ -101,6 +101,7 @@ class Conversation: ...@@ -101,6 +101,7 @@ class Conversation:
stop_token_ids: Optional[int] = None stop_token_ids: Optional[int] = None
audio_data: Optional[List[str]] = None audio_data: Optional[List[str]] = None
image_token_at_prefix: bool = False
def get_prompt(self) -> str: def get_prompt(self) -> str:
"""Get the prompt for generation.""" """Get the prompt for generation."""
...@@ -445,6 +446,7 @@ class Conversation: ...@@ -445,6 +446,7 @@ class Conversation:
image_token=self.image_token, image_token=self.image_token,
video_token=self.video_token, video_token=self.video_token,
audio_token=self.audio_token, audio_token=self.audio_token,
image_token_at_prefix=self.image_token_at_prefix,
) )
def dict(self): def dict(self):
...@@ -512,6 +514,7 @@ def generate_embedding_convs( ...@@ -512,6 +514,7 @@ def generate_embedding_convs(
image_token=conv_template.image_token, image_token=conv_template.image_token,
video_token=conv_template.video_token, video_token=conv_template.video_token,
audio_token=conv_template.audio_token, audio_token=conv_template.audio_token,
image_token_at_prefix=conv_template.image_token_at_prefix,
) )
real_content = "" real_content = ""
...@@ -578,6 +581,7 @@ def generate_chat_conv( ...@@ -578,6 +581,7 @@ def generate_chat_conv(
image_token=conv.image_token, image_token=conv.image_token,
audio_token=conv.audio_token, audio_token=conv.audio_token,
video_token=conv.video_token, video_token=conv.video_token,
image_token_at_prefix=conv.image_token_at_prefix,
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
...@@ -627,7 +631,7 @@ def generate_chat_conv( ...@@ -627,7 +631,7 @@ def generate_chat_conv(
real_content += content.text real_content += content.text
elif content.type == "image_url": elif content.type == "image_url":
# NOTE: works for llava and intervl2_5 # NOTE: works for llava and intervl2_5
if conv.name in ["internvl-2-5"]: if conv.image_token_at_prefix:
real_content = image_token + real_content real_content = image_token + real_content
else: else:
real_content += image_token real_content += image_token
...@@ -820,6 +824,7 @@ register_conv_template( ...@@ -820,6 +824,7 @@ register_conv_template(
sep="<|im_end|>\n", sep="<|im_end|>\n",
stop_str=["<|im_end|>", "<|action_end|>"], stop_str=["<|im_end|>", "<|action_end|>"],
image_token="<IMG_CONTEXT>", image_token="<IMG_CONTEXT>",
image_token_at_prefix=True,
) )
) )
...@@ -848,6 +853,7 @@ register_conv_template( ...@@ -848,6 +853,7 @@ register_conv_template(
sep_style=SeparatorStyle.NO_COLON_SINGLE, sep_style=SeparatorStyle.NO_COLON_SINGLE,
stop_str=["<|end▁of▁sentence|>"], stop_str=["<|end▁of▁sentence|>"],
image_token="<image>", image_token="<image>",
image_token_at_prefix=True,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment