Commit ce4251e7 authored by chenych's avatar chenych
Browse files

First commit

parents
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
# DeepSeek-V3.2
## 论文
[DeepSeek_V3.2](./doc/paper.pdf)
## 模型简介
DeepSeek-V3.2,这是一个在高计算效率与卓越推理和代理性能之间取得平衡的模型。我们的方法基于三个关键技术突破:
1. **DeepSeek 稀疏注意力(DSA):** 我们引入了 DSA,这是一种高效的注意力机制,它显著降低了计算复杂性,同时保持了模型性能,特别针对长上下文场景进行了优化。
2. **可扩展的强化学习框架:** 通过实施强大的 RL 协议并扩展后训练计算,DeepSeek-V3.2 的表现与 GPT-5 相当。值得注意的是,我们的高计算变体 DeepSeek-V3.2-Speciale 超越了 GPT-5,并在推理能力上与 Gemini-3.0-Pro 相当。
+ *成就:*🥇 2025年国际数学奥林匹克竞赛(IMO)和国际信息学奥林匹克竞赛(IOI)金牌表现。
3. **大规模代理任务合成管道:** 为了将推理融入工具使用场景,我们开发了一种新颖的合成管道,系统地生成大规模训练数据。这促进了可扩展的代理后训练,提高了在复杂交互环境中的合规性和泛化能力。
<div align=center>
<img src="./doc/benchmark.png"/>
</div>
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.1 |
| python | 3.10.12 |
| transformers | 4.56.1 |
| torch | 2.5.1+das.opt1.dtk25041 |
推荐使用镜像:
- 挂载地址`-v` 根据实际模型情况修改
```bash
docker run -it \
--shm-size 60g \
--network=host \
--name deepseek-v32 \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--device=/dev/mkfd \
--group-add video \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
-u root \
-v /opt/hyhal/:/opt/hyhal/:ro \
-v /path/your_code_data/:/path/your_code_data/ \
image.sourcefind.cn:5000/dcu/admin/base/vllm:0.9.2-ubuntu22.04-dtk25.04.1-rc5-rocblas104381-0915-das1.6-py3.10-20250916-rc2-ds3.2 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
**vllm下载安装方法,仅适用于deepseek-v3.2模型**
```bash
wget http://112.11.119.99:18000/temp/vllm-0.9.2%2Bdas.opt1.rc2.51af08a.dtk25041-cp310-cp310-linux_x86_64.whl
# 卸载原环境中的vllm
pip uninstall vllm
#安装新的vllm
pip install vllm-0.9.2+das.opt1.rc2.51af08a.dtk25041-cp310-cp310-linux_x86_64.whl
# 查看vllm在环境中的地址
pip show vllm
# 替换vllm部分代码
cp vllm-codes/* /path/of/env/vllm/entrypoints/
```
## 数据集
## 训练
暂无
## 推理
### vllm
#### 多机推理
1. 将模型转换成bf16格式,转换命令如下:
```bash
# fp8转bf16
python inference/fp8_cast_bf16.py --input-fp8-hf-path /path/to/DeepSeek-V3.2 --output-bf16-hf-path /path/to/DeepSeek-V3.2-bf16
```
转换完成后,将原模型中的 `generation_config.json`, `tokenizer_config.json`, `tokenizer.json`拷贝到`/path/to/DeepSeek-V3.2-Exp-bf16`中。
拷贝config文件
```bash
cp inference/config.json /path/to/DeepSeek-V3.2-bf16
```
2. 加入环境变量
> 请注意:
> 每个节点上的环境变量都写到.sh文件中,保存后各个计算节点分别source `.sh` 文件
>
> VLLM_HOST_IP:节点本地通信口ip,尽量选择IB网卡的IP,**避免出现rccl超时问题**
>
> NCCL_SOCKET_IFNAME和GLOO_SOCKET_IFNAME:节点本地通信网口ip对应的名称
>
> 通信口和ip查询方法:ifconfig
>
> IB口状态查询:ibstat !!!一定要active激活状态才可用,各个节点要保持统一
<div align=center>
<img src="./doc/ip_bw.png"/>
</div>
```bash
export ALLREDUCE_STREAM_WITH_COMPUTE=1
export VLLM_HOST_IP=x.x.x.x # 对应计算节点的IP,建议选择IB口SOCKET_IFNAME对应IP地址
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export HSA_FORCE_FINE_GRAIN_PCIE=1
export NCCL_SOCKET_IFNAME=ibxxxx
export GLOO_SOCKET_IFNAME=ibxxxx
export NCCL_IB_HCA=mlx5_0:1
unset NCCL_ALGO
export NCCL_IB_DISABLE=0
export NCCL_MAX_NCHANNELS=16
export NCCL_MIN_NCHANNELS=16
export NCCL_NET_GDR_READ=1
export NCCL_DEBUG=INFO
export NCCL_MIN_P2P_NCHANNELS=16
export NCCL_NCHANNELS_PER_PEER=16
export HIP_USE_GRAPH_QUEUE_POOL=1
export VLLM_ENABLE_MOE_FUSED_GATE=0
export VLLM_ENFORCE_EAGER_BS_THRESHOLD=44
export VLLM_RPC_TIMEOUT=1800000
export VLLM_USE_FLASH_MLA=1
# 海光CPU绑定核,intel cpu可不加
export VLLM_NUMA_BIND=1
export VLLM_RANK0_NUMA=0
export VLLM_RANK1_NUMA=1
export VLLM_RANK2_NUMA=2
export VLLM_RANK3_NUMA=3
export VLLM_RANK4_NUMA=4
export VLLM_RANK5_NUMA=5
export VLLM_RANK6_NUMA=6
export VLLM_RANK7_NUMA=7
# BW集群需要额外设置的环境变量
export NCCL_NET_GDR_LEVEL=7
export NCCL_SDMA_COPY_ENABLE=0
```
3. 启动RAY集群
> x.x.x.x 对应第一步 Master节点的 VLLM_HOST_IP
```bash
# head节点执行
ray start --head --node-ip-address=x.x.x.x --port=6379 --num-gpus=8 --num-cpus=32
# worker节点执行
ray start --address='x.x.x.x:6379' --num-gpus=8 --num-cpus=32
```
4. 启动vllm server
> intel cpu 需要加参数:`--enforce-eager`
```bash
vllm serve /path/to/DeepSeek-V3.2-bf16 \
--trust-remote-code \
--distributed-executor-backend ray \
--dtype bfloat16 \
--tensor-parallel-size 32 \
--max-model-len 32768 \
--port 8001
```
启动完成后可通过以下方式访问:
```bash
curl http://localhost:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/path/to/DeepSeek-V3.2-bf16",
"messages": [
{
"role": "user",
"content": "请介绍下你自己。"
}
],
"temperature": 1.0,
"chat_template_kwargs": {
"thinking": true
}
}'
```
## 效果展示
<div align=center>
<img src="./doc/result-dcu.png"/>
</div>
### 精度
DCU与GPU精度一致,推理框架:vllm。
## 预训练权重
| 模型名称 | 权重大小 | DCU型号 | 最低卡数需求 |下载地址|
|:-----:|:----------:|:----------:|:---------------------:|:----------:|
| DeepSeek-V3.2 | 685B | BW1000 | 32 | [下载地址](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/deepseek-v3.2_vllm
## 参考资料
- https://huggingface.co/deepseek-ai/DeepSeek-V3.2
File added
icon.png

53.8 KB

{
"architectures": [
"DeepseekV3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"hidden_size": 7168,
"index_head_dim": 128,
"index_n_heads": 64,
"index_topk": 2048,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": true,
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": false,
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.44.2",
"use_cache": true,
"v_head_dim": 128,
"vocab_size": 129280
}
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
from typing import Tuple
import torch
import triton
import triton.language as tl
from triton import Config
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.is_contiguous()
assert x.size(-1) % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
# @triton.jit
# def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
# pid_m = tl.program_id(axis=0)
# pid_n = tl.program_id(axis=1)
# n = tl.cdiv(N, BLOCK_SIZE)
# offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# offs = offs_m[:, None] * N + offs_n[None, :]
# mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
# x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
# s = tl.load(s_ptr + pid_m * n + pid_n)
# y = x * s
# tl.store(y_ptr + offs, y, mask=mask)
# def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
# assert x.is_contiguous() and s.is_contiguous()
# assert x.dim() == 2 and s.dim() == 2
# M, N = x.size()
# y = torch.empty_like(x, dtype=torch.get_default_dtype())
# grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
# weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
# return y
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
# 计算 s 的目标形状
s_M = (M + block_size - 1) // block_size # 向上取整
s_N = (N + block_size - 1) // block_size # 向上取整
# 检查 s 的形状是否正确
assert s.size(0) == s_M and s.size(1) == s_N, \
f"s 的形状应为 ({s_M}, {s_N}), 但实际为 {s.size()}"
# 将 s 扩展到与 x 相同的形状
s_expanded = s.repeat_interleave(block_size, dim=0).repeat_interleave(block_size, dim=1)
# 裁剪 s_expanded 以匹配 x 的形状
s_expanded = s_expanded[:M, :N]
# 逐元素乘法
y = x.to(torch.float32) * s_expanded
y = y.to(torch.bfloat16)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c
# 模型唯一标识
modelCode=1847
# 模型名称
modelName=deepseek-v3.2_vllm
# 模型描述
modelDescription=DeepSeek-V3.2 是deepseek推出的首个将思考融入工具使用的模型,并且同时支持思考模式与非思考模式的工具调用.
# 运行过程
processType=推理
# 算法类别
appCategory=对话问答
# 框架类型
frameType=vllm
# 加速卡类型
accelerateType=BW1000
This diff is collapsed.
from typing import Any, Dict, List, Union, Optional, Tuple
import copy
import json
import re
TOOLS_SYSTEM_TEMPLATE = """## Tools
You have access to a set of tools you can use to answer the user's question.
You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user:
<{dsml_token}function_calls>
<{dsml_token}invoke name="$FUNCTION_NAME">
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
...
</{dsml_token}invoke>
<{dsml_token}invoke name="$FUNCTION_NAME2">
...
</{dsml_token}invoke>
</{dsml_token}function_calls>
String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects).
If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example:
<{dsml_token}function_calls>
...
</{dsml_token}function_calls>
<function_results>
...
</function_results>
{thinking_start_token}...thinking about results{thinking_end_token}
Here are the functions available in JSONSchema format:
<functions>
{tool_schemas}
</functions>
"""
bos_token: str = "<|begin▁of▁sentence|>"
eos_token: str = "<|end▁of▁sentence|>"
thinking_start_token: str = "<think>"
thinking_end_token: str = "</think>"
dsml_token: str = "|DSML|"
system_msg_template: str = "{content}"
user_msg_template: str = "<|User|>{content}<|Assistant|>"
assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>"
thinking_template = "{reasoning_content}"
response_format_template: str = (
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
)
tool_call_template: str = (
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
)
tool_calls_template = (
"<{dsml_token}function_calls>\n{tool_calls}\n</{dsml_token}function_calls>"
)
tool_output_template: str = (
"\n<result>{content}</result>"
)
def to_json(value: Any) -> str:
try:
return json.dumps(value, ensure_ascii=False)
except:
return json.dumps(value, ensure_ascii=True)
def tools_from_openai_format(tools):
return [tool["function"] for tool in tools]
def tool_calls_from_openai_format(tool_calls):
return [
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
for tool_call in tool_calls
]
def tool_calls_to_openai_format(tool_calls):
return [
{
"type": "function",
"function": {
"name": tool_call["name"],
"arguments": tool_call["arguments"],
}
}
for tool_call in tool_calls
]
def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str:
p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>"""
P_dsml_strs = []
arguments = json.loads(tool_call["arguments"])
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
dsml_token=dsml_token,
key=k,
is_str="true" if isinstance(v, str) else "false",
value=v if isinstance(v, str) else to_json(v),
)
P_dsml_strs.append(p_dsml_str)
return "\n".join(P_dsml_strs)
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
def _decode_value(key: str, value: str, string: str):
if string == "true":
value = to_json(value)
return f"{to_json(key)}: {value}"
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
return dict(name=tool_name, arguments=tool_args_json)
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
tools_json = [to_json(t) for t in tools]
return TOOLS_SYSTEM_TEMPLATE.format(
tool_schemas="\n".join(tools_json),
dsml_token=dsml_token,
thinking_start_token=thinking_start_token,
thinking_end_token=thinking_end_token,
)
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
last_user_index = -1
for idx in range(len(messages)-1, -1, -1):
if messages[idx].get("role") in ["user", "developer"]:
last_user_index = idx
break
return last_user_index
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str) -> str:
assert 0 <= index < len(messages)
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
prompt = ""
msg = messages[index]
last_user_idx = find_last_user_index(messages)
role = msg.get("role")
content = msg.get("content")
tools = msg.get("tools")
response_format = msg.get("response_format")
tool_calls = msg.get("tool_calls")
reasoning_content = msg.get("reasoning_content")
if tools:
tools = tools_from_openai_format(tools)
if tool_calls:
tool_calls = tool_calls_from_openai_format(tool_calls)
if role == "system":
prompt += system_msg_template.format(content=content or "")
if tools:
prompt += "\n\n" + render_tools(tools)
if response_format:
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
elif role == "developer":
assert content, f"Invalid message for role `{role}`: {msg}"
content_developer = ""
if tools:
content_developer += "\n\n" + render_tools(tools)
if response_format:
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
content_developer += "\n\n# The user's message is: {}".format(content)
prompt += user_msg_template.format(content=content_developer)
if index == last_user_idx and thinking_mode == "thinking":
prompt += thinking_start_token
else:
prompt += thinking_end_token
elif role == "user":
prompt += user_msg_template.format(content=content)
if index == last_user_idx and thinking_mode == "thinking":
prompt += thinking_start_token
else:
prompt += thinking_end_token
elif role == "tool":
prev_assistant_idx = index - 1
assistant_msg = messages[prev_assistant_idx]
while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool":
prev_assistant_idx -= 1
assistant_msg = messages[prev_assistant_idx]
assert index == 0 or prev_assistant_idx >= 0 and assistant_msg.get("role") == "assistant", f"Invalid messages at {index}:\n{assistant_msg}"
tool_call_order = index - prev_assistant_idx
assistant_tool_calls = assistant_msg.get("tool_calls")
assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, "No tool calls but found tool output"
if tool_call_order == 1:
prompt += "\n\n<function_results>"
prompt += tool_output_template.format(content=content)
if tool_call_order == len(assistant_tool_calls):
prompt += "\n</function_results>"
if index >= last_user_idx and thinking_mode == "thinking":
prompt += "\n\n" + thinking_start_token
else:
prompt += "\n\n" + thinking_end_token
elif role == "assistant":
prev_assistant_idx = index
thinking_part = ""
tool_calls_content = ""
if tool_calls:
tool_calls = [
tool_call_template.format(
dsml_token=dsml_token,
name=tool_call.get("name"),
arguments=encode_arguments_to_dsml(tool_call)
)
for tool_call in tool_calls
]
tool_calls_content += "\n\n" + tool_calls_template.format(
dsml_token=dsml_token,
tool_calls="\n".join(tool_calls)
)
summary_content = content or ""
if thinking_mode == "thinking" and index > last_user_idx:
assert reasoning_content or tool_calls, f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message"
thinking_part = thinking_template.format(reasoning_content=reasoning_content or "") + thinking_end_token
prompt += assistant_msg_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tool_calls_content,
)
else:
raise NotImplementedError(f"Unknown role: {role}")
return prompt
def drop_thinking_messages(messages: List[Dict[str, Any]], last_user_idx: Optional[int]=None) -> List[Dict[str, Any]]:
messages_wo_thinking: List[Dict[str, Any]] = []
last_user_idx = find_last_user_index(messages) if last_user_idx is None else last_user_idx
for idx, msg in enumerate(messages):
role = msg.get("role")
if role in ["user", "system", "tool"] or idx >= last_user_idx:
messages_wo_thinking.append(msg)
continue
elif role == "assistant":
msg_wo_thinking = copy.copy(msg)
msg_wo_thinking.pop("reasoning_content", None)
messages_wo_thinking.append(msg_wo_thinking)
return messages_wo_thinking
def encode_messages(messages: List[Dict[str, Any]], thinking_mode: str, context: Optional[List[Dict[str, Any]]] = None, drop_thinking: bool = True, add_default_bos_token: bool = True) -> str:
context = context if context else []
full_messages = context + messages
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
if thinking_mode == "thinking" and drop_thinking:
full_messages = drop_thinking_messages(full_messages)
for idx in range(len(messages)):
prompt += render_message(idx + len(context), full_messages, thinking_mode=thinking_mode)
return prompt
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
min_pos = len(text)
matched_stop = None
for s in stop:
pos = text.find(s, index)
if pos != -1 and pos < min_pos:
min_pos = pos
matched_stop = s
if matched_stop:
content = text[index:min_pos]
return min_pos + len(matched_stop), content, matched_stop
else:
content = text[index:]
return len(text), content, None
def parse_tool_calls(index: int, text: str):
tool_calls: List[Dict[str, Any]] = []
stop_token = None
tool_calls_end_token = f"</{dsml_token}function_calls>"
while index < len(text):
index, _, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
assert _ == ">\n", "Tool call format error"
if stop_token == tool_calls_end_token:
break
assert stop_token is not None, "Missing special token"
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
assert len(p_tool_name) == 1, "Tool name format error"
tool_name = p_tool_name[0]
tool_args: Dict[str, Tuple[str, str]] = {}
while stop_token == f"<{dsml_token}parameter":
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
assert len(param_kv) == 1, "Parameter format error"
param_name, string, param_value = param_kv[0]
assert param_name not in tool_args, "Duplicate parameter name"
tool_args[param_name] = (param_value, string)
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
assert content == ">\n", "Parameter format error"
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
tool_calls.append(tool_call)
return index, stop_token, tool_calls
# NOTE: This function is designed to parse only correctly formatted string and will not attempt to correct malformed output that may be generated by the model.
def parse_message_from_completion_text(text: str, thinking_mode: str):
summary_content, reasoning_content, tool_calls = "", "", []
index, stop_token = 0, None
tool_calls_start_token = f"\n\n<{dsml_token}function_calls"
is_thinking, is_tool_calling = thinking_mode == "thinking", False
if is_thinking:
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
reasoning_content = content_delta
assert stop_token == thinking_end_token, "Invalid thinking format"
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
summary_content = content_delta
if stop_token == tool_calls_start_token:
is_tool_calling = True
else:
assert stop_token == eos_token, "Invalid summary format"
if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
assert not tool_ends_text, "Unexpected content after tool calls"
assert len(text) == index and stop_token in [eos_token, None], "Unexpected content at end"
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
assert sp_token not in summary_content and sp_token not in reasoning_content, "Unexpected special token in content"
return {
"role": "assistant",
"content": summary_content,
"reasoning_content": reasoning_content,
"tool_calls": tool_calls_to_openai_format(tool_calls)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment