#!/usr/bin/env python # coding=utf-8 # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import logging import re from typing import Any, Callable, Dict, List, Tuple, Union from .. import is_torch_available from ..utils import logging as transformers_logging from ..utils.import_utils import is_pygments_available from .agent_types import AgentAudio, AgentImage, AgentText from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .llm_engine import HfEngine, MessageRole from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT from .python_interpreter import evaluate_python_code from .tools import ( DEFAULT_TOOL_DESCRIPTION_TEMPLATE, Tool, get_tool_description_with_args, load_tool, ) if is_pygments_available(): from pygments import highlight from pygments.formatters import Terminal256Formatter from pygments.lexers import PythonLexer class CustomFormatter(logging.Formatter): grey = "\x1b[38;20m" bold_yellow = "\x1b[33;1m" red = "\x1b[31;20m" green = "\x1b[32;20m" bold_red = "\x1b[31;1m" bold_white = "\x1b[37;1m" reset = "\x1b[0m" format = "%(message)s" FORMATS = { logging.DEBUG: grey + format + reset, logging.INFO: format, logging.WARNING: bold_yellow + format + reset, 31: reset + format + reset, 32: green + format + reset, 33: bold_white + format + reset, logging.ERROR: red + format + reset, logging.CRITICAL: bold_red + format + reset, } def format(self, record): log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt) return formatter.format(record) logger = transformers_logging.get_logger(__name__) logger.propagate = False ch = logging.StreamHandler() ch.setFormatter(CustomFormatter()) logger.addHandler(ch) def parse_json_blob(json_blob: str) -> Dict[str, str]: try: first_accolade_index = json_blob.find("{") last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1] json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace('\\"', "'") json_data = json.loads(json_blob, strict=False) return json_data except json.JSONDecodeError as e: place = e.pos raise ValueError( f"The JSON blob you used is invalid: due to the following error: {e}. JSON blob was: {json_blob}, decoding failed at '{json_blob[place-4:place+5]}'." ) except Exception as e: raise ValueError(f"Error in parsing the JSON blob: {e}") def parse_code_blob(code_blob: str) -> str: try: pattern = r"```(?:py|python)?\n(.*?)```" match = re.search(pattern, code_blob, re.DOTALL) return match.group(1).strip() except Exception as e: raise ValueError( f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}" ) def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: json_blob = json_blob.replace("```json", "").replace("```", "") tool_call = parse_json_blob(json_blob) if "action" in tool_call and "action_input" in tool_call: return tool_call["action"], tool_call["action_input"] else: raise ValueError( f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" ) def parse_text_tool_call(text: str) -> Tuple[str, Union[str, Dict[str, str]]]: """ Expects a text in the format: 'Action:', 'Action input:', 'Observation:'. 'Action input:' contains a json string with input arguments. """ try: if "Observation:" in text: text = text.split("Observation:")[0] if "Action:" in text: text = text.split("Action:")[1] tool_name, tool_input = text.split("Action input:") if "{" in tool_input: tool_input = parse_json_blob(tool_input) else: tool_input = tool_input.strip().replace('"', "") return tool_name.strip().replace('"', "").replace("\\", ""), tool_input except Exception as e: raise ValueError( f"Error in parsing the text tool call: {e}. Be sure to provide the correct format. DO NOT repeat your previous incorrect tool call." ) def to_text(input: Union[List[Dict[str, str]], Dict[str, str], str]) -> str: if isinstance(input, list): return "\n".join([m["content"] for m in input]) elif isinstance(input, dict): return input["content"] else: return input HUGGINGFACE_DEFAULT_TOOLS = {} _tools_are_initialized = False class Toolbox: """ The toolbox contains all tools that the agent can perform operations with, as well as a few methods to manage them. Args: tools (`List[Tool]`): The list of tools to instantiate the toolbox with add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to add the tools available within `transformers` to the toolbox. """ def __init__(self, tools: List[Tool], add_base_tools: bool = False): self._tools = {tool.name: tool for tool in tools} if add_base_tools: self.add_base_tools() self._load_tools_if_needed() def add_base_tools(self, add_python_interpreter: bool = False): global _tools_are_initialized global HUGGINGFACE_DEFAULT_TOOLS if not _tools_are_initialized: HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools(logger) _tools_are_initialized = True for tool in HUGGINGFACE_DEFAULT_TOOLS.values(): if tool.name != "python_interpreter" or add_python_interpreter: self.add_tool(tool) self._load_tools_if_needed() @property def tools(self) -> Dict[str, Tool]: """Get all tools currently in the toolbox""" return self._tools def show_tool_descriptions(self, tool_description_template: str = None) -> str: """ Returns the description of all tools in the toolbox Args: tool_description_template (`str`, *optional*): The template to use to describe the tools. If not provided, the default template will be used. """ return "\n".join( [get_tool_description_with_args(tool, tool_description_template) for tool in self._tools.values()] ) def add_tool(self, tool: Tool): """ Adds a tool to the toolbox Args: tool (`Tool`): The tool to add to the toolbox. """ if tool.name in self._tools: raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.") self._tools[tool.name] = tool def remove_tool(self, tool_name: str): """ Removes a tool from the toolbox Args: tool_name (`str`): The tool to remove from the toolbox. """ if tool_name not in self._tools: raise KeyError( f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}." ) del self._tools[tool_name] def update_tool(self, tool: Tool): """ Updates a tool in the toolbox according to its name. Args: tool (`Tool`): The tool to update to the toolbox. """ if tool.name not in self._tools: raise KeyError( f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}." ) self._tools[tool.name] = tool def clear_toolbox(self): """Clears the toolbox""" self._tools = {} def _load_tools_if_needed(self): for name, tool in self._tools.items(): if not isinstance(tool, Tool): task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id self._tools[name] = load_tool(task_or_repo_id) def __repr__(self): toolbox_description = "Toolbox contents:\n" for tool in self._tools.values(): toolbox_description += f"\t{tool.name}: {tool.description}\n" return toolbox_description def format_prompt_with_tools(toolbox: Toolbox, prompt_template: str, tool_description_template: str) -> str: tool_descriptions = toolbox.show_tool_descriptions(tool_description_template) prompt = prompt_template.replace("<>", tool_descriptions) if "<>" in prompt: tool_names = [f"'{tool_name}'" for tool_name in toolbox.tools.keys()] prompt = prompt.replace("<>", ", ".join(tool_names)) return prompt class AgentError(Exception): """Base class for other agent-related exceptions""" def __init__(self, message): super().__init__(message) self.message = message class AgentParsingError(AgentError): """Exception raised for errors in parsing in the agent""" pass class AgentExecutionError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentMaxIterationsError(AgentError): """Exception raised for errors in execution in the agent""" pass class AgentGenerationError(AgentError): """Exception raised for errors in generation in the agent""" pass class Agent: def __init__( self, tools: Union[List[Tool], Toolbox], llm_engine: Callable = HfEngine(), system_prompt=DEFAULT_REACT_JSON_SYSTEM_PROMPT, tool_description_template=None, additional_args={}, max_iterations: int = 6, tool_parser=parse_json_tool_call, add_base_tools: bool = False, verbose: int = 0, memory_verbose: bool = False, ): self.agent_name = self.__class__.__name__ self.llm_engine = llm_engine self.system_prompt_template = system_prompt self.tool_description_template = ( tool_description_template if tool_description_template else DEFAULT_TOOL_DESCRIPTION_TEMPLATE ) self.additional_args = additional_args self.max_iterations = max_iterations self.logger = logger self.tool_parser = tool_parser if isinstance(tools, Toolbox): self._toolbox = tools if add_base_tools: if not is_torch_available(): raise ImportError("Using the base tools requires torch to be installed.") self._toolbox.add_base_tools(add_python_interpreter=(self.__class__ == ReactJsonAgent)) else: self._toolbox = Toolbox(tools, add_base_tools=add_base_tools) self.system_prompt = format_prompt_with_tools( self._toolbox, self.system_prompt_template, self.tool_description_template ) self.prompt = None self.logs = [] self.task = None self.memory_verbose = memory_verbose if verbose == 0: logger.setLevel(logging.WARNING) elif verbose == 1: logger.setLevel(logging.INFO) elif verbose == 2: logger.setLevel(logging.DEBUG) @property def toolbox(self) -> Toolbox: """Get the toolbox currently available to the agent""" return self._toolbox def initialize_for_run(self, task: str, **kwargs): self.task = task if len(kwargs) > 0: self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}." self.state = kwargs.copy() self.system_prompt = format_prompt_with_tools( self._toolbox, self.system_prompt_template, self.tool_description_template ) self.logs = [{"system_prompt": self.system_prompt, "task": self.task}] self.logger.warn("======== New task ========") self.logger.log(33, self.task) self.logger.debug("System prompt is as follows:") self.logger.debug(self.system_prompt) def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: """ Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages that can be used as input to the LLM. """ prompt_message = {"role": MessageRole.SYSTEM, "content": self.logs[0]["system_prompt"]} task_message = { "role": MessageRole.USER, "content": "Task: " + self.logs[0]["task"], } memory = [prompt_message, task_message] for i, step_log in enumerate(self.logs[1:]): if "llm_output" in step_log: thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"} memory.append(thought_message) if "error" in step_log: message_content = ( "Error: " + str(step_log["error"]) + "\nNow let's retry: take care not to repeat previous errors! Try to adopt different approaches.\n" ) elif "observation" in step_log: message_content = f"Observation: {step_log['observation']}" tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} memory.append(tool_response_message) if len(memory) % 3 == 0: reminder_content = ( "Reminder: you are working towards solving the following task: " + self.logs[0]["task"] ) reminder_content += "\nHere is a summary of your past tool calls and their results:" for j in range(i + 1): reminder_content += "\nStep " + str(j + 1) if "tool_call" in self.logs[j]: reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"]) if self.memory_verbose: if "observation" in self.logs[j]: reminder_content += "\nObservation:" + str(self.logs[j]["observation"]) if "error" in self.logs[j]: reminder_content += "\nError:" + str(self.logs[j]["error"]) memory.append( { "role": MessageRole.USER, "content": reminder_content, } ) return memory def extract_action(self, llm_output: str, split_token: str) -> str: """ Parse action from the LLM output Args: llm_output (`str`): Output of the LLM split_token (`str`): Separator for the action. Should match the example in the system prompt. """ try: split = llm_output.split(split_token) rationale, action = ( split[-2], split[-1], ) # NOTE: using indexes starting from the end solves for when you have more than one split_token in the output except Exception as e: self.logger.error(e, exc_info=1) raise AgentParsingError( f"Error: No '{split_token}' token provided in your output.\nYour output:\n{llm_output}\n. Be sure to include an action, prefaced with '{split_token}'!" ) return rationale, action def execute_tool_call(self, tool_name: str, arguments: Dict[str, str]) -> Any: """ Execute tool with the provided input and returns the result. This method replaces arguments with the actual values from the state if they refer to state variables. Args: tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox). arguments (Dict[str, str]): Arguments passed to the Tool. """ if tool_name not in self.toolbox.tools: error_msg = f"Error: unknown tool {tool_name}, should be instead one of {list(self.toolbox.tools.keys())}." self.logger.error(error_msg, exc_info=1) raise AgentExecutionError(error_msg) try: if isinstance(arguments, str): observation = self.toolbox.tools[tool_name](arguments) else: for key, value in arguments.items(): # if the value is the name of a state variable like "image.png", replace it with the actual value if isinstance(value, str) and value in self.state: arguments[key] = self.state[value] observation = self.toolbox.tools[tool_name](**arguments) return observation except Exception as e: raise AgentExecutionError( f"Error in tool call execution: {e}\nYou should only use this tool with a correct input.\n" f"As a reminder, this tool's description is the following:\n{get_tool_description_with_args(self.toolbox.tools[tool_name])}" ) def log_code_action(self, code_action: str) -> None: self.logger.warning("==== Agent is executing the code below:") if is_pygments_available(): self.logger.log( 31, highlight(code_action, PythonLexer(ensurenl=False), Terminal256Formatter(style="nord")) ) else: self.logger.log(31, code_action) self.logger.warning("====") def run(self, **kwargs): """To be implemented in the child class""" raise NotImplementedError class CodeAgent(Agent): """ A class for an agent that solves the given task using a single block of code. It plans all its actions, then executes all in one shot. """ def __init__( self, tools: List[Tool], llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, **kwargs, ): super().__init__( tools=tools, llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, **kwargs, ) if not is_pygments_available(): transformers_logging.warning_once( logger, "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " "CodeAgent.", ) self.python_evaluator = evaluate_python_code def parse_code_blob(self, result: str) -> str: """ Override this method if you want to change the way the code is cleaned in the `run` method. """ return parse_code_blob(result) def run(self, task: str, return_generated_code: bool = False, **kwargs): """ Runs the agent for the given task. Args: task (`str`): The task to perform return_generated_code (`bool`, *optional*, defaults to `False`): Whether to return the generated code instead of running it kwargs (additional keyword arguments, *optional*): Any keyword argument to send to the agent when evaluating the code. Example: ```py from transformers.agents import CodeAgent, PythonInterpreterTool python_interpreter = PythonInterpreterTool() agent = CodeAgent(tools=[python_interpreter]) agent.run("What is the result of 2 power 3.7384?") ``` """ self.initialize_for_run(task, **kwargs) # Run LLM prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt} task_message = { "role": MessageRole.USER, "content": "Task: " + self.task, } self.prompt = [prompt_message, task_message] self.logger.info("====Executing with this prompt====") self.logger.info(self.prompt) llm_output = self.llm_engine(self.prompt, stop_sequences=[""]) if return_generated_code: return llm_output # Parse _, code_action = self.extract_action(llm_output=llm_output, split_token="Code:") try: code_action = self.parse_code_blob(code_action) except Exception as e: error_msg = f"Error in code parsing: {e}. Be sure to provide correct code" self.logger.error(error_msg, exc_info=1) return error_msg # Execute self.log_code_action(code_action) try: available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} output = self.python_evaluator(code_action, available_tools, state=self.state) self.logger.info(self.state["print_outputs"]) return output except Exception as e: error_msg = f"Error in execution: {e}. Be sure to provide correct code." self.logger.error(error_msg, exc_info=1) return error_msg class ReactAgent(Agent): """ This agent that solves the given task step by step, using the ReAct framework: While the objective is not reached, the agent will perform a cycle of thinking and acting. The action will be parsed from the LLM output: it consists in calls to tools from the toolbox, with arguments chosen by the LLM engine. """ def __init__( self, tools: List[Tool], llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, **kwargs, ): super().__init__( tools=tools, llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, **kwargs, ) if "final_answer" not in self._toolbox.tools: self._toolbox.add_tool(FinalAnswerTool()) def run(self, task: str, **kwargs): """ Runs the agent for the given task. Args: task (`str`): The task to perform Example: ```py from transformers.agents import ReactJsonAgent, PythonInterpreterTool python_interpreter = PythonInterpreterTool() agent = ReactJsonAgent(tools=[python_interpreter]) agent.run("What is the result of 2 power 3.7384?") ``` """ self.initialize_for_run(task, **kwargs) final_answer = None iteration = 0 while final_answer is None and iteration < self.max_iterations: try: final_answer = self.step() except AgentError as e: self.logger.error(e, exc_info=1) self.logs[-1]["error"] = e finally: iteration += 1 if final_answer is None and iteration == self.max_iterations: error_message = "Reached max iterations." self.logs.append({"error": AgentMaxIterationsError(error_message)}) self.logger.error(error_message, exc_info=1) self.prompt = [ { "role": MessageRole.SYSTEM, "content": "An agent tried to answer a user query but it failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:", } ] self.prompt += self.write_inner_memory_from_logs()[1:] self.prompt += [ { "role": MessageRole.USER, "content": f"Based on the above, please provide an answer to the following user request:\n{task}", } ] try: final_answer = self.llm_engine(self.prompt, stop_sequences=["Observation:"]) except Exception as e: final_answer = f"Error in generating final llm output: {e}." return final_answer class ReactJsonAgent(ReactAgent): """ This agent that solves the given task step by step, using the ReAct framework: While the objective is not reached, the agent will perform a cycle of thinking and acting. The tool calls will be formulated by the LLM in JSON format, then parsed and executed. """ def __init__( self, tools: List[Tool], llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, **kwargs, ): super().__init__( tools=tools, llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, **kwargs, ) def step(self): """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. The errors are raised here, they are caught and logged in the run() method. """ agent_memory = self.write_inner_memory_from_logs() self.logs[-1]["agent_memory"] = agent_memory.copy() self.prompt = agent_memory self.logger.debug("===== New step =====") # Add new step in logs self.logs.append({}) self.logger.info("===== Calling LLM with this last message: =====") self.logger.info(self.prompt[-1]) try: llm_output = self.llm_engine(self.prompt, stop_sequences=["Observation:"]) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") self.logger.debug(llm_output) self.logs[-1]["llm_output"] = llm_output # Parse self.logger.debug("===== Extracting action =====") rationale, action = self.extract_action(llm_output=llm_output, split_token="Action:") try: tool_name, arguments = self.tool_parser(action) except Exception as e: raise AgentParsingError(f"Could not parse the given action: {e}.") self.logs[-1]["rationale"] = rationale self.logs[-1]["tool_call"] = {"tool_name": tool_name, "tool_arguments": arguments} # Execute self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") if tool_name == "final_answer": if isinstance(arguments, dict): answer = arguments["answer"] else: answer = arguments if answer in self.state: # if the answer is a state variable, return the value answer = self.state[answer] return answer else: observation = self.execute_tool_call(tool_name, arguments) observation_type = type(observation) if observation_type == AgentText: updated_information = str(observation).strip() else: # TODO: observation naming could allow for different names of same type if observation_type == AgentImage: observation_name = "image.png" elif observation_type == AgentAudio: observation_name = "audio.mp3" else: observation_name = "object.object" self.state[observation_name] = observation updated_information = f"Stored '{observation_name}' in memory." self.logger.info(updated_information) self.logs[-1]["observation"] = updated_information return None class ReactCodeAgent(ReactAgent): """ This agent that solves the given task step by step, using the ReAct framework: While the objective is not reached, the agent will perform a cycle of thinking and acting. The tool calls will be formulated by the LLM in code format, then parsed and executed. """ def __init__( self, tools: List[Tool], llm_engine: Callable = HfEngine(), system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, **kwargs, ): super().__init__( tools=tools, llm_engine=llm_engine, system_prompt=system_prompt, tool_description_template=tool_description_template, **kwargs, ) if not is_pygments_available(): transformers_logging.warning_once( logger, "pygments isn't installed. Installing pygments will enable color syntax highlighting in the " "ReactCodeAgent.", ) self.python_evaluator = evaluate_python_code def step(self): """ Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. The errors are raised here, they are caught and logged in the run() method. """ agent_memory = self.write_inner_memory_from_logs() self.logs[-1]["agent_memory"] = agent_memory.copy() self.prompt = agent_memory.copy() self.logger.debug("===== New step =====") # Add new step in logs self.logs.append({}) self.logger.info("===== Calling LLM with these last messages: =====") self.logger.info(self.prompt[-2:]) try: llm_output = self.llm_engine(self.prompt, stop_sequences=["", "Observation:"]) except Exception as e: raise AgentGenerationError(f"Error in generating llm output: {e}.") self.logger.debug("===== Output message of the LLM: =====") self.logger.debug(llm_output) self.logs[-1]["llm_output"] = llm_output # Parse self.logger.debug("===== Extracting action =====") rationale, raw_code_action = self.extract_action(llm_output=llm_output, split_token="Code:") try: code_action = parse_code_blob(raw_code_action) except Exception as e: error_msg = f"Error in code parsing: {e}. Make sure to provide correct code" raise AgentParsingError(error_msg) self.logs[-1]["rationale"] = rationale self.logs[-1]["tool_call"] = {"tool_name": "code interpreter", "tool_arguments": code_action} # Execute self.log_code_action(code_action) try: available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} result = self.python_evaluator(code_action, available_tools, state=self.state) information = self.state["print_outputs"] self.logger.warning("Print outputs:") self.logger.log(32, information) self.logs[-1]["observation"] = information except Exception as e: error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}" if "'dict' object has no attribute 'read'" in str(e): error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string." raise AgentExecutionError(error_msg) for line in code_action.split("\n"): if line[: len("final_answer")] == "final_answer": self.logger.warning(">>> Final answer:") self.logger.log(32, result) return result return None