Commit 0505b7c0 authored by mashun1's avatar mashun1
Browse files

toolace

parents
Pipeline #2723 canceled with stages
__pycache__
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
\ No newline at end of file
# ToolACE
## 论文
`ToolACE: Winning the Points of LLM Function Calling`
* https://arxiv.org/pdf/2409.00920
## 模型结构
模型LLama3.1训练
![alt text](readme_imgs/arch.png)
## 算法原理
Multi-Head Attention是一种并行注意力机制,它通过多个子空间中的注意力头协同工作,从不同角度捕捉序列中元素之间的关系,从而增强模型的表达能力。
![alt text](readme_imgs/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
docker run --shm-size 100g --network=host --name=toolace --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -U transformers
pip install accelerate
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 100g --network=host --name=toolace --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -U transformers
pip install accelerate
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.sourcefind.cn/tool/
```
DTK驱动:dtk25.04
python:python3.10
torch:2.4.1
triton:3.0
flash-attn:2.6.1
deepspeed:0.14.2
apex:1.4.0
```
2、其他非特殊库直接按照requirements.txt安装
```
pip install -U transformers
pip install accelerate
```
## 数据集
[ToolACE](https://huggingface.co/datasets/Team-ACE/ToolACE)
本项目提供适用于`llama-factory`的数据格式,位于`datasets/data_1.json`
## 训练
### 安装llama-factory
详情参考[llama-factory](https://developer.sourcefind.cn/codes/OpenDAS/llama-factory/-/tree/0.9.2-parallel_tool).
### Lora
```bash
llamafactory-cli train examples/train_lora/llama3_lora_sft_ds3.yaml
```
在运行前需要修改`examples/train_lora/llama3_lora_sft_ds3.yaml`中相应参数。
## 推理
```bash
python inference.py
```
## result
```python
[sales_growth.calculate(company="XYZ", years=3), financial_ratios interest_coverage(company_name="XYZ", years=3)]
```
注意:结果可能根据参数不同而变化。
### 精度
loss: 0.3
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`电商,教育,广媒`
## 预训练权重
[ToolACE-2-Llama-3.1-8B](https://hf-mirror.com/Team-ACE/ToolACE-2-Llama-3.1-8B)
## 源码仓库及问题反馈
*
## 参考资料
* https://hf-mirror.com/Team-ACE/ToolACE-2-Llama-3.1-8B
This diff is collapsed.
icon.png

53.8 KB

from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Team-ACE/ToolACE-2-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype='auto',
device_map='auto'
)
# You can modify the prompt for your task
system_prompt = """You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out.
You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.\n{functions}\n
"""
# User query
query = "Find me the sales growth rate for company XYZ for the last 3 years and also the interest coverage ratio for the same duration."
# Availabel tools in JSON format (OpenAI-format)
tools = [
{
"name": "financial_ratios.interest_coverage", "description": "Calculate a company's interest coverage ratio given the company name and duration",
"arguments": {
"type": "dict",
"properties": {
"company_name": {
"type": "string",
"description": "The name of the company."
},
"years": {
"type": "integer",
"description": "Number of past years to calculate the ratio."
}
},
"required": ["company_name", "years"]
}
},
{
"name": "sales_growth.calculate",
"description": "Calculate a company's sales growth rate given the company name and duration",
"arguments": {
"type": "dict",
"properties": {
"company": {
"type": "string",
"description": "The company that you want to get the sales growth rate for."
},
"years": {
"type": "integer",
"description": "Number of past years for which to calculate the sales growth rate."
}
},
"required": ["company", "years"]
}
},
{
"name": "weather_forecast",
"description": "Retrieve a weather forecast for a specific location and time frame.",
"arguments": {
"type": "dict",
"properties": {
"location": {
"type": "string",
"description": "The city that you want to get the weather for."
},
"days": {
"type": "integer",
"description": "Number of days for the forecast."
}
},
"required": ["location", "days"]
}
}
]
messages = [
{'role': 'system', 'content': system_prompt.format(functions=tools)},
{'role': 'user', 'content': query}
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True))
# 模型编码
modelCode=1558
# 模型名称
modelName=ToolACE_pytorch
# 模型描述
modelDescription=华为研发的工具调用模型
# 应用场景
appScenario=训练,推理,对话问答,电商,教育,广媒
# 框架类型
frameType=pytorch
import json
import re
def is_function_call(text):
"""
判断文本是否符合函数调用列表格式:
[Func1(arg="val"), Func2(...)]
"""
text = text.strip()
if not (text.startswith('[') and text.endswith(']')):
return False
inner = text[1:-1].strip()
pattern = re.compile(r'([^\(\)]+)\(([^()]*)\)')
matches = pattern.findall(inner)
if not matches:
return False
replaced = pattern.sub('', inner)
replaced = replaced.replace(',', '').replace(' ', '')
return replaced == ''
def parse_function_call_list(text):
"""
解析函数调用字符串为JSON数组格式
示例:
[SEC Filings(identifier="AAPL"), United States Away from Home Mobility API(string="2025-05-17")]
转换成:
[
{"name": "SEC Filings", "arguments": {"identifier": "AAPL"}},
{"name": "United States Away from Home Mobility API", "arguments": {"string": "2025-05-17"}}
]
"""
inner = text.strip()[1:-1].strip()
pattern = re.compile(r'([^\(\)]+)\((.*?)\)')
matches = pattern.findall(inner)
functions = []
for func_name, args_str in matches:
func_name = func_name.strip()
args = {}
if args_str.strip():
# 支持多个参数,形如 key="value", key2="value2"
parts = re.split(r',\s*(?=\w+=)', args_str)
for part in parts:
key_val = part.split('=', 1)
if len(key_val) == 2:
key = key_val[0].strip()
val = key_val[1].strip()
if (val.startswith('"') and val.endswith('"')) or (val.startswith("'") and val.endswith("'")):
val = val[1:-1]
args[key] = val
functions.append({
"name": func_name,
"arguments": args
})
return functions
def convert_conversation(conversations):
converted = []
for message in conversations:
role = message.get("from", "")
value = message.get("value", "")
if role == "user":
converted.append({
"from": "human",
"value": value
})
elif role == "assistant":
# 判断是否函数调用
if is_function_call(value):
parsed_funcs = parse_function_call_list(value)
converted.append({
"from": "function_call",
"value": json.dumps(parsed_funcs, ensure_ascii=False)
})
else:
converted.append({
"from": "gpt",
"value": value
})
elif role == "tool":
converted.append({
"from": "observation",
"value": value
})
return converted
def transform_data(data):
"""
对整体data进行转换,保留system,转换conversations
"""
result = []
for item in data:
system_text = item.get("system", "")
conversations = item.get("conversations", [])
converted_conversations = convert_conversation(conversations)
result.append({
"system": system_text,
"conversations": converted_conversations
})
return result
def main(input_file, output_file):
with open(input_file, "r", encoding="utf-8") as f:
data = json.load(f)
transformed = transform_data(data)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(transformed, f, ensure_ascii=False, indent=2)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--input_file", type=str)
parser.add_argument("--output_file", type=str)
args = parser.parse_args()
main(args.input_file, args.output_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment