Commit 80e58670 authored by chenych's avatar chenych
Browse files

Update VLLM

parent ce4251e7
......@@ -39,24 +39,19 @@ docker run -it \
-u root \
-v /opt/hyhal/:/opt/hyhal/:ro \
-v /path/your_code_data/:/path/your_code_data/ \
image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2-ds3.2 bash
image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.2-das1.7-py3.10-20251203 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
**vllm下载安装方法,仅适用于deepseek-v3.2模型**
**vllm下载安装方法**
```bash
wget http://112.11.119.99:18000/temp/vllm-0.9.2%2Bdas.opt1.rc2.51af08a.dtk25041-cp310-cp310-linux_x86_64.whl
wget http://112.11.119.99:18000/customized/vllm/dtk25.04.2/0.9.2%2Bdas.opt1.dtk25042/0.9.2%2Bdas.opt1.dtk25042-9f9886d8/vllm-0.9.2%2Bdas.opt1.dtk25042.20251202.g9f9886d8-cp310-cp310-manylinux_2_28_x86_64.whl
# 卸载原环境中的vllm
pip uninstall vllm
#安装新的vllm
pip install vllm-0.9.2+das.opt1.rc2.51af08a.dtk25041-cp310-cp310-linux_x86_64.whl
# 查看vllm在环境中的地址
pip show vllm
# 替换vllm部分代码
cp vllm-codes/* /path/of/env/vllm/entrypoints/
pip install vllm-0.9.2+das.opt1.dtk25042.20251202.g9f9886d8-cp310-cp310-manylinux_2_28_x86_64.whl
```
## 数据集
......@@ -187,7 +182,8 @@ DCU与GPU精度一致,推理框架:vllm。
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:---------------------:|:----------:|
| DeepSeek-V3.2 | 685B | BW1000 | 32 | [下载地址](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) |
| DeepSeek-V3.2 | 685B | BW1000 | 32 | [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) |
| DeepSeek-V3.2-Speciale | 685B | BW1000 | 32 | [Hugging Face](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Speciale) |
## 源码仓库及问题反馈
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
# yapf: enable
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_cls
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import (
get_chat_template_fallback_path)
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import deprecate_kwargs, random_uuid
logger = init_logger(__name__)
class AudioURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class ChatCompletionContentPartAudioParam(TypedDict, total=False):
audio_url: Required[AudioURL]
type: Required[Literal["audio_url"]]
"""The type of the content part."""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, dict[str, str]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["image_embeds"]]
"""The type of the content part."""
class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Either a URL of the video or a data URL with base64 encoded video data.
"""
class ChatCompletionContentPartVideoParam(TypedDict, total=False):
video_url: Required[VideoURL]
type: Required[Literal["video_url"]]
"""The type of the content part."""
class PILImage(BaseModel):
"""
A PIL.Image.Image object.
"""
image_pil: Image.Image
model_config = ConfigDict(arbitrary_types_allowed=True)
class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a PIL image.
Example:
{
"image_pil": ImageAsset('cherry_blossom').pil_image
}
"""
image_pil: Required[PILImage]
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"video_url": "https://example.com/video.mp4"
}
"""
video_url: Required[str]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, list[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam]
# TODO: Make fields ReadOnly once mypy supports it
class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[Optional[str], list[dict[str, str]]]
"""The contents of the message"""
tool_call_id: Optional[str]
"""Tool call that this message is responding to."""
name: Optional[str]
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
"""The tool calls generated by the model, such as function calls."""
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: Optional[str] = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return (node.node is not None
and _is_var_or_elems_access(node.node, varname, key))
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if (isinstance(node, jinja2.nodes.Getitem)
and isinstance(node.arg, jinja2.nodes.Slice)):
return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable
return (
_is_attr_access(node, varname, key) if key
else _is_var_access(node, varname)
) # yapf: enable
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def resolve_mistral_chat_template(
chat_template: Optional[str],
**kwargs: Any,
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
return None
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
model_config: ModelConfig,
trust_remote_code: Optional[bool] = None,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=model_config.trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
hasattr(processor, 'chat_template') and \
processor.chat_template is not None:
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.", tokenizer.name_or_path)
chat_template = load_chat_template(path)
else:
logger.debug("There is no chat template fallback for %s",
tokenizer.name_or_path)
return chat_template
def _resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
else:
hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))
detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
return detected_format
@lru_cache
def _log_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
trust_remote_code: Optional[bool] = None,
) -> _ChatTemplateContentFormat:
if given_format != "auto":
return given_format
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
tokenizer,
model_config=model_config,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T")
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
super().__init__()
self._model_config = model_config
self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[_T]](list)
@property
def model_config(self) -> ModelConfig:
return self._model_config
@cached_property
def model_cls(self):
return get_model_cls(self.model_config)
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
mm_registry = self.mm_registry
model_config = self.model_config
model_cls = cast(SupportsMultiModal, self.model_cls)
input_modality = modality.replace("_embeds", "")
if mm_registry.has_processor(model_config):
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
else:
mm_config = model_config.multimodal_config
if mm_config is None:
msg = "This model does not support multi-modal inputs"
raise ValueError(msg)
allowed_count = mm_config.get_limit_per_prompt(input_modality)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self._items_by_modality[modality].append(item)
return model_cls.get_placeholder_str(modality, current_count)
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\
"Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
def parse_image(self, image_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_pil(self, image_pil: Image.Image) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError
@abstractmethod
def parse_video(self, video_url: str) -> None:
raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
placeholder = self._tracker.add("image_embeds", embeds)
if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)
self._add_placeholder(placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
placeholder = self._tracker.add("image", image_pil)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path
)
def parse_image(self, image_url: str) -> None:
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
future.set_result(embeds)
if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil)
placeholder = self._tracker.add("image", future)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return
elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")
def _load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
if chat_template is None:
return None
if is_literal:
if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly "
"from its value")
return chat_template
try:
with open(chat_template) as f:
return f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
return _load_chat_template(chat_template, is_literal=True)
_cached_load_chat_template = lru_cache(_load_chat_template)
def load_chat_template(
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio":
lambda part: _InputAudioParser(part).get("input_audio", None),
"refusal":
lambda part: _RefusalParser(part).get("refusal", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
# We only support 'auto', which is the default
if part_type == "image_url" and part.get("detail", "auto") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
return part_type, content
# Handle missing 'type' but provided direct URL fields.
# 'type' is required field by pydantic
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
if not isinstance(part_type, str):
raise ValueError("Invalid 'type' field in multimodal part.")
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds", "image_pil",
"audio_url", "input_audio", "video_url")
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
*,
wrap_dicts: bool,
) -> list[ConversationMessage]:
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
for part in parts:
parse_res = _parse_chat_message_content_part(
part,
mm_parser,
wrap_dicts=wrap_dicts,
)
if parse_res:
content.append(parse_res)
if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(list[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
{"type": "image"}, respectively. Otherwise multimodal data will be
handled by mm_parser, and texts will be returned as strings to be joined
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
# content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type)
return None
if part_type in ("text", "refusal"):
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}
else:
return str_content
if part_type == "image_pil":
image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url":
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None
if part_type == "video_url":
str_content = cast(str, content)
mm_parser.parse_video(str_content)
return {'type': 'video'} if wrap_dicts else None
raise NotImplementedError(f"Unknown part type: {part_type}")
# No need to validate using Pydantic again
_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam)
_ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
) -> list[ConversationMessage]:
role = message["role"]
content = message.get("content")
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
result = _parse_chat_message_content_parts(
role,
content, # type: ignore
mm_tracker,
wrap_dicts=(content_format == "openai"),
)
for result_msg in result:
if role == 'assistant':
parsed_msg = _AssistantParser(message)
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if ("tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
result_msg["tool_call_id"] = parsed_msg["tool_call_id"]
if "name" in message and isinstance(message["name"], str):
result_msg["name"] = message["name"]
return result
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
content_format,
)
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
content_format,
)
conversation.extend(sub_messages)
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data()
@deprecate_kwargs(
"trust_remote_code",
additional_message="Please use `model_config.trust_remote_code` instead.",
)
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
model_config: ModelConfig,
tokenize: bool = False, # Different from HF's default
# Deprecated, explicitly capture here so it doesn't slit into kwargs.
trust_remote_code: Optional[bool] = None,
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
if hf_chat_template is None:
from .encoding_dsv32 import encode_messages
encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True)
prompt = encode_messages(conversation, **encode_config)
return tokenizer.encode(prompt)
# raise ValueError(
# "As of transformers v4.44, default chat template is no longer "
# "allowed, so you must provide a chat template if the tokenizer "
# "does not define one.")
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template")
raise ValueError(str(e)) from e
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
**kwargs: Any,
) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None,
# and we won't use it.
resolve_mistral_chat_template(
chat_template=chat_template,
**kwargs,
)
try:
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError(str(e)) from e
def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"
from typing import Any, Dict, List, Union, Optional, Tuple
import copy
import json
import re
TOOLS_SYSTEM_TEMPLATE = """## Tools
You have access to a set of tools you can use to answer the user's question.
You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user:
<{dsml_token}function_calls>
<{dsml_token}invoke name="$FUNCTION_NAME">
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
...
</{dsml_token}invoke>
<{dsml_token}invoke name="$FUNCTION_NAME2">
...
</{dsml_token}invoke>
</{dsml_token}function_calls>
String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).
If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:
<{dsml_token}function_calls>
...
</{dsml_token}function_calls>
<function_results>
...
</function_results>
{thinking_start_token}...thinking about results{thinking_end_token}
Here are the functions available in JSONSchema format:
<functions>
{tool_schemas}
</functions>
"""
bos_token: str = "<|begin▁of▁sentence|>"
eos_token: str = "<|end▁of▁sentence|>"
thinking_start_token: str = "<think>"
thinking_end_token: str = "</think>"
dsml_token: str = "|DSML|"
system_msg_template: str = "{content}"
user_msg_template: str = "<|User|>{content}<|Assistant|>"
assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>"
thinking_template = "{reasoning_content}"
response_format_template: str = (
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
)
tool_call_template: str = (
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
)
tool_calls_template = (
"<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>"
)
tool_output_template: str = (
"\n<result>{content}</result>"
)
def to_json(value: Any) -> str:
try:
return json.dumps(value, ensure_ascii=False)
except:
return json.dumps(value, ensure_ascii=True)
def tools_from_openai_format(tools):
return [tool["function"] for tool in tools]
def tool_calls_from_openai_format(tool_calls):
return [
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
for tool_call in tool_calls
]
def tool_calls_to_openai_format(tool_calls):
return [
{
"type": "function",
"function": {
"name": tool_call["name"],
"arguments": tool_call["arguments"],
}
}
for tool_call in tool_calls
]
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>"""
P_dsml_strs = []
arguments = json.loads(tool_call["arguments"])
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
dsml_token=dsml_token,
key=k,
is_str="true" if isinstance(v, str) else "false",
value=v if isinstance(v, str) else to_json(v),
)
P_dsml_strs.append(p_dsml_str)
return "\n".join(P_dsml_strs)
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
def _decode_value(key: str, value: str, string: str):
if string == "true":
value = to_json(value)
return f"{to_json(key)}: {value}"
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
return dict(name=tool_name, arguments=tool_args_json)
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
tools_json = [to_json(t) for t in tools]
return TOOLS_SYSTEM_TEMPLATE.format(
tool_schemas="\n".join(tools_json),
dsml_token=dsml_token,
thinking_start_token=thinking_start_token,
thinking_end_token=thinking_end_token,
)
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
last_user_index = -1
for idx in range(len(messages)-1, -1, -1):
if messages[idx].get("role") in ["user", "developer"]:
last_user_index = idx
break
return last_user_index
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str) -> str:
assert 0 <= index < len(messages)
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
prompt = ""
msg = messages[index]
last_user_idx = find_last_user_index(messages)
role = msg.get("role")
content = msg.get("content")
tools = msg.get("tools")
response_format = msg.get("response_format")
tool_calls = msg.get("tool_calls")
reasoning_content = msg.get("reasoning_content")
if tools:
tools = tools_from_openai_format(tools)
if tool_calls:
tool_calls = tool_calls_from_openai_format(tool_calls)
if role == "system":
prompt += system_msg_template.format(content=content or "")
if tools:
prompt += "\n\n" + render_tools(tools)
if response_format:
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
elif role == "developer":
assert content, f"Invalid message for role `{role}`: {msg}"
content_developer = ""
if tools:
content_developer += "\n\n" + render_tools(tools)
if response_format:
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
content_developer += "\n\n# The user's message is: {}".format(content)
prompt += user_msg_template.format(content=content_developer)
if index == last_user_idx and thinking_mode == "thinking":
prompt += thinking_start_token
else:
prompt += thinking_end_token
elif role == "user":
prompt += user_msg_template.format(content=content)
if index == last_user_idx and thinking_mode == "thinking":
prompt += thinking_start_token
else:
prompt += thinking_end_token
elif role == "tool":
prev_assistant_idx = index - 1
assistant_msg = messages[prev_assistant_idx]
while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool":
prev_assistant_idx -= 1
assistant_msg = messages[prev_assistant_idx]
assert index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant", f"Invalid messages at {index}:\n{assistant_msg}"
tool_call_order = index - prev_assistant_idx
assistant_tool_calls = assistant_msg.get("tool_calls")
assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, "No tool calls but found tool output"
if tool_call_order == 1:
prompt += "\n\n<function_results>"
prompt += tool_output_template.format(content=content)
if tool_call_order == len(assistant_tool_calls):
prompt += "\n</function_results>"
if index >= last_user_idx and thinking_mode == "thinking":
prompt += "\n\n" + thinking_start_token
else:
prompt += "\n\n" + thinking_end_token
elif role == "assistant":
prev_assistant_idx = index
thinking_part = ""
tool_calls_content = ""
if tool_calls:
tool_calls = [
tool_call_template.format(
dsml_token=dsml_token,
name=tool_call.get("name"),
arguments=encode_arguments_to_dsml(tool_call)
)
for tool_call in tool_calls
]
tool_calls_content += "\n\n" + tool_calls_template.format(
dsml_token=dsml_token,
tool_calls="\n".join(tool_calls)
)
summary_content = content or ""
if thinking_mode == "thinking" and index > last_user_idx:
assert reasoning_content or tool_calls, f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
thinking_part = thinking_template.format(reasoning_content=reasoning_content or "") + thinking_end_token
prompt += assistant_msg_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tool_calls_content,
)
else:
raise NotImplementedError(f"Unknown role: {role}")
return prompt
def drop_thinking_messages(messages: List[Dict[str, Any]], last_user_idx: Optional[int]=None) -> List[Dict[str, Any]]:
messages_wo_thinking: List[Dict[str, Any]] = []
last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx
for idx, msg in enumerate(messages):
role = msg.get("role")
if role in ["user", "system", "tool"] or idx >= last_user_idx:
messages_wo_thinking.append(msg)
continue
elif role == "assistant":
msg_wo_thinking = copy.copy(msg)
msg_wo_thinking.pop("reasoning_content", None)
messages_wo_thinking.append(msg_wo_thinking)
return messages_wo_thinking
def encode_messages(messages: List[Dict[str, Any]], thinking_mode: str, context: Optional[List[Dict[str, Any]]] = None, drop_thinking: bool = True, add_default_bos_token: bool = True) -> str:
context = context if context else []
full_messages = context + messages
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
if thinking_mode == "thinking" and drop_thinking:
full_messages = drop_thinking_messages(full_messages)
for idx in range(len(messages)):
prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode)
return prompt
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
min_pos = len(text)
matched_stop = None
for s in stop:
pos = text.find(s, index)
if pos != -1 and pos < min_pos:
min_pos = pos
matched_stop = s
if matched_stop:
content = text[index:min_pos]
return min_pos + len(matched_stop), content, matched_stop
else:
content = text[index:]
return len(text), content, None
def parse_tool_calls(index: int, text: str):
tool_calls: List[Dict[str, Any]] = []
stop_token = None
tool_calls_end_token = f"</{dsml_token}function_calls>"
while index < len(text):
index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
assert _ == ">\n", "Tool call format error"
if stop_token == tool_calls_end_token:
break
assert stop_token is not None, "Missing special token"
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
assert len(p_tool_name) == 1, "Tool name format error"
tool_name = p_tool_name[0]
tool_args: Dict[str, Tuple[str, str]] = {}
while stop_token == f"<{dsml_token}parameter":
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
assert len(param_kv) == 1, "Parameter format error"
param_name, string, param_value = param_kv[0]
assert param_name not in tool_args, "Duplicate parameter name"
tool_args[param_name] = (param_value, string)
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
assert content == ">\n", "Parameter format error"
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
tool_calls.append(tool_call)
return index, stop_token, tool_calls
# NOTE: This function is designed to parse only correctly formatted string and will not attempt to correct malformed output that may be generated by the model.
def parse_message_from_completion_text(text: str, thinking_mode: str):
summary_content, reasoning_content, tool_calls = "", "", []
index, stop_token = 0, None
tool_calls_start_token = f"\n\n<{dsml_token}function_calls"
is_thinking, is_tool_calling = thinking_mode == "thinking", False
if is_thinking:
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
reasoning_content = content_delta
assert stop_token == thinking_end_token, "Invalid thinking format"
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
summary_content = content_delta
if stop_token == tool_calls_start_token:
is_tool_calling = True
else:
assert stop_token == eos_token, "Invalid summary format"
if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
assert not tool_ends_text, "Unexpected content after tool calls"
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
assert sp_token not in summary_content and sp_token not in reasoning_content, "Unexpected special token in content"
return {
"role": "assistant",
"content": summary_content,
"reasoning_content": reasoning_content,
"tool_calls": tool_calls_to_openai_format(tool_calls)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment