# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import json import logging import os from abc import ABC, abstractmethod import regex as re from pydantic import BaseModel from verl.utils.rollout_trace import rollout_trace_op logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class FunctionCall(BaseModel): arguments: str """ The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. """ name: str """The name of the function to call.""" class ToolParser(ABC): _registry: dict[str, type["ToolParser"]] = {} def __init__(self, tokenizer) -> None: self.tokenizer = tokenizer @abstractmethod async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: """Extract tool calls from the responses. Args: responses_ids (List[int]): The ids of the responses. Returns: Tuple[str, List[FunctionCall]]: Content and extracted tool calls. """ raise NotImplementedError @classmethod def get_tool_parser(cls, name: str, tokenizer): if name not in cls._registry: raise ValueError(f"Unknown tool parser: {name}") return cls._registry[name](tokenizer) @classmethod def register(cls, name: str): def decorator(subclass: type[ToolParser]) -> type[ToolParser]: cls._registry[name] = subclass return subclass return decorator @ToolParser.register("hermes") class HermesToolParser(ToolParser): """Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py""" def __init__(self, tokenizer) -> None: super().__init__(tokenizer) self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_regex = re.compile(r"(.*?)", re.DOTALL) @rollout_trace_op async def extract_tool_calls(self, responses_ids: list[int]) -> tuple[str, list[FunctionCall]]: loop = asyncio.get_running_loop() text = await loop.run_in_executor(None, self.tokenizer.decode, responses_ids) if self.tool_call_start_token not in text or self.tool_call_end_token not in text: return text, [] matches = self.tool_call_regex.findall(text) function_calls = [] for match in matches: try: function_call = json.loads(match) name, arguments = function_call["name"], function_call["arguments"] function_calls.append(FunctionCall(name=name, arguments=json.dumps(arguments, ensure_ascii=False))) except Exception as e: logger.error(f"Failed to decode tool call: {e}") # remaing text exclude tool call tokens content = self.tool_call_regex.sub("", text) return content, function_calls