Unverified Commit 8deb3b52 authored by He Tianyao's avatar He Tianyao Committed by GitHub
Browse files

correct server chat template

MinerU2.0's chat template during training has additional "\n" compared with chatml's. This difference may slightly affect the performance under server mode. (because sglang_server use sglang's chat template).
parent 390ddd8b
......@@ -6,10 +6,26 @@ from sglang.srt.entrypoints.http_server import app, generate_request, launch_ser
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
from sglang.srt.conversation import Conversation
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
from .logit_processor import Mineru2LogitProcessor
# mineru2.0的chat_template与chatml在换行上有微小区别
def custom_get_prompt(self) -> str:
system_prompt = self.system_template.format(system_message=self.system_message)
if self.system_message == "":
ret = ""
else:
ret = system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
return ret
_custom_logit_processor_str = Mineru2LogitProcessor().to_str()
# remote the existing /generate route
......@@ -45,6 +61,7 @@ def main():
if server_args.chat_template is None:
server_args.chat_template = "chatml"
Conversation.get_prompt = custom_get_prompt
server_args.enable_custom_logit_processor = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment