Commit c0d96b32 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2852 failed with stages
in 0 seconds
import contextlib
import io
import json
import os
import re
import sys
import traceback
import fire
from vllm import LLM, SamplingParams
max_turns = 5
system_prompt_template = """You are an AI Agent who is proficient in solve complicated task.
Each step you should wirte executable code to fulfill user query. Any Response without code means the task is completed and you do not have another chance to submit code
You are equipped with a codeinterpreter. You can give the code and get the execution result of your code. You should use the codeinterpreter in the following format:
<|execute_start|>
```python
<your code>
```
<|execute_end|>
WARNING:Do not use cv2.waitKey(0) cv2.destroyAllWindows()!!! Or the program will be destoried
Each round, your answer should ALWAYS use the following format(Each of your response should contain code, until you complete the task):
Analyse:(Analyse the message you received and plan what you should do)
This Step Todo: One Subtask need to be done at this step
Code(WARNING:MAKE SURE YOU CODE FOLLOW THE FORMAT AND WRITE CODE OR THE TASK WILL BE FAILED):
<|execute_start|>
```python
<your code>
```
<|execute_end|>
You will got the result of your code after each step. When the code of previous subtask is excuted successfully, you can write and excuet the code for next subtask
When all the code your write are executed and you got the code result that can fulfill the user query, you should summarize the previous analyse process and make a formal response to user, The response should follow this format:
WARNING:MAKE SURE YOU GET THE CODE EXECUTED RESULT THAT FULFILLED ALL REQUIREMENT OF USER BEFORE USE "Finished"
Finished: <Answer to user query>
Some notice:
1. When you want to draw a plot, use plt.savefig() and print the image path in markdown format instead of plt.show()
2. Save anything to ./output folder
3. End the process whenever you complete the task, When you do not have Action(Code), Use: Finished: <summary the analyse process and make response>
4. Do not ask for user input in your python code.
"""
def execute_code(code):
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
# Note here we simplely imitate notebook output.
# if you want to run more complex tasks, try to use nbclient to run python code
lines = code.strip().split('\n')
last_expr = lines[-1].strip()
if '=' in last_expr:
value = last_expr.split('=')[0].strip()
code += f"\nprint({value})"
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
try:
# execute code here
exec(code)
except Exception as e:
return {'output': stdout_capture.getvalue(), 'error': str(e)}
return {'output': stdout_capture.getvalue(), 'error': stderr_capture.getvalue()}
class DemoLLM:
def __init__(self, model_path):
# Initialize default sampling parameters
params_dict = {
"n": 1,
"best_of": None,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"repetition_penalty": 1.02,
"temperature": 1.0,
"top_p": 0.85,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1.0,
"early_stopping": False,
"stop": None,
"stop_token_ids": None,
"ignore_eos": False,
"max_tokens": 300,
"logprobs": None,
"prompt_logprobs": None,
"skip_special_tokens": True,
}
# Create a SamplingParams object
self.sampling_params = SamplingParams(**params_dict)
# Initialize the language model
self.llm = LLM(
model=model_path,
tensor_parallel_size=1,
trust_remote_code=True,
enforce_eager=True
)
def apply_template(self, messages):
"""Formats messages into a prompt string for the LLM."""
formatted_messages = [
f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
for msg in messages
]
formatted_messages.append("<|im_start|>assistant\n")
return ''.join(formatted_messages)
def generate(self, messages):
"""Generates a response from the LLM based on the input messages."""
raw_input = self.apply_template(messages)
response = self.llm.generate(raw_input, self.sampling_params)
if response:
return response[0].outputs[0].text
return None
def extract_code(text):
""" Extracts Python code blocks from the given text. """
# Define a regular expression pattern to match Python code blocks
pattern = r'```python\s+(.*?)\s+```'
matches = re.findall(pattern, text, re.DOTALL)
return matches
def process(model_path):
"""
Processes interactions with the DemoLLM using provided model path.
Args:
model_path (str): The path to the language model directory.
"""
# Initialize the language model
llm = DemoLLM(model_path)
# Define initial messages
messages = [
{"role": "system", "content": system_prompt_template},
{"role": "user", "content": "2 的 100 次方是多少?"},
]
for index in range(max_turns):
print(f"Turn {index+1} start...")
# Generate response from the LLM
raw_resp = llm.generate(messages)
print(f"Raw response: {raw_resp}")
# Check if the response contains the termination keyword
if "Finished" in raw_resp:
break
# Extract code from the raw response
code_list = extract_code(raw_resp)
if not code_list:
break
# Execute the extracted code
code_str = code_list[-1]
run_result = execute_code(code_str)
executor_response = run_result['output'] if run_result['error'] == "" else run_result['error']
print(f"Code execution result: {run_result}")
# Append the execution result to the messages
messages.append({"role": "user", "content": executor_response})
if __name__ == "__main__":
fire.Fire(process)
\ No newline at end of file
# MiniCPM FunctionCall
1. Start VLLM functioncall server
```shell
python -m vllm.entrypoints.openai.api_server \
--model openbmb/MiniCPM3-4B \
--dtype auto \
--api-key token-abc123 \
--tensor-parallel-size 1 \
--trust-remote-code \
--enable-auto-tool-choice \
--tool-call-parser minicpm \
--tool-parser-plugin minicpm_tool_parser.py
```
2. Functioncall client example
```python
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
}
]
messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]
completion = client.chat.completions.create(
model="openbmb/MiniCPM3-4B",
messages=messages,
tools=tools,
tool_choice="auto"
)
print(completion)
```
3. Run functioncall inference locally
```shell
python functioncall.py
```
# Thanks
- resolve_ast_call and resolve_ast_by_type from [gorilla](https://github.com/ShishirPatil/gorilla)
- minicpm chat template with tool from @CISCai
\ No newline at end of file
#!/usr/bin/env python
# encoding: utf-8
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from minicpm_tool_parser import fc2dict
import json
model_path = "openbmb/MiniCPM3-4B"
tools = [
{
"type": "function",
"function": {
"name": "get_delivery_date",
"description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
"parameters": {
"type": "object",
"properties": {
"order_id": {
"type": "string",
"description": "The customer's order ID.",
},
},
"required": ["order_id"],
"additionalProperties": False,
},
},
}
]
messages = [
{
"role": "system",
"content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
},
{
"role": "user",
"content": "Hi, can you tell me the delivery date for my order? The order id is 1234 and 4321.",
},
# {
# "content": "",
# "tool_calls": [
# {
# "type": "function",
# "function": {
# "name": "get_delivery_date",
# "arguments": {"order_id": "1234"},
# },
# "id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
# },
# {
# "type": "function",
# "function": {
# "name": "get_delivery_date",
# "arguments": {"order_id": "4321"},
# },
# "id": "call_628965479dd84794bbb72ab9bdda0c39",
# },
# ],
# "role": "assistant",
# },
# {
# "role": "tool",
# "content": '{"delivery_date": "2024-09-05", "order_id": "1234"}',
# "tool_call_id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
# },
# {
# "role": "tool",
# "content": '{"delivery_date": "2024-09-05", "order_id": "4321"}',
# "tool_call_id": "call_628965479dd84794bbb72ab9bdda0c39",
# },
# {
# "content": "Both your orders will be delivered on 2024-09-05.",
# "role": "assistant",
# "thought": "\nI have the information you need, both orders will be delivered on the same date, 2024-09-05.\n",
# },
]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tools=tools, tokenize=False, add_generation_prompt=True
)
llm = LLM(model_path, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1000)
def fake_tool_execute(toolcall):
data = {
"delivery_date": "2024-09-05",
"order_id": toolcall.get("function", {})
.get("arguments", {})
.get("order_id", "order_id"),
}
return json.dumps(data)
while True:
prompt = tokenizer.apply_chat_template(
messages, tools=tools, tokenize=False, add_generation_prompt=True
)
outputs = llm.generate([prompt], sampling_params)
response = outputs[0].outputs[0].text
msg = fc2dict(response)
if (
"tool_calls" in msg
and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0
):
messages.append(msg)
print(msg)
for toolcall in msg["tool_calls"]:
tool_response = fake_tool_execute(toolcall)
tool_msg = {
"role": "tool",
"content": tool_response,
"tool_call_id": toolcall["id"],
}
messages.append(tool_msg)
print(tool_msg)
else:
messages.append(msg)
print(msg)
break
{%- macro json_to_python_type(param_name, json_spec) %}
{%- set basic_type_map = {
'string': 'str',
'number': 'float',
'integer': 'int',
'boolean': 'bool',
'null': 'None'
} %}
{%- if json_spec.enum %}
{{- param_name|title }}
{%- elif basic_type_map[json_spec.type] is defined %}
{{- basic_type_map[json_spec.type] }}
{%- elif json_spec.type == 'array' %}
{{- 'List[' + json_to_python_type(param_name, json_spec['items']) + ']' }}
{%- elif json_spec.type == 'object' %}
{{- 'Dict[str, ' + json_to_python_type(param_name, json_spec.additionalProperties if json_spec.additionalProperties else 'Any') + ']' if not json_spec.properties else param_name|title }}
{%- elif json_spec.type is iterable %}
{{- 'Union[' }}
{%- for t in json_spec.type %}
{{- json_to_python_type(param_name, {'type': t}) }}
{{- ', ' if not loop.last }}
{%- endfor %}
{{- ']' }}
{%- else %}
{{- 'Any' }}
{%- endif %}
{%- endmacro %}
{%- macro object_to_fields(json_spec, field_indent) %}
{%- set o_ns = namespace(f = caller()) %}
{%- for param_name, param_fields in json_spec.properties|items %}
{%- if param_fields.enum %}
{{- '\n\nclass ' + param_name|title + '(Enum):\n' }}
{%- for enum_option in param_fields.enum %}
{{- ' enum_' + loop.index0|string + ' = ' + enum_option|tojson + '\n' }}
{%- endfor %}
{%- elif param_fields.type == 'object' and param_fields.properties %}
{%- call object_to_fields(param_fields, ' ') %}
{{- '\n\nclass ' + param_name|title + '(BaseModel):\n' }}
{%- endcall %}
{%- elif param_fields.type == 'array' and param_fields['items'] and param_fields['items'].type == 'object' and param_fields['items'].properties %}
{%- call object_to_fields(param_fields['items'], ' ') %}
{{- '\n\nclass ' + param_name|title + '(BaseModel):\n' }}
{%- endcall %}
{%- endif %}
{%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}
{%- set o_ns.f = o_ns.f + field_indent + param_name + ': ' %}
{%- set o_ns.f = o_ns.f + ('Optional[' + json_to_python_type(param_name, param_fields) + ']' if param_name not in json_spec.required else json_to_python_type(param_name, param_fields)) %}
{%- if not param_fields.title and not param_fields.description and not param_fields.pattern %}
{%- set o_ns.f = o_ns.f + (' = ' + param_default if param_name not in json_spec.required else '') %}
{%- else %}
{%- set o_ns.f = o_ns.f + (' = Field(...' if param_name in json_spec.required else ' = Field(' + param_default) %}
{%- set o_ns.f = o_ns.f + (', description=' + param_fields.description|tojson if param_fields.description else '') %}
{%- set o_ns.f = o_ns.f + (', regex=' + param_fields.pattern|tojson if param_fields.pattern else '') %}
{%- set o_ns.f = o_ns.f + (', title=' + param_fields.title|tojson if param_fields.title else '') %}
{%- set o_ns.f = o_ns.f + ')' %}
{%- endif %}
{%- set o_ns.f = o_ns.f + '\n' %}
{%- endfor %}
{{- o_ns.f }}
{%- endmacro %}
{%- macro tool_parser(tools) %}
{%- for tool in tools %}
{%- if tool.type is not defined or tool.type == 'function' %}
{%- if tool.function is defined %}
{%- set tool = tool.function %}
{%- endif %}
{%- set tool_params = tool.parameters if tool.parameters is defined else none %}
{%- call object_to_fields(tool_params, ' ') %}
{{- '\n\ndef ' + tool.name + '(' }}
{%- if tool_params %}
{%- for param_name, param_fields in tool_params.properties|items %}
{%- set param_default = param_fields.default|tojson if param_fields.default is string else param_fields.default|string if param_fields.default is defined else 'None' %}
{{- ', ' if loop.index0 != 0 }}
{{- param_name }}
{{- '=' + param_default if param_name not in tool_params.required }}
{%- endfor %}
{%- endif %}
{{- '):\n """' }}
{{- tool.description }}
{{- '\n\n Args:\n' if tool_params else '\n' }}
{%- endcall %}
{{- ' """\n' }}
{%- endif %}
{%- endfor %}
{%- endmacro %}
{%- if messages[0]['role'] == 'system' %}
{%- set loop_messages = messages[1:] %}
{%- set system_message = messages[0]['content'] %}
{%- else %}
{%- set loop_messages = messages %}
{%- set system_message = '' %}
{%- endif %}
{{- '<|im_start|>system\n' + system_message if system_message or tools }}
{%- if tools %}
{{- '\n# Functions\nHere is a list of functions that you can invoke:\n```python\nfrom enum import Enum\nfrom typing import List, Dict, Optional\nfrom pydantic import BaseModel, Field\n\n' }}
{{- tool_parser(tools) }}
{{- "\n```\n\n# Function Call Rule and Output Format\n- If the user's question can be answered without calling any function, please answer the user's question directly. In this situation, you should return your thought and answer the user's question directly.\n- If the user cannot be answered without calling any function, and the user does not provide enough information to call functions, please ask the user for more information. In this situation, you should return your thought and ask the user for more information.\n- If the user's question cannot be answered without calling any function, and the user has provided enough information to call functions to solve it, you should call the functions. In this situation, the assistant should return your thought and call the functions.\n- Use default parameters unless the user has specified otherwise.\n- You should answer in the following format:\n\n<|thought_start|>\n{explain why the user's question can be answered without calling a function or why you should ask the user for more information or why you should call one or more functions and your plan to solve the user's question.}\n<|thought_end|>\n<|tool_call_start|>\n```python\nfunc1(params_name=params_value, params_name2=params_value2...)\nfunc2(params)\n```\n<|tool_call_end|>\n{answer the user's question directly or ask the user for more information}" }}
{%- endif %}
{{- '<|im_end|>\n' if system_message or tools }}
{%- for message in loop_messages %}
{%- set content = message.content %}
{%- if message.role == 'assistant' and message.tool_calls %}
{{- '<|im_start|>' + message.role + '\n' }}
{{- '<|thought_start|>\n' + message.thought + '\n<|thought_end|>\n' if message.thought }}
{{- '<|tool_call_start|>\n```python\n' }}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- tool_call.name + '(' }}
{%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %}
{%- for param_name, param_value in tool_call.arguments|items %}
{{- param_name + '=' + param_value|tojson }}
{{- ',' if not loop.last }}
{%- endfor %}
{%- endif %}
{{- ')\n' }}
{%- endfor %}
{{- '```\n<|tool_call_end|>\n' }}
{{- content if content and not content.startswith('<|tool_call_start|>') }}
{{- '<|im_end|>\n' }}
{%- elif message.role == 'assistant' and message.thought %}
{{- '<|im_start|>' + message.role + '\n' + '<|thought_start|>\n' + message.thought + '\n<|thought_end|>\n' + content + '<|im_end|>\n' }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
#!/usr/bin/env python
# encoding: utf-8
import ast
import json
import keyword
import re
import traceback
from typing import Dict, List, Sequence, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ToolParserManager.register_module("minicpm")
class MiniCPMToolParser(ToolParser):
"""
Tool call parser for MiniCPM3 4B models intended for use with the
examples/tool_chat_template_minicpm3.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser minicpm are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.thought_start_token = "<|thought_start|>"
self.thought_end_token = "<|thought_end|>"
self.tool_call_start_token = "<|tool_call_start|>"
self.tool_call_end_token = "<|tool_call_end|>"
self.stop_token_ids = [2, 73440]
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
"""
Extract the tool calls from a complete model response.
"""
msg = fc2dict(model_output)
if ("tool_calls" in msg and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0):
tool_calls: List[ToolCall] = [
ToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(raw_function_call["arguments"],
ensure_ascii=False),
),
) for raw_function_call in msg["tool_calls"]
]
# get any content before the tool call
ret = ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=msg.get("content", None),
)
return ret
else:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[],
content=msg.get("content", None),
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# if no tools are provided, we don't need to parse tool calls
if not request.tools:
return DeltaMessage(content=delta_text)
if self.thought_end_token not in current_text:
return None
useful_text = current_text.split(self.thought_end_token)[1]
if (current_token_ids[-1]
in self.stop_token_ids): # case 3: stream generation ended
msg = fc2dict(current_text)
if ("tool_calls" in msg and msg["tool_calls"] is not None
and len(msg["tool_calls"]) > 0):
self.prev_tool_call_arr = msg["tool_calls"]
self.streamed_args_for_tool = ["" for tc in msg["tool_calls"]]
delta_message = DeltaMessage(
role="assistant",
content=msg.get("content", None),
)
return delta_message
else:
return DeltaMessage(content=msg.get("content", None))
elif (self.tool_call_start_token in useful_text
and self.tool_call_end_token
in useful_text): # case 2: tool call ended
return None
elif (self.tool_call_start_token
in useful_text): # case 1: tool call started
# Extract function name and arguments, handling nested parentheses
pattern = r"(\w+)\(((?:[^()]*|\([^()]*\))*)\)"
matches = re.finditer(pattern, useful_text)
tool_calls: List[Dict] = []
delta = None
for idx, match in enumerate(matches):
if self.current_tool_id < idx:
self.current_tool_id = idx
func_name = match.group(1)
func_args = match.group(2)
tool_call_string = f"{func_name}({func_args})\n"
parsed = ast.parse(tool_call_string)
for elem in parsed.body:
assert isinstance(elem.value, ast.Call) # type: ignore
calls = resolve_ast_call(elem.value) # type: ignore
for func_name, func_args in calls.items():
this_call = {
"name":
func_name,
"arguments":
json.dumps(func_args, ensure_ascii=False),
}
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
**this_call).model_dump(exclude_none=True),
)
])
self.prev_tool_call_arr = tool_calls
self.streamed_args_for_tool = ["" for x in tool_calls]
self.current_tool_name_sent = True
return delta
else:
return None
def fc2dict(
sequence: str,
tool_call_start="<|tool_call_start|>",
tool_call_end="<|tool_call_end|>",
thought_start="<|thought_start|>",
thought_end="<|thought_end|>",
):
if thought_end in sequence and thought_start in sequence:
thought_string, sequence = sequence.rsplit(thought_end, 1)
thought_string = thought_string.split(thought_start, 1)[1]
else:
thought_string = ""
if tool_call_start in sequence and tool_call_end in sequence:
tool_call_string, content = sequence.rsplit(tool_call_end, 1)
tool_call_string = tool_call_string.split(tool_call_start, 1)[1]
try:
tool_calls = []
tool_call_string = tool_call_string.strip()
if tool_call_string.startswith("```"):
tool_call_string = tool_call_string[3:].strip()
if tool_call_string.startswith("python"):
tool_call_string = tool_call_string.lstrip(
"python").strip()
if tool_call_string.endswith("```"):
tool_call_string = tool_call_string[:-3].strip()
for kw in keyword.kwlist:
tool_call_string = tool_call_string.replace(
"," + kw + "=", "," + kw + "_=")
tool_call_string = tool_call_string.replace(
" " + kw + "=", " " + kw + "_=")
tool_call_string = tool_call_string.replace(
"(" + kw + "=", "(" + kw + "_=")
parsed: ast.Module = ast.parse(tool_call_string)
for elem in parsed.body:
assert isinstance(elem.value, ast.Call) # type: ignore
calls = resolve_ast_call(elem.value) # type: ignore
for func_name, func_args in calls.items():
new_args = {}
for k, v in func_args.items():
for kw in keyword.kwlist:
if k == kw + "_":
k = kw
new_args[k] = v
this_one = {"name": func_name, "arguments": new_args}
tool_calls.append(this_one)
return {
"content": content.strip(),
"tool_calls": tool_calls,
"role": "assistant",
}
except Exception as e:
logger.error("Error parsing tool call: %s", str(e))
logger.error(traceback.format_exc())
return {
"content": content.strip(),
"role": "assistant",
"thought": thought_string,
}
else:
return {
"content": sequence.strip(),
"role": "assistant",
"thought": thought_string,
}
# from ShishirPatil/gorilla
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
output = "..." if value.value is Ellipsis else value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value # type: ignore
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {
resolve_ast_by_type(k): resolve_ast_by_type(v)
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value,
ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(
value, ast.BinOp
): # Added this condition to handle function calls as arguments
output = ast.literal_eval(ast.unparse(value)) # type: ignore
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value) # type: ignore
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = ast.literal_eval(
ast.unparse( # type: ignore
value.body[0].value)) # type: ignore
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value) # type: ignore
except Exception as e:
logger.error("Error parsing tool call: %s", str(e))
output = (
ast.unparse(value.value) + "[" + # type: ignore
ast.unparse(value.slice) + "]") # type: ignore
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output
<div align="center">
<img src="./assets/logo.png" alt="MiniCPM-4-MCP 标志" width="400em"></img>
</div>
<p align="center">
【中文 | <a href="README_en.md">English</a>
</p>
<p align="center">
<a href="#概述">概述</a>
<a href="#安装">安装</a>
<a href="#模型训练">模型训练</a>
<a href="https://huggingface.co/openbmb/MiniCPM4-MCP">模型下载</a>
<a href="#推理">推理</a>
<a href="#模型评估">模型评估</a>
</p>
## 最新消息
* [2025-06-05] 🚀🚀🚀我们开源了基于MiniCPM4-8B构建的MiniCPM4-MCP,能够调用多种MCP工具,性能比肩更大模型。
## 概述
**MiniCPM-4-MCP** 是由[清华大学自然语言处理实验室(THUNLP)](https://nlp.csai.tsinghua.edu.cn)、中国人民大学与 [ModelBest](https://modelbest.cn/en) 联合开发的开源本地大语言模型智能体,基于 [MiniCPM-4](https://huggingface.co/openbmb/MiniCPM4-8B),拥有 80 亿参数。它能够通过 MCP 与各种工具和数据资源交互,解决多种真实世界任务。截至目前,MiniCPM4-MCP 已支持:
- 涵盖 16 个 MCP Servers 中工具的使用:这些servers横跨了办公类、生活类、通讯类、资讯类、工作管理类等.
- 单工具使用的能力:可使用符合 MCP 协议的工具进行单一工具的一步或多步调用。
- 跨工具组合使用的能力:可组合使用符合 MCP 协议的不同工具。
演示案例(1 倍速):
https://github.com/user-attachments/assets/b60b3ed6-f175-4ab7-ace2-4488ef82cb69
## 安装
所需依赖包版本在 `./requirements.txt` 中列出,以确保兼容性。
```
pip install -r requirements.txt
```
## 模型训练
我们主要采用“示范学习”的方法训练模型。示范数据由大语言模型与 MCP 环境持续交互生成。MiniCPM 通过监督微调(SFT)从这些示范中学习。我们使用 LLaMa-Factory 作为 SFT 框架训练 MiniCPM,适配过后的 LLaMa-Factory 版本会进行发布。
### 模型下载
通过[此链接](https://huggingface.co/openbmb/MiniCPM4-MCP)下载模型。
### 数据格式
```json
{
"conversations": [
{"from": "human", "value": "你好,我想把500美元换成欧元,你能帮我吗?"},
{"from": "gpt", "value": "<|thought_start|>\n我将调用 get_currency_exchange_rate 函数将 500 美元兑换成欧元。\n<|thought_end|>\n<|tool_call_start|>\n```python\nget_currency_exchange_rate(from_currency=\"USD\",to_currency=\"EUR\",amount=500)\n```\n<|tool_call_end|>\n"},
{"from": "tool", "value": "{\"converted_amount\": 425.50, \"exchange_rate\": 0.851}"},
{"from": "gpt", "value": "<|thought_start|>\n助手认为用户想进行货币兑换,这是可以通过 get_currency_exchange_rate 函数完成的,并已获得执行该函数所需的参数。\n<|thought_end|>\n当然,500 美元大约可兑换为 425.50 欧元。当前汇率为 0.851。"}
],
"tools": "[{\"name\": \"get_currency_exchange_rate\", \"description\": \"获取两种货币之间的汇率\", \"parameters\": {\"type\": \"object\", \"properties\": {\"from_currency\": {\"type\": \"string\", \"description\": \"原始货币\"}, \"to_currency\": {\"type\": \"string\", \"description\": \"目标货币\"}, \"amount\": {\"type\": \"number\", \"description\": \"兑换金额\"}}, \"required\": [\"from_currency\", \"to_currency\", \"amount\"]}}, {\"name\": \"generate_random_password\", \"description\": \"生成符合特定要求的随机密码\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"密码长度\"}, \"include_numbers\": {\"type\": \"boolean\", \"description\": \"是否包含数字\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"是否包含符号\"}}, \"required\": [\"length\"]}}]",
"system": "你是一个具有函数调用能力的智能助手。请在需要时使用这些函数。"
}
```
### 单机训练
在单台机器上执行训练:
```bash
llamafactory-cli train /path/to/config.yaml
```
示例:
```bash
llamafactory-cli train ./LLaMA-Factory/examples/train_full/minicpm4/sft.yaml
```
### 分布式训练
所有节点必须满足以下条件:
- 软件环境一致
- 网络可连接(如开放 TCP 端口 29500)
- 可访问相同的训练数据
步骤如下:
1. 获取主节点IP(节点0):
```bash
hostname -I | awk '{print $1}'
```
2. 启动训练:
主节点(节点0):
```bash
export MASTER_ADDR=$(hostname -I | awk '{print $1}')
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=$MASTER_ADDR MASTER_PORT=29500 \
llamafactory-cli train /path/to/config.yaml
```
工作节点(节点1):
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=$MASTER_ADDR MASTER_PORT=29500 \
llamafactory-cli train /path/to/config.yaml
```
## 推理
### MCP Servers 部署
MiniCPM4-MCP 所支持的 MCP Servers 具体包含
[Airbnb](https://github.com/openbnb-org/mcp-server-airbnb),
[Amap-Maps](https://github.com/zxypro1/amap-maps-mcp-server),
[Arxiv-MCP-Server](https://github.com/blazickjp/arxiv-mcp-server),
[Calculator](https://github.com/githejie/mcp-server-calculator),
[Computer-Control-MCP](https://github.com/AB498/computer-control-mcp),
[Desktop-commander](https://github.com/wonderwhy-er/DesktopCommanderMCP),
[Filesystem](https://github.com/mark3labs/mcp-filesystem-server),
[Github](https://github.com/modelcontextprotocol/servers-archived/tree/main/src/github),
[Gaode](https://github.com/perMAIN/gaode),
[MCP-Code-Executor](https://github.com/bazinga012/mcp_code_executor),
[MCP-DOCx](https://github.com/MeterLong/MCP-Doc),
[PPT](https://github.com/GongRzhe/Office-PowerPoint-MCP-Server),
[PPTx](https://github.com/supercurses/powerpoint),
[Simple-Time-Server](https://github.com/andybrandt/mcp-simple-timeserver),
[Slack](https://github.com/modelcontextprotocol/servers-archived/tree/main/src/slack),
[Whisper](https://github.com/arcaputo3/mcp-server-whisper)
根据这些servers的仓库指引即可成功部署。需要注意的是,这些servers中包含的工具并非全部都可以在环境中顺利跑通,有一些工具的波动性较大,会返回例如timeout、http error等报错。在训练数据构造的过程中,失败率过高的工具(例如LLM在上百次尝试后仍无法为该工具构建出一条能将其成功调用的query)会被过滤掉。
### MCP Client 搭建
我们基于 [mcp-cli](https://github.com/chrishayuk/mcp-cli) 仓库已经实现的 MCP Client 进行修改,由此实现 MiniCPM 和 MCP Server 的交互。MCP Client与 Server 进行握手后所获得的server工具列表内容样例如`available_tool_example.json`所示。获取到available tools以及用户query之后,可按照以下脚本中的逻辑使用指定模型生成结果:
```bash
python generate_example.py \
--tokenizer_path {minicpm4 tokenizer的路径} \
--base_url {vllm部署的url} \
--model {vllm部署时的模型名} \
--output_path {结果保存路径}
```
其中,MiniCPM4 以如下格式生成工具调用:
```
<|tool_call_start|>
```python
read_file(path="/path/to/file")
```
<|tool_call_end|>
```
可依据此逻辑为 MiniCPM4 工具调用实现自定义解析器,解析逻辑相关代码位于:`generate_example.py`
由于 [mcp-cli](https://github.com/chrishayuk/mcp-cli) 仓库支持 vllm 推理框架,因此也可以通过修改vllm从而令MiniCPM4-MCP直接适配mcp-cli的运行逻辑。具体而言,可按照[此链接](https://github.com/OpenBMB/MiniCPM/tree/main/demo/minicpm3/function_call)所述方式修改vllm从而实现搭载着MiniCPM4-MCP模型的client与server的交互通信。
## 模型评估
生成结束后,运行以下脚本进行评估:
```bash
python eval_scripts.py \
--input_path {generate生成结果的保存路径}
```
该脚本用于评估模型在单轮工具调用中函数名预测的表现。多轮调用情况下,给定之前步骤的ground-truth信息即可评测模型在当前步骤生成的工具调用指令的准确性,每个步骤的评测逻辑与单轮相同。
### 评估结果
| MCP 服务器 | | gpt-4o | | | qwen3 | | | minicpm4 | |
| -------------------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- | -------- |
| | 函数名正确率 | 参数名正确率 | 数值正确率 | 函数名正确率 | 参数名正确率 | 数值正确率 | 函数名正确率 | 参数名正确率 | 数值正确率 |
| Airbnb | 89.3 | 67.9 | 53.6 | 92.8 | 60.7 | 50.0 | 96.4 | 67.9 | 50.0 |
| Amap-Maps | 79.8 | 77.5 | 50.0 | 74.4 | 72.0 | 41.0 | 89.3 | 85.7 | 39.9 |
| Arxiv-MCP-Server | 85.7 | 85.7 | 85.7 | 81.8 | 54.5 | 50.0 | 57.1 | 57.1 | 52.4 |
| Calculator | 100.0 | 100.0 | 20.0 | 80.0 | 80.0 | 13.3 | 100.0 | 100.0 | 6.67 |
| Computor-Control-MCP | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 86.7 |
| Desktop-Commander | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 |
| Filesystem | 63.5 | 63.5 | 31.3 | 69.7 | 69.7 | 26.0 | 83.3 | 83.3 | 42.7 |
|Github | 92.0 | 80.0 | 58.0 | 80.5 | 50.0 | 27.7 | 62.8 | 25.7 | 17.1 |
| Gaode | 71.1 | 55.6 | 17.8 | 68.8 | 46.6 | 24.4 | 68.9 | 46.7 | 15.6 |
| MCP-Code-Executor | 85.0 | 80.0 | 70.0 | 80.0 | 80.0 | 70.0 | 90.0 | 90.0 | 65.0 |
| MCP-Docx | 95.8 | 86.7 | 67.1 | 94.9 | 81.6 | 60.1 | 95.1 | 86.6 | 76.1 |
| PPT | 72.6 | 49.8 | 40.9 | 85.9 | 50.7 | 37.5 | 91.2 | 72.1 | 56.7 |
| PPTx | 64.2 | 53.7 | 13.4 | 91.0 | 68.6 | 20.9 | 91.0 | 58.2 | 26.9 |
| Simple-Time-Server | 90.0 | 70.0 | 70.0 | 90.0 | 90.0 | 90.0 | 90.0 | 60.0 | 60.0 |
| Slack | 100.0 | 90.0 | 70.0 | 100.0 | 100.0 | 65.0 | 100.0 | 100.0 | 100.0 |
| Whisper | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 30.0 |
| **平均值** | **80.2** | **70.2** | **49.1** | **83.5** | **67.7** | **43.8** | **88.3** | **76.1** | **51.2** |
<div align="center">
<img src="./assets/logo.png" alt="MiniCPM-4-MCP Logo" width="400em"></img>
</div>
<p align="center">
[English | <a href="README.md">中文</a>]
</p>
<p align="center">
<a href="#-overview">Overview</a>
<a href="#%EF%B8%8F-training">Training</a>
<a href="https://huggingface.co/openbmb/MiniCPM4-MCP">Model</a>
<a href="#-inference">Inference</a>
<a href="#-evaluation">Evaluation</a>
</p>
## News
* [2025-06-05] 🚀🚀🚀 We have open-sourced MiniCPM4-MCP, built on MiniCPM4-8B, which is capable of calling a variety of MCP tools and delivers performance comparable to larger models.
## 🚩 Overview
**MiniCPM4-MCP** is an open-source on-device LLM agent model jointly developed by [THUNLP](https://nlp.csai.tsinghua.edu.cn), Renmin University of China and [ModelBest](https://modelbest.cn/en), built on [MiniCPM-4](https://huggingface.co/openbmb/MiniCPM4-8B) with 8 billion parameters. It is capable of solving a wide range of real-world tasks by interacting with various tool and data resources through MCP. As of now, MiniCPM4-MCP supports the following:
- Utilization of tools across 16 MCP servers: These servers span various categories, including office, lifestyle, communication, information, and work management.
- Single-tool-calling capability: It can perform single- or multi-step tool calls using a single tool that complies with the MCP.
- Cross-tool-calling capability: It can perform single- or multi-step tool calls using different tools that complies with the MCP.
Demo Case (1 x speed):
https://github.com/user-attachments/assets/e5aa0780-0c18-42f2-a787-8171110a7706
## 🛠️ Installation
The required package versions are listed in `./requirements.txt` to ensure compatibility.
```
pip install -r requirements.txt
```
## 📽️ Training
We primarily adopt a learning-from-demonstration approach to train our model. The demonstrations are generated through continuous interactions between an LLM and the MCP environment. MiniCPM learns from these demonstrations through Supervised Fine-Tuning (SFT). We employ LLaMa-Factory as our SFT framework, with an adapted version for MiniCPM.
### Model Download
Download the model in this [link](https://huggingface.co/openbmb/MiniCPM4-MCP).
### Data Format
```json
{
"conversations": [
{"from": "human", "value": "Hi, I need to convert 500 US dollars to Euros. Can you help me with that?"},
{"from": "gpt", "value": "<|thought_start|>\nI will call the get_currency_exchange_rate function to convert 500 US dollars to Euros.\n<|thought_end|>\n<|tool_call_start|>\n```python\nget_currency_exchange_rate(from_currency=\"USD\",to_currency=\"EUR\",amount=500)\n```\n<|tool_call_end|>\n"},
{"from": "tool", "value": "{\"converted_amount\": 425.50, \"exchange_rate\": 0.851}"},
{"from": "gpt", "value": "<|thought_start|>\nThe assistant thought that the user asked for a currency conversion, which is a task that can be handled by the 'get_currency_exchange_rate' function, and the assistant has received the necessary parameters to execute this function.\n<|thought_end|>\nSure, 500 US dollars will convert to approximately 425.50 Euros. The current exchange rate is 0.851."}
],
"tools": "[{\"name\": \"get_currency_exchange_rate\", \"description\": \"Get the exchange rate between two currencies\", \"parameters\": {\"type\": \"object\", \"properties\": {\"from_currency\": {\"type\": \"string\", \"description\": \"The currency to convert from\"}, \"to_currency\": {\"type\": \"string\", \"description\": \"The currency to convert to\"}, \"amount\": {\"type\": \"number\", \"description\": \"The amount to convert\"}}, \"required\": [\"from_currency\", \"to_currency\", \"amount\"]}}, {\"name\": \"generate_random_password\", \"description\": \"Generate a random password with specified requirements\", \"parameters\": {\"type\": \"object\", \"properties\": {\"length\": {\"type\": \"integer\", \"description\": \"The length of the password\"}, \"include_numbers\": {\"type\": \"boolean\", \"description\": \"Include numbers in the password\"}, \"include_symbols\": {\"type\": \"boolean\", \"description\": \"Include symbols in the password\"}}, \"required\": [\"length\"]}}]",
"system": "You are a helpful assistant with access to some functions. Use them if required."}
```
### Single-Node Training
To run training on a single machine, simply use the following command:
```bash
llamafactory-cli train /path/to/config.yaml
```
Example:
```bash
llamafactory-cli train ./LLaMA-Factory/examples/train_full/minicpm4/sft.yaml
```
### Distributed Training Setup (2 Nodes Example)
All nodes must have:
- Identical software environments
- Network connectivity (e.g., TCP port 29500 open)
- Access to the same training data
To run training across 2 machines, follow these steps:
1. Determine Master Node IP
On your master node (node 0), run:
```bash
hostname -I | awk '{print $1}'
```
2. Launch Training
On master node (node 0):
```
export MASTER_ADDR=$(hostname -I | awk '{print $1}')
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=$MASTER_ADDR MASTER_PORT=29500 \
llamafactory-cli train /path/to/config.yaml
```
On worker node (node 1):
```bash
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=$MASTER_ADDR MASTER_PORT=29500 \
llamafactory-cli train /path/to/config.yaml
```
## 📖 Inference
### MCP Servers Deployment
The MCP Servers supported by MiniCPM4-MCP include
[Airbnb](https://github.com/openbnb-org/mcp-server-airbnb),
[Amap-Maps](https://github.com/zxypro1/amap-maps-mcp-server),
[Arxiv-MCP-Server](https://github.com/blazickjp/arxiv-mcp-server),
[Calculator](https://github.com/githejie/mcp-server-calculator),
[Computer-Control-MCP](https://github.com/AB498/computer-control-mcp),
[Desktop-commander](https://github.com/wonderwhy-er/DesktopCommanderMCP),
[Filesystem](https://github.com/mark3labs/mcp-filesystem-server),
[Github](https://github.com/modelcontextprotocol/servers-archived/tree/main/src/github),
[Gaode](https://github.com/perMAIN/gaode),
[MCP-Code-Executor](https://github.com/bazinga012/mcp_code_executor),
[MCP-DOCx](https://github.com/MeterLong/MCP-Doc),
[PPT](https://github.com/GongRzhe/Office-PowerPoint-MCP-Server),
[PPTx](https://github.com/supercurses/powerpoint),
[Simple-Time-Server](https://github.com/andybrandt/mcp-simple-timeserver),
[Slack](https://github.com/modelcontextprotocol/servers-archived/tree/main/src/slack), and
[Whisper](https://github.com/arcaputo3/mcp-server-whisper). Follow the instructions provided in each server's repository for successful deployment. Note that not all tools in these servers will function properly in every environment. Some tools are unstable and may return errors such as timeouts or HTTP errors. During training data construction, tools with consistently high failure rates (e.g., those for which the LLM fails to produce a successful query even after hundreds of attempts) are filtered out.
### MCP Client Setup
We modified the existing MCP Client from the [mcp-cli](https://github.com/chrishayuk/mcp-cli) repository to enable interaction between MiniCPM and MCP Servers.
After the MCP Client performs a handshake with a Server, it retrieves a list of available tools. An example of tool information contained in this list is provided in `available_tool_example.json`.
Once the available tools and user query are obtained, results can be generated using the following script logic:
```bash
python generate_example.py \
--tokenizer_path {path to MiniCPM4 tokenizer} \
--base_url {vllm deployment URL} \
--model {model name used in vllm deployment} \
--output_path {path to save results}
```
where MiniCPM4 generates tool calls in the following format:
```
<|tool_call_start|>
```python
read_file(path="/path/to/file")
```
<|tool_call_end|>
```
You can build a custom parser for MiniCPM4 tool calls based on this format. The relevant parsing logic is located in `generate_example.py`.
Since the [mcp-cli](https://github.com/chrishayuk/mcp-cli) repository supports the vLLM inference framework, MiniCPM4-MCP can also be integrated into `mcp-cli` by modifying vLLM accordingly.
Specifically, follow the instructions in [this link](https://github.com/OpenBMB/MiniCPM/tree/main/demo/minicpm3/function_call) to enable interaction between a client running the MiniCPM4-MCP model and the MCP Server.
## 📈 Evaluation
Once generation is complete, run the following example evaluation script:
```bash
python eval_scripts.py \
--input_path {path where the results generated by `generate` are saved}
```
This script is used to evaluate the model's performance in predicting function names during single-turn tool calls. In multi-turn scenarios, the accuracy of the tool call generated at the current step can be evaluated by providing the ground-truth information from previous steps. The evaluation logic for each step is the same as that of the single-turn setting.
### Evaluation Results
| MCP Server | | gpt-4o | | | qwen3 | | | minicpm4 | |
|-----------------------|----------------|--------------|--------------|---------------|--------------|--------------|----------------|--------------|--------------|
| | func | param | value | func | param | value | func | param | value |
| Airbnb | 89.3 | 67.9 | 53.6 | 92.8 | 60.7 | 50.0 | 96.4 | 67.9 | 50.0 |
| Amap-Maps | 79.8 | 77.5 | 50.0 | 74.4 | 72.0 | 41.0 | 89.3 | 85.7 | 39.9 |
| Arxiv-MCP-Server | 85.7 | 85.7 | 85.7 | 81.8 | 54.5 | 50.0 | 57.1 | 57.1 | 52.4 |
| Calculator | 100.0 | 100.0 | 20.0 | 80.0 | 80.0 | 13.3 | 100.0 | 100.0 | 6.67 |
| Computor-Control-MCP | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 86.7 |
| Desktop-Commander | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 | 100.0 |
| Filesystem | 63.5 | 63.5 | 31.3 | 69.7 | 69.7 | 26.0 | 83.3 | 83.3 | 42.7 |
|Github | 92.0 | 80.0 | 58.0 | 80.5 | 50.0 | 27.7 | 62.8 | 25.7 | 17.1 |
| Gaode | 71.1 | 55.6 | 17.8 | 68.8 | 46.6 | 24.4 | 68.9 | 46.7 | 15.6 |
| MCP-Code-Executor | 85.0 | 80.0 | 70.0 | 80.0 | 80.0 | 70.0 | 90.0 | 90.0 | 65.0 |
| MCP-Docx | 95.8 | 86.7 | 67.1 | 94.9 | 81.6 | 60.1 | 95.1 | 86.6 | 76.1 |
| PPT | 72.6 | 49.8 | 40.9 | 85.9 | 50.7 | 37.5 | 91.2 | 72.1 | 56.7 |
| PPTx | 64.2 | 53.7 | 13.4 | 91.0 | 68.6 | 20.9 | 91.0 | 58.2 | 26.9 |
| Simple-Time-Server | 90.0 | 70.0 | 70.0 | 90.0 | 90.0 | 90.0 | 90.0 | 60.0 | 60.0 |
| Slack | 100.0 | 90.0 | 70.0 | 100.0 | 100.0 | 65.0 | 100.0 | 100.0 | 100.0 |
| Whisper | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 90.0 | 30.0 |
| **Average** | **80.2** | **70.2** | **49.1** | **83.5** | **67.7** | **43.8** | **88.3** | **76.1** | **51.2** |
{
"available_tools": [
{
"function": {
"description": "",
"name": "searchPOI",
"parameters": {
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"properties": {
"city": {
"description": "搜索的城市名称",
"type": "string"
},
"extensions": {
"description": "返回结果控制,base返回基本信息,all返回详细信息",
"enum": [
"base",
"all"
],
"type": "string"
},
"keywords": {
"description": "要搜索的关键词",
"type": "string"
},
"offset": {
"description": "每页记录数,默认为20",
"type": "number"
},
"page": {
"description": "页码,默认为1",
"type": "number"
},
"types": {
"description": "POI类型",
"type": "string"
}
},
"required": [
"keywords"
],
"type": "object"
}
},
"type": "function"
},
{
"function": {
"description": "",
"name": "getWeather",
"parameters": {
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": false,
"properties": {
"city": {
"description": "城市编码",
"type": "string"
},
"extensions": {
"description": "气象类型:base(实况天气)、all(预报天气)",
"enum": [
"base",
"all"
],
"type": "string"
}
},
"required": [
"city"
],
"type": "object"
}
},
"type": "function"
}
]
}
\ No newline at end of file
import os
import json
from argparse import ArgumentParser
def normalize_value(value) -> object:
if isinstance(value, str):
try:
return float(value) if '.' in value else int(value)
except ValueError:
if value.lower() in ["true", "false"]:
return value.lower() == "true"
return value.lower().strip()
return value
def validate_prediction(pred_json):
return (
pred_json is not None
)
def evaluate_function_calls(test_data: list, eval_data: list) :
stats = {
"total": 0,
"function_correct": 0,
"argument_name_correct": 0,
"argument_value_correct": 0,
"invalid_gt": 0,
"invalid_pred": 0
}
function_accuracy_errors = []
argument_name_accuracy_errors = []
argument_value_accuracy_errors = []
for i in range(len(test_data)):
stats["total"] += 1
try:
pred_data = test_data[i]
gt_data = eval_data[i]
except:
breakpoint()
if not validate_prediction(pred_data):
stats["invalid_pred"] += 1
continue
if isinstance(pred_data,list):
find = False
for data in pred_data:
if data["name"] == gt_data["name"]:
pred_data = data
find = True
break
if not find:
pred_data = pred_data[0]
if pred_data["name"] == gt_data["name"]:
stats["function_correct"] += 1
pred_args = pred_data.get("arguments", {})
gt_args = gt_data.get("arguments", {})
try:
if set(pred_args.keys()) == set(gt_args.keys()):
stats["argument_name_correct"] += 1
all_values_match = True
for key in gt_args.keys():
normalized_pred = normalize_value(pred_args[key])
normalized_gt = normalize_value(gt_args[key])
if normalized_pred != normalized_gt:
all_values_match = False
argument_value_accuracy_errors.append({"prediction": pred_data, "ground_truth": gt_data})
break
if all_values_match:
stats["argument_value_correct"] += 1
else:
argument_name_accuracy_errors.append({"prediction": pred_data, "ground_truth": gt_data})
except Exception as e:
print(e)
continue
else:
function_accuracy_errors.append({"prediction": repr(pred_data["name"]), "ground_truth": repr(gt_data["name"])})
return {
"function_accuracy": round(stats["function_correct"] / stats["total"], 4) if stats["total"] > 0 else 0.0,
"argument_name_accuracy": round(stats["argument_name_correct"] / stats["total"], 4) if stats["function_correct"] > 0 else 0.0,
"argument_value_accuracy": round(stats["argument_value_correct"] / stats["total"], 4) if stats["argument_name_correct"] > 0 else 0.0,
"total_samples": stats["total"],
"function_correct": stats["function_correct"],
"argument_name_correct": stats["argument_name_correct"],
"argument_value_correct": stats["argument_value_correct"],
"invalid_ground_truth": stats["invalid_gt"],
"invalid_predictions": stats["invalid_pred"]
}
gt_example = [
{
"name": "getWeather",
"arguments": {
"city": "杭州",
"extensions": "all"
}
}
]
if __name__ == "__main__":
argument_parsr = ArgumentParser()
argument_parsr.add_argument("--input_path",type=str,required=True)
args = argument_parsr.parse_args()
with open(args.input_path,"r") as f:
data = json.load(f)
print(evaluate_function_calls(test_data=[data[-1].get("function_call")],eval_data=gt_example))
\ No newline at end of file
import openai
import json
import keyword
import ast
import uuid
from transformers import AutoTokenizer
from argparse import ArgumentParser
def resolve_ast_call(elem):
# Handle nested attributes for deeply nested module paths
func_parts = []
func_part = elem.func
while isinstance(func_part, ast.Attribute):
func_parts.append(func_part.attr)
func_part = func_part.value
if isinstance(func_part, ast.Name):
func_parts.append(func_part.id)
func_name = ".".join(reversed(func_parts))
args_dict = {}
for arg in elem.keywords:
output = resolve_ast_by_type(arg.value)
args_dict[arg.arg] = output
return {func_name: args_dict}
def resolve_ast_by_type(value):
if isinstance(value, ast.Constant):
output = "..." if value.value is Ellipsis else value.value
elif isinstance(value, ast.UnaryOp):
output = -value.operand.value # type: ignore
elif isinstance(value, ast.List):
output = [resolve_ast_by_type(v) for v in value.elts]
elif isinstance(value, ast.Dict):
output = {
resolve_ast_by_type(k): resolve_ast_by_type(v)
for k, v in zip(value.keys, value.values)
}
elif isinstance(
value,
ast.NameConstant): # Added this condition to handle boolean values
output = value.value
elif isinstance(
value, ast.BinOp
): # Added this condition to handle function calls as arguments
output = ast.literal_eval(ast.unparse(value)) # type: ignore
elif isinstance(value, ast.Name):
output = value.id
elif isinstance(value, ast.Call):
if len(value.keywords) == 0:
output = ast.unparse(value) # type: ignore
else:
output = resolve_ast_call(value)
elif isinstance(value, ast.Tuple):
output = tuple(resolve_ast_by_type(v) for v in value.elts)
elif isinstance(value, ast.Lambda):
output = ast.literal_eval(
ast.unparse( # type: ignore
value.body[0].value)) # type: ignore
elif isinstance(value, ast.Ellipsis):
output = "..."
elif isinstance(value, ast.Subscript):
try:
output = ast.unparse(value.body[0].value) # type: ignore
except Exception as e:
output = (
ast.unparse(value.value) + "[" + # type: ignore
ast.unparse(value.slice) + "]") # type: ignore
else:
raise Exception(f"Unsupported AST type: {type(value)}")
return output
def parse_tool_for_minicpm3(
sequence: str,
tool_call_start="<|tool_call_start|>",
tool_call_end="<|tool_call_end|>",
):
try:
if tool_call_start in sequence and tool_call_end in sequence:
tool_call_string, content = sequence.rsplit(tool_call_end, 1)
tool_call_string = tool_call_string.split(tool_call_start, 1)[1]
tool_calls = []
tool_call_string = tool_call_string.strip()
if tool_call_string.startswith("```"):
tool_call_string = tool_call_string[3:].strip()
if tool_call_string.startswith("python"):
tool_call_string = tool_call_string.lstrip(
"python").strip()
if tool_call_string.endswith("```"):
tool_call_string = tool_call_string[:-3].strip()
for kw in keyword.kwlist:
tool_call_string = tool_call_string.replace(
"," + kw + "=", "," + kw + "_=")
tool_call_string = tool_call_string.replace(
" " + kw + "=", " " + kw + "_=")
tool_call_string = tool_call_string.replace(
"(" + kw + "=", "(" + kw + "_=")
need_replace = False
replaced_tool_call_string = tool_call_string.replace("-","_")
if replaced_tool_call_string != tool_call_string:
need_replace = True
tool_call_string = replaced_tool_call_string
parsed: ast.Module = ast.parse(tool_call_string)
for elem in parsed.body:
assert isinstance(elem.value, ast.Call) # type: ignore
calls = resolve_ast_call(elem.value) # type: ignore
for func_name, func_args in calls.items():
new_args = {}
for k, v in func_args.items():
for kw in keyword.kwlist:
if k == kw + "_":
k = kw
new_args[k] = v
this_one = {"name": func_name, "arguments": new_args}
tool_calls.append({ "id":str(uuid.uuid4()),"function":this_one,"type":"function"})
if need_replace:
for tool_call in tool_calls:
tool_call["function"]["name"] = tool_call["function"]["name"].replace("_","-")
return tool_calls
else:
return []
except:
return []
def generate(
client,
tokenizer,
model : str,
messages_minicpm : list,
tools : list
):
prompt = tokenizer.apply_chat_template(
messages_minicpm, tools=tools, tokenize=False, add_generation_prompt=True
)
response = client.completions.create(model=model,prompt = prompt,max_tokens = 8192)
response_dict = response.model_dump()
first_choice = response_dict['choices'][0]
return first_choice["text"]
example_messages_history = [
{
"role": "system",
"content": "You are an intelligent assistant with access to various tools. Your task is to answer questions by using these tools when needed.\n\nCRITICAL INSTRUCTIONS:\n\n1. Tool use is expected and encouraged, especially when information cannot be inferred from the conversation context. However, if you have gathered enough information to confidently provide the final answer, you may do so directly — but only after tool usage has been attempted or proven unnecessary.\n\n2. DO NOT describe or talk about using tools — actually CALL them using the tool_calls mechanism.\n\n3. NEVER fabricate answers. Always rely on tool results or clearly indicate when no useful result is available.\n\n4. If a tool returns an error OR fails to provide useful or new information (e.g., empty results, no content, or repeated output), DO NOT call it again with the same inputs. Avoid repeating the same failed tool calls. If a tool fails, try alternative tools if available.\n\n5. You MUST consider previous tool_calls and tool responses when deciding what to do next. Use this history to avoid redundant or circular behavior.\n\n6. If ALL relevant tools have been tried and none provide helpful results, you may gracefully end the conversation with a best-effort response, acknowledging that tools did not yield a definitive answer.\n\n7. When delivering the final answer, use the following format:\n - First provide a concise analysis or summary of your reasoning and tool findings.\n - Then end with: **\"The answer is: [your final answer]\"**\n\nTECHNICAL DETAILS:\n\n- For any step involving tool use, your response must include a \"tool_calls\" field.\n- The only valid response without tool_calls is when delivering the FINAL ANSWER after attempting or ruling out tool usage.\n\nEXAMPLES OF ACCEPTABLE BEHAVIOR:\n- Trying a tool, analyzing the response, and choosing a different tool when appropriate.\n- Avoiding re-use of failed tool calls by checking prior results.\n- Stopping and concluding if all tool paths have been exhausted.\n\nNEVER:\n- Repeat failed tool calls unnecessarily.\n- Respond with general knowledge if tools are required to verify the answer.\n\nRemember: your goal is to reason with tool assistance. Use tools thoughtfully and adaptively to solve the user's question.\n "
},
{
"role": "user",
"content": "I'm searching for movie theaters in Hangzhou and wondering about the weather forecast for this evening."
},
{
"role": "assistant",
"content": "<|tool_call_start|>\n```python\nsearchPOI(city=\"杭州\",extensions=\"base\",keywords=\"电影院\")\n```\n<|tool_call_end|>\nI'll help you find movie theaters in Hangzhou and check the weather forecast for this evening. Let me gather that information for you."
},
{
"role": "tool",
"content": "{\"status\":\"1\",\"count\":229,\"pois\":[{\"id\":\"B0FFIQG5YR\",\"name\":\"万达影城(砂之船国际生活广场店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"解放东路8号砂之船国际生活广场B1层\",\"location\":\"120.213532,30.244310\",\"tel\":\"0571-81106343;0571-81106969\",\"distance\":[],\"photos\":[]},{\"id\":\"B0H63L9Z18\",\"name\":\"SFC上影国际影城(高德置地广场店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"市民中心站G口旁(高德置地广场5楼)\",\"location\":\"120.208656,30.242733\",\"tel\":\"0571-87390565\",\"distance\":[],\"photos\":[]},{\"id\":\"B0J0XOKKZF\",\"name\":\"星光嘉映影城(杭州来福士店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"来福士中心07层03号\",\"location\":\"120.213035,30.248775\",\"tel\":\"17300922881\",\"distance\":[],\"photos\":[]},{\"id\":\"B0I6AMXU40\",\"name\":\"万象影城(杭州万象城店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"富春路701号杭州万象城5楼\",\"location\":\"120.214835,30.251509\",\"tel\":\"15657105178\",\"distance\":[],\"photos\":[]},{\"id\":\"B023B0BI2S\",\"name\":\"卢米埃影城(银泰百货庆春店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"景昙路18-26号庆春银泰6层\",\"location\":\"120.205439,30.260063\",\"tel\":[],\"distance\":[],\"photos\":[]},{\"id\":\"B0FFG10DAO\",\"name\":\"金逸影城-脱口秀剧场(杭州五福天虹购物中心B座店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"新塘路108号杭州五福天虹购物中心B座6层606号\",\"location\":\"120.209629,30.264973\",\"tel\":\"0571-87702702;0571-87706773\",\"distance\":[],\"photos\":[]},{\"id\":\"B0G0SZ0C22\",\"name\":\"保利万和CFR国际影城(钱江世纪城店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"钱江世纪城钱江世纪公园A区13幢\",\"location\":\"120.239038,30.244492\",\"tel\":\"0571-83822017\",\"distance\":[],\"photos\":[]},{\"id\":\"B0FFFV2EZ4\",\"name\":\"德信影城(万泰城店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"望江东路332号万泰城3-4层\",\"location\":\"120.192637,30.229449\",\"tel\":\"0571-87173399;0571-87710207\",\"distance\":[],\"photos\":[]},{\"id\":\"B0FFIQ15SN\",\"name\":\"至潮影城(庆春路店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"庆春路87号锦和大厦5层\",\"location\":\"120.176942,30.257545\",\"tel\":\"13336013203\",\"distance\":[],\"photos\":[]},{\"id\":\"B0H2PCRWZ9\",\"name\":\"INF无限时空电影剧场\",\"type\":\"体育休闲服务;影剧院;剧场\",\"address\":\"建国北路286号凤起农贸市场一号楼二层(桐江小院楼上)\",\"location\":\"120.181682,30.265485\",\"tel\":\"17767151776\",\"distance\":[],\"photos\":[]},{\"id\":\"B0FFI2YFAE\",\"name\":\"中影国际影城(星光大道二期店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"闻涛路1766号华联·星光大道2期4层\",\"location\":\"120.206202,30.211833\",\"tel\":\"0571-88997727\",\"distance\":[],\"photos\":[]},{\"id\":\"B0GRC5CDQF\",\"name\":\"西戏·XIXI LIVE(星澜里店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"西兴街道西兴路2333号星澜大厦4幢301室\",\"location\":\"120.221481,30.215414\",\"tel\":\"13516855490\",\"distance\":[],\"photos\":[]},{\"id\":\"B0GR07XGIF\",\"name\":\"海马国际影城(江和美亲子广场店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"江和美亲子广场4层(三堡地铁站B1口步行230米)\",\"location\":\"120.227284,30.268501\",\"tel\":\"0571-85771502\",\"distance\":[],\"photos\":[]},{\"id\":\"B0J6CCKZ3T\",\"name\":\"万达影城(星耀城1期九宜城店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"星耀城1期3层(江陵路地铁站A口步行350米)\",\"location\":\"120.215352,30.212717\",\"tel\":\"0571-81138116\",\"distance\":[],\"photos\":[]},{\"id\":\"B023B08Q8G\",\"name\":\"新华影都(庆春路店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"娃哈哈同乐舫大酒店3层\",\"location\":\"120.169271,30.257827\",\"tel\":\"0571-87046523;0571-87212554\",\"distance\":[],\"photos\":[]},{\"id\":\"B0FFG58CAX\",\"name\":\"海上明珠国际影城(银泰百货西湖店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"延安路98号西湖银泰城A馆5层\",\"location\":\"120.164970,30.243828\",\"tel\":\"0571-87002038\",\"distance\":[],\"photos\":[]},{\"id\":\"B023B19440\",\"name\":\"中影·国际影城(杭州滨江星光大道店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"江南大道228号星光国际广场2幢4层\",\"location\":\"120.209128,30.208094\",\"tel\":\"0571-88924880;0571-88924988\",\"distance\":[],\"photos\":[]},{\"id\":\"B0IAJ167YH\",\"name\":\"尚橙电影工场(利星名品广场店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"中山南路77号利星名品广场4层\",\"location\":\"120.168520,30.228409\",\"tel\":\"0571-56668320;0571-56668321;13646859933\",\"distance\":[],\"photos\":[]},{\"id\":\"B0HADOGYLE\",\"name\":\"百美汇影城(杭州嘉里中心店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"延安路385号杭州嘉里中心L4层\",\"location\":\"120.161926,30.260416\",\"tel\":\"0571-81181775\",\"distance\":[],\"photos\":[]},{\"id\":\"B0H6OAS72Q\",\"name\":\"德信影城(杭州之翼购物中心店)\",\"type\":\"体育休闲服务;影剧院;电影院\",\"address\":\"鸿泰路133号杭州之翼购物中心4F\",\"location\":\"120.221653,30.294449\",\"tel\":\"0571-88881601\",\"distance\":[],\"photos\":[]}],\"mapUrl\":\"https://restapi.amap.com/v3/staticmap?key=9baeae2ef50243c0d49141fa31c1dcf9&zoom=12&size=750*500&scale=2&location=120.213532,30.244310&markers=mid,0xFF0000,A:120.213532,30.244310&markers=mid,0xFF0000,B:120.208656,30.242733&markers=mid,0xFF0000,C:120.213035,30.248775&markers=mid,0xFF0000,D:120.214835,30.251509&markers=mid,0xFF0000,E:120.205439,30.260063&markers=mid,0xFF0000,F:120.209629,30.264973&markers=mid,0xFF0000,G:120.239038,30.244492&markers=mid,0xFF0000,H:120.192637,30.229449&markers=mid,0xFF0000,I:120.176942,30.257545&markers=mid,0xFF0000,J:120.181682,30.265485&markers=mid,0xFF0000,K:120.206202,30.211833&markers=mid,0xFF0000,L:120.221481,30.215414&markers=mid,0xFF0000,M:120.227284,30.268501&markers=mid,0xFF0000,N:120.215352,30.212717&markers=mid,0xFF0000,O:120.169271,30.257827&markers=mid,0xFF0000,P:120.164970,30.243828&markers=mid,0xFF0000,Q:120.209128,30.208094&markers=mid,0xFF0000,R:120.168520,30.228409&markers=mid,0xFF0000,S:120.161926,30.260416&markers=mid,0xFF0000,T:120.221653,30.294449&path=120.213532,30.244310;120.208656,30.242733;120.213035,30.248775;120.214835,30.251509;120.205439,30.260063;120.209629,30.264973;120.239038,30.244492;120.192637,30.229449;120.176942,30.257545;120.181682,30.265485;120.206202,30.211833;120.221481,30.215414;120.227284,30.268501;120.215352,30.212717;120.169271,30.257827;120.164970,30.243828;120.209128,30.208094;120.168520,30.228409;120.161926,30.260416;120.221653,30.294449\"}",
}
]
with open("available_tool_example.json","r") as f:
available_tools = json.load(f)["available_tools"]
if __name__ == "__main__":
argument_parser = ArgumentParser()
argument_parser.add_argument("--tokenizer_path",type=str,required=True)
argument_parser.add_argument("--base_url",type=str,required=True)
argument_parser.add_argument("--model",type=str,required=True)
argument_parser.add_argument("--output_path",type=str,required=True)
args = argument_parser.parse_args()
client = openai.OpenAI(
api_key="1",
base_url=args.base_url
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
generate_result = generate(
client=client,
tokenizer=tokenizer,
model=args.model,
messages_minicpm=example_messages_history,
tools=available_tools
)
tool_calls = parse_tool_for_minicpm3(generate_result)
result = []
for message in example_messages_history:
if message["role"] == "system":
result.append(
{"system":message["content"]}
)
elif message["role"] == "user":
result.append(
{"human":message["content"]}
)
elif message["role"] == "assistant":
history_tool_calls = parse_tool_for_minicpm3(
message["content"]
)
result.append(
{
"gpt":message["content"],
"function_call":history_tool_calls
}
)
elif message["role"] == "tool":
if "observation" not in result[-1]:
result[-1]["observation"] = []
result[-1]["observation"].append(message["content"])
result.append(
{
"gpt": generate_result,
"function_call":tool_calls
}
)
with open(args.output_path,"w") as f:
json.dump(result,f,ensure_ascii=False,indent=4)
\ No newline at end of file
accelerate==1.2.1
datasets==3.2.0
deepspeed==0.16.5
einops==0.8.1
flash_attn==2.7.4.post1
flashinfer-python==0.2.5
huggingface-hub==0.30.2
peft==0.15.1
rouge-chinese==1.0.3
safetensors==0.5.3
sentencepiece==0.2.0
tokenizers==0.21.0
torch==2.6.0
transformers==4.49.0
trl==0.9.6
vllm==0.8.4
wandb==0.19.9
xformers==0.0.29.post2
\ No newline at end of file
<div align="center">
<img src="./assets/main.png" alt="MiniCPM4-Survey MAIN" width="400em"></img>
</div>
<p align="center">
【English | <a href="README.md">中文</a>
</p>
<p align="center">
<a href="https://huggingface.co/openbmb/MiniCPM4-Survey">Model</a>
<a href="https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf">Technical Report</a>
</p>
## News
* [2025-06-05] 🚀🚀🚀 We have open-sourced **MiniCPM4-Survey**, a model built upon MiniCPM4-8B that is capable of generating trustworthy, long-form survey papers while maintaining competitive performance relative to significantly larger models.
## Overview
**MiniCPM4-Survey** is an open-source LLM agent model jointly developed by [THUNLP](https://nlp.csai.tsinghua.edu.cn), Renmin University of China and [ModelBest](https://modelbest.cn/en). Built on [MiniCPM4](https://github.com/OpenBMB/MiniCPM4) with 8 billion parameters, it accepts users' quiries as input and autonomously generate trustworthy, long-form survey papers.
Key features include:
- **Plan-Retrieve-Write Survey Generation Framework** — We propose a multi-agent generation framework, which operates through three core stages: planning (defining the overall structure of the survey), retrieval (generating appropriate retrieval keywords), and writing (synthesizing the retrieved information to generate coherent section-level content).
- **High-Quality Dataset Construction** — We gather and process lots of expert-written survey papers to construct a high-quality training dataset. Meanwhile, we collect a large number of research papers to build a retrieval database.
- **Multi-Aspect Reward Design** — We carefully design a reward system with three aspects (structure, content, and citations) to evaluate the quality of the surveys, which is used as the reward function in the RL training stage.
- **Multi-Step RL Training Strategy** — We propose a *Context Manager* to ensure retention of essential information while facilitating efficient reasoning, and we construct *Parallel Environment* to maintain efficient RL training cycles.
**Demo Case**:
https://github.com/user-attachments/assets/b47b1567-8af1-4be0-bec4-8f86f2207d60
## Quick Start
### Download the model
Download [MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey) from Hugging Face and place it in `model/MiniCPM4-Survey`.
We recommend using [MiniCPM-Embedding-Light](https://huggingface.co/openbmb/MiniCPM-Embedding-Light) as the embedding model, which can be downloaded from Hugging Face and placed in `model/MiniCPM-Embedding-Light`.
### Perpare the environment
You can download the [paper data](https://www.kaggle.com/datasets/Cornell-University/arxiv) from Kaggle, then extract it. You can run `python data_process.py` to process the data and generate the retrieval database. Then you can run `python build_index.py` to build the retrieval database.
```
curl -L -o ~/Downloads/arxiv.zip\
https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
unzip ~/Downloads/arxiv.zip -d .
mkdir data
python ./src/preprocess/data_process.py
mkdir index
python ./src/preprocess/build_index.py
```
### Model Inference
You can run the following command to build the retrieval environment and start the inference:
```bash
python ./src/retriever.py
bash ./scripts/run.sh
```
If you want to run with the frontend, you can run the following command:
```bash
python ./src/retriever.py
bash ./scripts/run_with_frontend.sh
cd frontend/minicpm4-survey
npm install
npm run dev
```
Then you can visit `http://localhost:5173` in your browser to use the model.
## Performance Evaluation
| Method | Relevance | Coverage | Depth | Novelty | Avg. | Fact Score |
|---------------------------------------------|-----------|----------|-------|---------|-------|------------|
| Naive RAG (driven by G2FT) | 3.25 | 2.95 | 3.35 | 2.60 | 3.04 | 43.68 |
| AutoSurvey (driven by G2FT) | 3.10 | 3.25 | 3.15 | **3.15**| 3.16 | 46.56 |
| Webthinker (driven by WTR1-7B) | 3.30 | 3.00 | 2.75 | 2.50 | 2.89 | -- |
| Webthinker (driven by QwQ-32B) | 3.40 | 3.30 | 3.30 | 2.50 | 3.13 | -- |
| OpenAI Deep Research (driven by GPT-4o) | 3.50 |**3.95** | 3.55 | 3.00 | **3.50** | -- |
| MiniCPM4-Survey | 3.45 | 3.70 | **3.85** | 3.00 | **3.50** | **68.73** |
| &nbsp;&nbsp;&nbsp;*w/o* RL | **3.55** | 3.35 | 3.30 | 2.25 | 3.11 | 50.24 |
*Performance comparison of the survey generation systems. "G2FT" stands for Gemini-2.0-Flash-Thinking, and "WTR1-7B" denotes Webthinker-R1-7B. FactScore evaluation was omitted for Webthinker, as it does not include citation functionality, and for OpenAI Deep Research, which does not provide citations when exporting the results.*
<div align="center">
<img src="./assets/main.png" alt="MiniCPM4-Survey MAIN" width="400em"></img>
</div>
<p align="center">
<a href="README-en.md">English</a> | 中文】
</p>
<p align="center">
<a href="https://huggingface.co/openbmb/MiniCPM4-Survey">模型</a>
<a href="https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf">技术报告</a>
</p>
## News
* [2025-06-05] 🚀🚀🚀我们开源了基于MiniCPM4-8B构建的MiniCPM4-Survey,能够生成可信的长篇调查报告,性能比肩更大模型。
## 概览
MiniCPM4-Survey是由[THUNLP](https://nlp.csai.tsinghua.edu.cn)、中国人民大学和[ModelBest](https://modelbest.cn)联合开发的开源大语言模型智能体。它基于[MiniCPM4](https://github.com/OpenBMB/MiniCPM4) 80亿参数基座模型,接受用户质量作为输入,自主生成可信的长篇综述论文。
主要特性包括:
- 计划-检索-写作生成框架 — 我们提出了一个多智能体生成框架,包含三个核心阶段:计划(定义综述的整体结构)、检索(生成合适的检索关键词)和写作(利用检索到的信息,生成连贯的段落)。
- 高质量数据集构建——我们收集并处理大量人类专家写作的综述论文,构建高质量训练集。同时,我们收集大量研究论文,构建检索数据库。
- 多方面奖励设计 — 我们精心设计了包含结构、内容和引用的奖励,用于评估综述的质量,在强化学习训练阶段作奖励函数。
- 多步强化学习训练策略 — 我们提出了一个上下文管理器,以确保在促进有效推理的同时保留必要的信息,并构建了并行环境,维持强化学习训练高效。
**Demo**:
https://github.com/user-attachments/assets/b47b1567-8af1-4be0-bec4-8f86f2207d60
## 使用
### 下载模型
从 Hugging Face 下载[MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey)并将其放在model/MiniCPM4-Survey中。
我们建议使用[MiniCPM-Embedding-Light](https://huggingface.co/openbmb/MiniCPM-Embedding-Light)作为表征模型,放在model/MiniCPM-Embedding-Light中。
### 准备环境
从 Kaggle 下载论文数据,然后解压。运行`python data_process.py`,处理数据并生成检索数据库。然后运行`python build_index.py`,构建检索数据库。
``` bash
curl -L -o ~/Downloads/arxiv.zip\
https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
unzip ~/Downloads/arxiv.zip -d .
mkdir data
python ./src/preprocess/data_process.py
mkdir index
python ./src/preprocess/build_index.py
```
### 模型推理
运行以下命令来构建检索环境并开始推理:
``` bash
python ./src/retriever.py
bash ./scripts/run.sh
```
如果您想使用前端运行,可以运行以下命令:
``` bash
python ./src/retriever.py
bash ./scripts/run_with_frontend.sh
cd frontend/minicpm4-survey
npm install
npm run dev
```
然后你可以在浏览器中访问`http://localhost:5173`使用。
## 性能
| Method | Relevance | Coverage | Depth | Novelty | Avg. | Fact Score |
|---------------------------------------------|-----------|----------|-------|---------|-------|------------|
| Naive RAG (driven by G2FT) | 3.25 | 2.95 | 3.35 | 2.60 | 3.04 | 43.68 |
| AutoSurvey (driven by G2FT) | 3.10 | 3.25 | 3.15 | **3.15**| 3.16 | 46.56 |
| Webthinker (driven by WTR1-7B) | 3.30 | 3.00 | 2.75 | 2.50 | 2.89 | -- |
| Webthinker (driven by QwQ-32B) | 3.40 | 3.30 | 3.30 | 2.50 | 3.13 | -- |
| OpenAI Deep Research (driven by GPT-4o) | 3.50 |**3.95** | 3.55 | 3.00 | **3.50** | -- |
| MiniCPM4-Survey | 3.45 | 3.70 | **3.85** | 3.00 | **3.50** | **68.73** |
| &nbsp;&nbsp;&nbsp;*w/o* RL | **3.55** | 3.35 | 3.30 | 2.25 | 3.11 | 50.24 |
*GPT-4o对综述生成系统的性能比较。“G2FT”代表Gemini-2.0-Flash-Thinking,“WTR1-7B”代表Webthinker-R1-7B。由于Webthinker不包括引用功能,OpenAI Deep Research在导出结果时不提供引用,因此省略了对它们的FactScore评估。我们的技术报告中包含评测的详细信息。*
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
export default [
{ ignores: ['dist'] },
{
files: ['**/*.{js,jsx}'],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
parserOptions: {
ecmaVersion: 'latest',
ecmaFeatures: { jsx: true },
sourceType: 'module',
},
},
plugins: {
'react-hooks': reactHooks,
'react-refresh': reactRefresh,
},
rules: {
...js.configs.recommended.rules,
...reactHooks.configs.recommended.rules,
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
},
]
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/png" href="/openbmb.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>MiniCPM4-Survey</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment