# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert JoeYing/ReTool-SFT to standard multi-turn tool calling messages.
"""
import json
import os
import re
from typing import Any
import datasets
from omegaconf import OmegaConf
code_pattern = re.compile(r"```python(.*?)```", re.DOTALL)
def extract_code_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "", ""
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
code = content[i + len(start) : j]
matches = code_pattern.findall(code)
if matches:
code = matches[0].strip()
message = {
"role": "assistant",
"content": content[:i].strip(),
"tool_calls": [
{
"type": "function",
"function": {
"name": "code_interpreter",
"arguments": {"code": code},
},
},
],
}
return message, content[j + len(stop) :]
def extract_answer_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "", ""
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
answer = content[:i] + content[i + len(start) : j]
message = {
"role": "assistant",
"content": answer.strip(),
}
return message, content[j + len(stop) :]
def extract_interpreter_message(content: str) -> tuple[dict[str, Any], str]:
start, stop = "", ""
i = content.find(start)
if i == -1:
return None, content
j = content.find(stop)
assert j > i
interpreter = content[i + len(start) : j]
message = {
"role": "tool",
"content": interpreter.strip(),
}
return message, content[j + len(stop) :]
def process(row: dict, *, tools: str):
messages = []
# extract problem
content = row["messages"][0]["content"]
start = "*user question:*"
i = content.find(start)
assert i != -1
prompt = content[i + len(start) :].replace("", "").replace("", "").strip()
messages.append(
{
"role": "user",
"content": prompt,
}
)
# extract multi turns
content = row["messages"][1]["content"]
role = "assistant"
while len(content) > 0:
if role == "assistant":
message, content = extract_code_message(content)
if message is None:
message, content = extract_answer_message(content)
assert message is not None
messages.append(message)
role = "tool"
else:
message, content = extract_interpreter_message(content)
assert message is not None
messages.append(message)
role = "assistant"
tools = json.loads(tools)
return {"messages": messages, "tools": tools}
if __name__ == "__main__":
tools_config_file = "recipe/retool/sandbox_fusion_tool_config.yaml"
tools_config = OmegaConf.load(tools_config_file)
tool_schema = OmegaConf.to_container(tools_config["tools"][0]["tool_schema"])
tools = json.dumps([tool_schema])
data = datasets.load_dataset("JoeYing/ReTool-SFT")["train"]
data = data.map(process, fn_kwargs={"tools": tools})
save_path = os.path.expanduser("~/ReTool-SFT/data/train-00000-of-00001.parquet")
data.to_parquet(save_path)