# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import json import logging import os from typing import Any from uuid import uuid4 from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, register from verl.experimental.agent_loop.tool_parser import FunctionCall, ToolParser from verl.tools.utils.tool_registry import initialize_tools_from_config from verl.utils.profiler import simple_timer from verl.utils.rollout_trace import rollout_trace_op logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @register("tool_agent") class ToolAgentLoop(AgentLoopBase): @classmethod def init_class(cls, config, tokenizer, **kwargs): if cls._class_initialized: return cls._class_initialized = True print("Performing class-level ToolAgentLoop initialization") # Initialize tools from config file cls.tokenizer = tokenizer cls.max_user_turns = config.actor_rollout_ref.rollout.multi_turn.max_user_turns cls.max_assistant_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else [] cls.tools = {tool.name: tool for tool in tool_list} cls.tool_schemas = [tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True) for tool in tool_list] cls.tool_parser = ToolParser.get_tool_parser(config.actor_rollout_ref.rollout.multi_turn.format, cls.tokenizer) print(f"Initialized tools: {cls.tools}") cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length cls.response_length = config.actor_rollout_ref.rollout.response_length cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True) @rollout_trace_op async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput: metrics = {} request_id = uuid4().hex prompt_ids = await self.loop.run_in_executor( None, lambda: self.tokenizer.apply_chat_template( messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True ), ) response_mask = [] user_turns, assistant_turns = 0, 0 while True: with simple_timer("generate_sequences", metrics): response_ids = await self.server_manager.generate( request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params ) prompt_ids += response_ids response_mask += [1] * len(response_ids) assistant_turns += 1 # reach max response length if len(response_mask) >= self.response_length: break # reach max assistant turns if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns: break # reach max user turns if self.max_user_turns and user_turns >= self.max_user_turns: break # no tool calls _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids) if not tool_calls: break # call tools tasks = [] for tool_call in tool_calls[: self.max_parallel_calls]: tasks.append(self._call_tool(tool_call)) with simple_timer("tool_calls", metrics): tool_responses = await asyncio.gather(*tasks) if any(isinstance(item, Exception) for item in tool_responses): break # append tool_response_ids tool_response_ids = await self.loop.run_in_executor( None, lambda messages=tool_responses: self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True ), ) tool_response_ids = tool_response_ids[len(self.system_prompt) :] # NOTE: last turn should not be user turn, or the EOS token reward # can't be propagated to previous token in GAE. if len(response_mask) + len(tool_response_ids) >= self.response_length: break prompt_ids += tool_response_ids response_mask += [0] * len(tool_response_ids) user_turns += 1 response_ids = prompt_ids[-len(response_mask) :] prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)] output = AgentLoopOutput( prompt_ids=prompt_ids, response_ids=response_ids[: self.response_length], response_mask=response_mask[: self.response_length], num_turns=user_turns + assistant_turns + 1, metrics=metrics, ) return output async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]: """Call tool and return tool response.""" tool, instance_id = None, None try: # TODO: append malformed tool_call to the prompt: invalid function name or arguments tool_name = tool_call.name tool_args = json.loads(tool_call.arguments) tool = self.tools[tool_name] instance_id = await tool.create() tool_response, _, _ = await tool.execute(instance_id, tool_args) except Exception as e: logger.exception(f"Error when executing tool: {e}") return e finally: if tool and instance_id: await tool.release(instance_id) if len(tool_response) > self.max_tool_response_length: if self.tool_response_truncate_side == "left": tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)" elif self.tool_response_truncate_side == "right": tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :] else: length = self.max_tool_response_length // 2 tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:] return { "role": "tool", "content": tool_response, }