Unverified Commit b3818805 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Agents planning (#31702)

* Allow planning for agents
parent 0fdea860
...@@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available ...@@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText from .agent_types import AgentAudio, AgentImage, AgentText
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfEngine, MessageRole from .llm_engine import HfEngine, MessageRole
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE,
SYSTEM_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import ( from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE, DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
...@@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]: ...@@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str: def parse_code_blob(code_blob: str) -> str:
try: try:
pattern = r"```(?:py|python)?\n(.*?)```" pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL) match = re.search(pattern, code_blob, re.DOTALL)
return match.group(1).strip() return match.group(1).strip()
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}" f"""
The code blob you used is invalid: due to the following error: {e}
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
Code:
```py
# Your python code here
```<end_action>"""
) )
...@@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]: ...@@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
tool_call = parse_json_blob(json_blob) tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call: if "action" in tool_call and "action_input" in tool_call:
return tool_call["action"], tool_call["action_input"] return tool_call["action"], tool_call["action_input"]
elif "action" in tool_call:
return tool_call["action"], None
else: else:
raise ValueError( raise ValueError(
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}" f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
...@@ -208,7 +229,7 @@ class Toolbox: ...@@ -208,7 +229,7 @@ class Toolbox:
The tool to add to the toolbox. The tool to add to the toolbox.
""" """
if tool.name in self._tools: if tool.name in self._tools:
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.") raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
self._tools[tool.name] = tool self._tools[tool.name] = tool
def remove_tool(self, tool_name: str): def remove_tool(self, tool_name: str):
...@@ -359,12 +380,8 @@ class Agent: ...@@ -359,12 +380,8 @@ class Agent:
"""Get the toolbox currently available to the agent""" """Get the toolbox currently available to the agent"""
return self._toolbox return self._toolbox
def initialize_for_run(self, task: str, **kwargs): def initialize_for_run(self):
self.token_count = 0 self.token_count = 0
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools( self.system_prompt = format_prompt_with_tools(
self._toolbox, self._toolbox,
self.system_prompt_template, self.system_prompt_template,
...@@ -380,7 +397,7 @@ class Agent: ...@@ -380,7 +397,7 @@ class Agent:
self.logger.debug("System prompt is as follows:") self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt) self.logger.debug(self.system_prompt)
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]: def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
""" """
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM. that can be used as input to the LLM.
...@@ -390,43 +407,51 @@ class Agent: ...@@ -390,43 +407,51 @@ class Agent:
"role": MessageRole.USER, "role": MessageRole.USER,
"content": "Task: " + self.logs[0]["task"], "content": "Task: " + self.logs[0]["task"],
} }
if summary_mode:
memory = [task_message]
else:
memory = [prompt_message, task_message] memory = [prompt_message, task_message]
for i, step_log in enumerate(self.logs[1:]): for i, step_log in enumerate(self.logs[1:]):
if "llm_output" in step_log: if "llm_output" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"} thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
memory.append(thought_message)
if "facts" in step_log:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
}
memory.append(thought_message)
if "plan" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
memory.append(thought_message) memory.append(thought_message)
if "tool_call" in step_log and summary_mode:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
}
memory.append(tool_call_message)
if "task" in step_log:
tool_call_message = {
"role": MessageRole.USER,
"content": "New task:\n" + step_log["task"],
}
memory.append(tool_call_message)
if "error" in step_log or "observation" in step_log:
if "error" in step_log: if "error" in step_log:
message_content = ( message_content = (
"Error: " f"[OUTPUT OF STEP {i}] Error: "
+ str(step_log["error"]) + str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
) )
elif "observation" in step_log: elif "observation" in step_log:
message_content = f"Observation: {step_log['observation']}" message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content} tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
memory.append(tool_response_message) memory.append(tool_response_message)
if len(memory) % 3 == 0:
reminder_content = (
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
)
reminder_content += "\nHere is a summary of your past tool calls and their results:"
for j in range(i + 1):
reminder_content += "\nStep " + str(j + 1)
if "tool_call" in self.logs[j]:
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
if self.memory_verbose:
if "observation" in self.logs[j]:
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
if "error" in self.logs[j]:
reminder_content += "\nError:" + str(self.logs[j]["error"])
memory.append(
{
"role": MessageRole.USER,
"content": reminder_content,
}
)
return memory return memory
def get_succinct_logs(self): def get_succinct_logs(self):
...@@ -459,7 +484,7 @@ class Agent: ...@@ -459,7 +484,7 @@ class Agent:
This method replaces arguments with the actual values from the state if they refer to state variables. This method replaces arguments with the actual values from the state if they refer to state variables.
Args: Args:
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox). tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool. arguments (Dict[str, str]): Arguments passed to the Tool.
""" """
if tool_name not in self.toolbox.tools: if tool_name not in self.toolbox.tools:
...@@ -559,7 +584,11 @@ class CodeAgent(Agent): ...@@ -559,7 +584,11 @@ class CodeAgent(Agent):
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
""" """
self.initialize_for_run(task, **kwargs) self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.initialize_for_run()
# Run LLM # Run LLM
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt} prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
...@@ -598,7 +627,8 @@ class CodeAgent(Agent): ...@@ -598,7 +627,8 @@ class CodeAgent(Agent):
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools} available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
output = self.python_evaluator( output = self.python_evaluator(
code_action, code_action,
available_tools, static_tools=available_tools,
custom_tools={},
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
...@@ -623,6 +653,7 @@ class ReactAgent(Agent): ...@@ -623,6 +653,7 @@ class ReactAgent(Agent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -632,6 +663,7 @@ class ReactAgent(Agent): ...@@ -632,6 +663,7 @@ class ReactAgent(Agent):
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
**kwargs, **kwargs,
) )
self.planning_interval = planning_interval
def provide_final_answer(self, task) -> str: def provide_final_answer(self, task) -> str:
""" """
...@@ -655,11 +687,13 @@ class ReactAgent(Agent): ...@@ -655,11 +687,13 @@ class ReactAgent(Agent):
except Exception as e: except Exception as e:
return f"Error in generating final llm output: {e}." return f"Error in generating final llm output: {e}."
def run(self, task: str, stream: bool = False, **kwargs): def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
""" """
Runs the agent for the given task. Runs the agent for the given task.
Args: Args:
task (`str`): The task to perform task (`str`): The task to perform
Example: Example:
```py ```py
from transformers.agents import ReactCodeAgent from transformers.agents import ReactCodeAgent
...@@ -667,14 +701,23 @@ class ReactAgent(Agent): ...@@ -667,14 +701,23 @@ class ReactAgent(Agent):
agent.run("What is the result of 2 power 3.7384?") agent.run("What is the result of 2 power 3.7384?")
``` ```
""" """
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
if reset:
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream: if stream:
return self.stream_run(task, **kwargs) return self.stream_run(task)
else: else:
return self.direct_run(task, **kwargs) return self.direct_run(task)
def stream_run(self, task: str, **kwargs):
self.initialize_for_run(task, **kwargs)
def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None final_answer = None
iteration = 0 iteration = 0
while final_answer is None and iteration < self.max_iterations: while final_answer is None and iteration < self.max_iterations:
...@@ -700,13 +743,16 @@ class ReactAgent(Agent): ...@@ -700,13 +743,16 @@ class ReactAgent(Agent):
yield final_answer yield final_answer
def direct_run(self, task: str, **kwargs): def direct_run(self, task: str):
self.initialize_for_run(task, **kwargs) """
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None final_answer = None
iteration = 0 iteration = 0
while final_answer is None and iteration < self.max_iterations: while final_answer is None and iteration < self.max_iterations:
try: try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step() step_logs = self.step()
if "final_answer" in step_logs: if "final_answer" in step_logs:
final_answer = step_logs["final_answer"] final_answer = step_logs["final_answer"]
...@@ -726,6 +772,96 @@ class ReactAgent(Agent): ...@@ -726,6 +772,96 @@ class ReactAgent(Agent):
return final_answer return final_answer
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Used periodically by the agent to plan the next steps to reach the objective.
Args:
task (`str`): The task to perform
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
iteration (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
```
{task}
```
Now begin!""",
}
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
message_user_prompt_plan = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
)
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
{answer_plan}
```"""
final_facts_redaction = f"""Here are the facts that I know so far:
```
{answer_facts}
```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Initial plan: =====")
self.logger.debug(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
) # This will not log the plan but will log facts
# Redact updated facts
facts_update_system_prompt = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS_UPDATE,
}
facts_update_message = {
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
# Redact updated plan
plan_update_message = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
}
plan_update_message_user = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
)
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Updated plan: =====")
self.logger.debug(final_plan_redaction)
class ReactJsonAgent(ReactAgent): class ReactJsonAgent(ReactAgent):
""" """
...@@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent): ...@@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine: Callable = HfEngine(), llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent): ...@@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs, **kwargs,
) )
...@@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent): ...@@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}") self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer": if tool_name == "final_answer":
if isinstance(arguments, dict): if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"] answer = arguments["answer"]
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else:
answer = arguments
else: else:
answer = arguments answer = arguments
if answer in self.state: # if the answer is a state variable, return the value
answer = self.state[answer]
current_step_logs["final_answer"] = answer current_step_logs["final_answer"] = answer
return current_step_logs return current_step_logs
else: else:
...@@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT, system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE, tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: Optional[List[str]] = None, additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs, **kwargs,
): ):
super().__init__( super().__init__(
...@@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine=llm_engine, llm_engine=llm_engine,
system_prompt=system_prompt, system_prompt=system_prompt,
tool_description_template=tool_description_template, tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs, **kwargs,
) )
...@@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports)) self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.available_tools = { self.custom_tools = {}
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
} # This list can be augmented by the code agent creating some new functions
def step(self): def step(self):
""" """
...@@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent): ...@@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
try: try:
result = self.python_evaluator( result = self.python_evaluator(
code_action, code_action,
tools=self.available_tools, static_tools={
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
},
custom_tools=self.custom_tools,
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
...@@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent): ...@@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
self.logger.log(32, information) self.logger.log(32, information)
current_step_logs["observation"] = information current_step_logs["observation"] = information
except Exception as e: except Exception as e:
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}" error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e): if "'dict' object has no attribute 'read'" in str(e):
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string." error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
raise AgentExecutionError(error_msg) raise AgentExecutionError(error_msg)
......
...@@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool): ...@@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
def forward(self, code): def forward(self, code):
output = str( output = str(
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports) evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
) )
return output return output
......
...@@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task: ...@@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables. 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>> 8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
9. Don't give up! You're in charge of solving the task, not providing directions to solve it. 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
""" """
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
---
### 1. Facts given in the task
List here the specific facts given in the task that could help you (there might be nothing here).
### 2. Facts to look up
List here any facts that we may need to look up.
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
### 3. Facts to derive
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
### 1. Facts given in the task
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else."""
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
USER_PROMPT_PLAN = """
Here is your task:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
List of facts that you know:
```
{answer_facts}
```
Now begin! Write your plan below."""
SYSTEM_PROMPT_FACTS_UPDATE = """
You are a world expert at gathering known and unknown facts based on a conversation.
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Find the task and history below."""
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
Please update your list of facts based on the previous history, and provide these headings:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Now write your new list of facts below."""
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
You have been given a task:
```
{task}
```
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
If the previous tries so far have met some success, you can make an updated plan based on these actions.
If you are stalled, you can make a completely new plan starting from scratch.
"""
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
```
{task}
```
You have access to these tools:
{tool_descriptions}
Here is the up to date list of facts that you know:
```
{facts_update}
```
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Beware that you have {remaining_steps} steps remaining.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
Now write your new plan below."""
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
```
{task}
```
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""
...@@ -18,8 +18,17 @@ import ast ...@@ -18,8 +18,17 @@ import ast
import builtins import builtins
import difflib import difflib
from collections.abc import Mapping from collections.abc import Mapping
from importlib import import_module
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import numpy as np
from ..utils import is_pandas_available
if is_pandas_available():
import pandas as pd
class InterpreterError(ValueError): class InterpreterError(ValueError):
""" """
...@@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [ ...@@ -50,7 +59,8 @@ LIST_SAFE_MODULES = [
"unicodedata", "unicodedata",
] ]
PRINT_OUTPUTS = "" PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
class BreakException(Exception): class BreakException(Exception):
...@@ -75,8 +85,8 @@ def get_iterable(obj): ...@@ -75,8 +85,8 @@ def get_iterable(obj):
raise InterpreterError("Object is not iterable") raise InterpreterError("Object is not iterable")
def evaluate_unaryop(expression, state, tools): def evaluate_unaryop(expression, state, static_tools, custom_tools):
operand = evaluate_ast(expression.operand, state, tools) operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
if isinstance(expression.op, ast.USub): if isinstance(expression.op, ast.USub):
return -operand return -operand
elif isinstance(expression.op, ast.UAdd): elif isinstance(expression.op, ast.UAdd):
...@@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools): ...@@ -89,25 +99,25 @@ def evaluate_unaryop(expression, state, tools):
raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
def evaluate_lambda(lambda_expression, state, tools): def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
args = [arg.arg for arg in lambda_expression.args.args] args = [arg.arg for arg in lambda_expression.args.args]
def lambda_func(*values): def lambda_func(*values):
new_state = state.copy() new_state = state.copy()
for arg, value in zip(args, values): for arg, value in zip(args, values):
new_state[arg] = value new_state[arg] = value
return evaluate_ast(lambda_expression.body, new_state, tools) return evaluate_ast(lambda_expression.body, new_state, static_tools, custom_tools)
return lambda_func return lambda_func
def evaluate_while(while_loop, state, tools): def evaluate_while(while_loop, state, static_tools, custom_tools):
max_iterations = 1000 max_iterations = 1000
iterations = 0 iterations = 0
while evaluate_ast(while_loop.test, state, tools): while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
for node in while_loop.body: for node in while_loop.body:
try: try:
evaluate_ast(node, state, tools) evaluate_ast(node, state, static_tools, custom_tools)
except BreakException: except BreakException:
return None return None
except ContinueException: except ContinueException:
...@@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools): ...@@ -118,11 +128,11 @@ def evaluate_while(while_loop, state, tools):
return None return None
def create_function(func_def, state, tools): def create_function(func_def, state, static_tools, custom_tools):
def new_func(*args, **kwargs): def new_func(*args, **kwargs):
func_state = state.copy() func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args] arg_names = [arg.arg for arg in func_def.args.args]
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults] default_values = [evaluate_ast(d, state, static_tools, custom_tools) for d in func_def.args.defaults]
# Apply default values # Apply default values
defaults = dict(zip(arg_names[-len(default_values) :], default_values)) defaults = dict(zip(arg_names[-len(default_values) :], default_values))
...@@ -158,7 +168,7 @@ def create_function(func_def, state, tools): ...@@ -158,7 +168,7 @@ def create_function(func_def, state, tools):
result = None result = None
try: try:
for stmt in func_def.body: for stmt in func_def.body:
result = evaluate_ast(stmt, func_state, tools) result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
except ReturnException as e: except ReturnException as e:
result = e.value result = e.value
return result return result
...@@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body): ...@@ -173,25 +183,25 @@ def create_class(class_name, class_bases, class_body):
return type(class_name, tuple(class_bases), class_dict) return type(class_name, tuple(class_bases), class_dict)
def evaluate_function_def(func_def, state, tools): def evaluate_function_def(func_def, state, static_tools, custom_tools):
tools[func_def.name] = create_function(func_def, state, tools) custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools)
return tools[func_def.name] return custom_tools[func_def.name]
def evaluate_class_def(class_def, state, tools): def evaluate_class_def(class_def, state, static_tools, custom_tools):
class_name = class_def.name class_name = class_def.name
bases = [evaluate_ast(base, state, tools) for base in class_def.bases] bases = [evaluate_ast(base, state, static_tools, custom_tools) for base in class_def.bases]
class_dict = {} class_dict = {}
for stmt in class_def.body: for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef): if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools)
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
for target in stmt.targets: for target in stmt.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(stmt.value, state, tools) class_dict[target.id] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools) class_dict[target.attr] = evaluate_ast(stmt.value, state, static_tools, custom_tools)
else: else:
raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
...@@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools): ...@@ -200,17 +210,17 @@ def evaluate_class_def(class_def, state, tools):
return new_class return new_class
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): def evaluate_augassign(expression, state, static_tools, custom_tools):
# Helper function to get current value and set new value based on the target type # Helper function to get current value and set new value based on the target type
def get_current_value(target): def get_current_value(target):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
return state.get(target.id, 0) return state.get(target.id, 0)
elif isinstance(target, ast.Subscript): elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools) obj = evaluate_ast(target.value, state, static_tools, custom_tools)
key = evaluate_ast(target.slice, state, tools) key = evaluate_ast(target.slice, state, static_tools, custom_tools)
return obj[key] return obj[key]
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools) obj = evaluate_ast(target.value, state, static_tools, custom_tools)
return getattr(obj, target.attr) return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts) return tuple(get_current_value(elt) for elt in target.elts)
...@@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: ...@@ -220,7 +230,7 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
raise InterpreterError("AugAssign not supported for {type(target)} targets.") raise InterpreterError("AugAssign not supported for {type(target)} targets.")
current_value = get_current_value(expression.target) current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, tools) value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
# Determine the operation and apply it # Determine the operation and apply it
if isinstance(expression.op, ast.Add): if isinstance(expression.op, ast.Add):
...@@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: ...@@ -256,28 +266,28 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
# Update the state # Update the state
set_value(expression.target, updated_value, state, tools) set_value(expression.target, updated_value, state, static_tools, custom_tools)
return updated_value return updated_value
def evaluate_boolop(node, state, tools): def evaluate_boolop(node, state, static_tools, custom_tools):
if isinstance(node.op, ast.And): if isinstance(node.op, ast.And):
for value in node.values: for value in node.values:
if not evaluate_ast(value, state, tools): if not evaluate_ast(value, state, static_tools, custom_tools):
return False return False
return True return True
elif isinstance(node.op, ast.Or): elif isinstance(node.op, ast.Or):
for value in node.values: for value in node.values:
if evaluate_ast(value, state, tools): if evaluate_ast(value, state, static_tools, custom_tools):
return True return True
return False return False
def evaluate_binop(binop, state, tools): def evaluate_binop(binop, state, static_tools, custom_tools):
# Recursively evaluate the left and right operands # Recursively evaluate the left and right operands
left_val = evaluate_ast(binop.left, state, tools) left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
right_val = evaluate_ast(binop.right, state, tools) right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
# Determine the operation based on the type of the operator in the BinOp # Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add): if isinstance(binop.op, ast.Add):
...@@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools): ...@@ -308,66 +318,92 @@ def evaluate_binop(binop, state, tools):
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
def evaluate_assign(assign, state, tools): def evaluate_assign(assign, state, static_tools, custom_tools):
result = evaluate_ast(assign.value, state, tools) result = evaluate_ast(assign.value, state, static_tools, custom_tools)
if len(assign.targets) == 1: if len(assign.targets) == 1:
target = assign.targets[0] target = assign.targets[0]
set_value(target, result, state, tools) set_value(target, result, state, static_tools, custom_tools)
else: else:
if len(assign.targets) != len(result): if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
for tgt, val in zip(assign.targets, result): expanded_values = []
set_value(tgt, val, state, tools) for tgt in assign.targets:
if isinstance(tgt, ast.Starred):
expanded_values.extend(result)
else:
expanded_values.append(result)
for tgt, val in zip(assign.targets, expanded_values):
set_value(tgt, val, state, static_tools, custom_tools)
return result return result
def set_value(target, value, state, tools): def set_value(target, value, state, static_tools, custom_tools):
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in tools: if target.id in static_tools:
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
state[target.id] = value state[target.id] = value
elif isinstance(target, ast.Tuple): elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple): if not isinstance(value, tuple):
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
value = tuple(value)
else:
raise InterpreterError("Cannot unpack non-tuple value") raise InterpreterError("Cannot unpack non-tuple value")
if len(target.elts) != len(value): if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size") raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts): for i, elem in enumerate(target.elts):
set_value(elem, value[i], state, tools) set_value(elem, value[i], state, static_tools, custom_tools)
elif isinstance(target, ast.Subscript): elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools) obj = evaluate_ast(target.value, state, static_tools, custom_tools)
key = evaluate_ast(target.slice, state, tools) key = evaluate_ast(target.slice, state, static_tools, custom_tools)
obj[key] = value obj[key] = value
elif isinstance(target, ast.Attribute): elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools) obj = evaluate_ast(target.value, state, static_tools, custom_tools)
setattr(obj, target.attr, value) setattr(obj, target.attr, value)
def evaluate_call(call, state, tools): def evaluate_call(call, state, static_tools, custom_tools):
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpreterError( raise InterpreterError(f"This is not a correct function: {call.func}).")
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
)
if isinstance(call.func, ast.Attribute): if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, tools) obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
func_name = call.func.attr func_name = call.func.attr
if not hasattr(obj, func_name): if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}") raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name) func = getattr(obj, func_name)
elif isinstance(call.func, ast.Name): elif isinstance(call.func, ast.Name):
func_name = call.func.id func_name = call.func.id
if func_name in state: if func_name in state:
func = state[func_name] func = state[func_name]
elif func_name in tools: elif func_name in static_tools:
func = tools[func_name] func = static_tools[func_name]
elif func_name in custom_tools:
func = custom_tools[func_name]
elif func_name in ERRORS: elif func_name in ERRORS:
func = ERRORS[func_name] func = ERRORS[func_name]
else: else:
raise InterpreterError( raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
) )
args = [evaluate_ast(arg, state, tools) for arg in call.args] args = []
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords} for arg in call.args:
if isinstance(arg, ast.Starred):
args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
args.extend(unpacked)
else:
args.append(evaluate_ast(arg, state, static_tools, custom_tools))
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools) for keyword in call.keywords}
if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes if isinstance(func, type) and len(func.__module__.split(".")) > 1: # Check for user-defined classes
# Instantiate the class using its constructor # Instantiate the class using its constructor
...@@ -397,24 +433,31 @@ def evaluate_call(call, state, tools): ...@@ -397,24 +433,31 @@ def evaluate_call(call, state, tools):
output = " ".join(map(str, args)) output = " ".join(map(str, args))
global PRINT_OUTPUTS global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n" PRINT_OUTPUTS += output + "\n"
# cap the number of lines
return output return output
else: # Assume it's a callable object else: # Assume it's a callable object
output = func(*args, **kwargs) output = func(*args, **kwargs)
return output return output
def evaluate_subscript(subscript, state, tools): def evaluate_subscript(subscript, state, static_tools, custom_tools):
index = evaluate_ast(subscript.slice, state, tools) index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
value = evaluate_ast(subscript.value, state, tools) value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
if isinstance(index, slice):
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
return value[index]
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
return value[index]
elif isinstance(index, slice):
return value[index] return value[index]
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
return value[int(index)] return value[int(index)]
elif isinstance(value, str): elif isinstance(value, str):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)): if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
return value[index] return value[index]
...@@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools): ...@@ -427,11 +470,11 @@ def evaluate_subscript(subscript, state, tools):
raise InterpreterError(f"Could not index {value} with '{index}'.") raise InterpreterError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools): def evaluate_name(name, state, static_tools, custom_tools):
if name.id in state: if name.id in state:
return state[name.id] return state[name.id]
elif name.id in tools: elif name.id in static_tools:
return tools[name.id] return static_tools[name.id]
elif name.id in ERRORS: elif name.id in ERRORS:
return ERRORS[name.id] return ERRORS[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys())) close_matches = difflib.get_close_matches(name.id, list(state.keys()))
...@@ -440,9 +483,9 @@ def evaluate_name(name, state, tools): ...@@ -440,9 +483,9 @@ def evaluate_name(name, state, tools):
raise InterpreterError(f"The variable `{name.id}` is not defined.") raise InterpreterError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools): def evaluate_condition(condition, state, static_tools, custom_tools):
left = evaluate_ast(condition.left, state, tools) left = evaluate_ast(condition.left, state, static_tools, custom_tools)
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators] comparators = [evaluate_ast(c, state, static_tools, custom_tools) for c in condition.comparators]
ops = [type(op) for op in condition.ops] ops = [type(op) for op in condition.ops]
result = True result = True
...@@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools): ...@@ -450,63 +493,61 @@ def evaluate_condition(condition, state, tools):
for op, comparator in zip(ops, comparators): for op, comparator in zip(ops, comparators):
if op == ast.Eq: if op == ast.Eq:
result = result and (current_left == comparator) current_result = current_left == comparator
elif op == ast.NotEq: elif op == ast.NotEq:
result = result and (current_left != comparator) current_result = current_left != comparator
elif op == ast.Lt: elif op == ast.Lt:
result = result and (current_left < comparator) current_result = current_left < comparator
elif op == ast.LtE: elif op == ast.LtE:
result = result and (current_left <= comparator) current_result = current_left <= comparator
elif op == ast.Gt: elif op == ast.Gt:
result = result and (current_left > comparator) current_result = current_left > comparator
elif op == ast.GtE: elif op == ast.GtE:
result = result and (current_left >= comparator) current_result = current_left >= comparator
elif op == ast.Is: elif op == ast.Is:
result = result and (current_left is comparator) current_result = current_left is comparator
elif op == ast.IsNot: elif op == ast.IsNot:
result = result and (current_left is not comparator) current_result = current_left is not comparator
elif op == ast.In: elif op == ast.In:
result = result and (current_left in comparator) current_result = current_left in comparator
elif op == ast.NotIn: elif op == ast.NotIn:
result = result and (current_left not in comparator) current_result = current_left not in comparator
else: else:
raise InterpreterError(f"Operator not supported: {op}") raise InterpreterError(f"Operator not supported: {op}")
result = result & current_result
current_left = comparator current_left = comparator
if not result:
if isinstance(result, bool) and not result:
break break
return result return result if isinstance(result, (bool, pd.Series)) else result.all()
def evaluate_if(if_statement, state, tools): def evaluate_if(if_statement, state, static_tools, custom_tools):
result = None result = None
test_result = evaluate_ast(if_statement.test, state, tools) test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
if test_result: if test_result:
for line in if_statement.body: for line in if_statement.body:
line_result = evaluate_ast(line, state, tools) line_result = evaluate_ast(line, state, static_tools, custom_tools)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
else: else:
for line in if_statement.orelse: for line in if_statement.orelse:
line_result = evaluate_ast(line, state, tools) line_result = evaluate_ast(line, state, static_tools, custom_tools)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
return result return result
def evaluate_for(for_loop, state, tools): def evaluate_for(for_loop, state, static_tools, custom_tools):
result = None result = None
iterator = evaluate_ast(for_loop.iter, state, tools) iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
for counter in iterator: for counter in iterator:
if isinstance(for_loop.target, ast.Tuple): set_value(for_loop.target, counter, state, static_tools, custom_tools)
for i, elem in enumerate(for_loop.target.elts):
state[elem.id] = counter[i]
else:
state[for_loop.target.id] = counter
for node in for_loop.body: for node in for_loop.body:
try: try:
line_result = evaluate_ast(node, state, tools) line_result = evaluate_ast(node, state, static_tools, custom_tools)
if line_result is not None: if line_result is not None:
result = line_result result = line_result
except BreakException: except BreakException:
...@@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools): ...@@ -519,55 +560,60 @@ def evaluate_for(for_loop, state, tools):
return result return result
def evaluate_listcomp(listcomp, state, tools): def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
def inner_evaluate(generators, index, current_state):
if index >= len(generators):
return [evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)]
generator = generators[index]
iter_value = evaluate_ast(generator.iter, current_state, static_tools, custom_tools)
result = [] result = []
for generator in listcomp.generators:
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value: for value in iter_value:
new_state = state.copy() new_state = current_state.copy()
if isinstance(generator.target, ast.Tuple): if isinstance(generator.target, ast.Tuple):
for idx, elem in enumerate(generator.target.elts): for idx, elem in enumerate(generator.target.elts):
new_state[elem.id] = value[idx] new_state[elem.id] = value[idx]
else: else:
new_state[generator.target.id] = value new_state[generator.target.id] = value
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs): if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in generator.ifs):
result.append(evaluate_ast(listcomp.elt, new_state, tools)) result.extend(inner_evaluate(generators, index + 1, new_state))
return result return result
return inner_evaluate(listcomp.generators, 0, state)
def evaluate_try(try_node, state, tools): def evaluate_try(try_node, state, static_tools, custom_tools):
try: try:
for stmt in try_node.body: for stmt in try_node.body:
evaluate_ast(stmt, state, tools) evaluate_ast(stmt, state, static_tools, custom_tools)
except Exception as e: except Exception as e:
matched = False matched = False
for handler in try_node.handlers: for handler in try_node.handlers:
if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, tools)): if handler.type is None or isinstance(e, evaluate_ast(handler.type, state, static_tools, custom_tools)):
matched = True matched = True
if handler.name: if handler.name:
state[handler.name] = e state[handler.name] = e
for stmt in handler.body: for stmt in handler.body:
evaluate_ast(stmt, state, tools) evaluate_ast(stmt, state, static_tools, custom_tools)
break break
if not matched: if not matched:
raise e raise e
else: else:
if try_node.orelse: if try_node.orelse:
for stmt in try_node.orelse: for stmt in try_node.orelse:
evaluate_ast(stmt, state, tools) evaluate_ast(stmt, state, static_tools, custom_tools)
finally: finally:
if try_node.finalbody: if try_node.finalbody:
for stmt in try_node.finalbody: for stmt in try_node.finalbody:
evaluate_ast(stmt, state, tools) evaluate_ast(stmt, state, static_tools, custom_tools)
def evaluate_raise(raise_node, state, tools): def evaluate_raise(raise_node, state, static_tools, custom_tools):
if raise_node.exc is not None: if raise_node.exc is not None:
exc = evaluate_ast(raise_node.exc, state, tools) exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
else: else:
exc = None exc = None
if raise_node.cause is not None: if raise_node.cause is not None:
cause = evaluate_ast(raise_node.cause, state, tools) cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
else: else:
cause = None cause = None
if exc is not None: if exc is not None:
...@@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools): ...@@ -579,11 +625,11 @@ def evaluate_raise(raise_node, state, tools):
raise InterpreterError("Re-raise is not supported without an active exception") raise InterpreterError("Re-raise is not supported without an active exception")
def evaluate_assert(assert_node, state, tools): def evaluate_assert(assert_node, state, static_tools, custom_tools):
test_result = evaluate_ast(assert_node.test, state, tools) test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
if not test_result: if not test_result:
if assert_node.msg: if assert_node.msg:
msg = evaluate_ast(assert_node.msg, state, tools) msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
raise AssertionError(msg) raise AssertionError(msg)
else: else:
# Include the failing condition in the assertion message # Include the failing condition in the assertion message
...@@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools): ...@@ -591,10 +637,10 @@ def evaluate_assert(assert_node, state, tools):
raise AssertionError(f"Assertion failed: {test_code}") raise AssertionError(f"Assertion failed: {test_code}")
def evaluate_with(with_node, state, tools): def evaluate_with(with_node, state, static_tools, custom_tools):
contexts = [] contexts = []
for item in with_node.items: for item in with_node.items:
context_expr = evaluate_ast(item.context_expr, state, tools) context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools)
if item.optional_vars: if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__() state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id]) contexts.append(state[item.optional_vars.id])
...@@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools): ...@@ -604,7 +650,7 @@ def evaluate_with(with_node, state, tools):
try: try:
for stmt in with_node.body: for stmt in with_node.body:
evaluate_ast(stmt, state, tools) evaluate_ast(stmt, state, static_tools, custom_tools)
except Exception as e: except Exception as e:
for context in reversed(contexts): for context in reversed(contexts):
context.__exit__(type(e), e, e.__traceback__) context.__exit__(type(e), e, e.__traceback__)
...@@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools): ...@@ -614,10 +660,51 @@ def evaluate_with(with_node, state, tools):
context.__exit__(None, None, None) context.__exit__(None, None, None)
def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
module_path = module_name.split(".")
module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)]
return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
module = import_module(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
)
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
module = __import__(expression.module, fromlist=[alias.name for alias in expression.names])
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None
def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
result = {}
for gen in dictcomp.generators:
iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
for value in iter_value:
new_state = state.copy()
set_value(gen.target, value, new_state, static_tools, custom_tools)
if all(evaluate_ast(if_clause, new_state, static_tools, custom_tools) for if_clause in gen.ifs):
key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
val = evaluate_ast(dictcomp.value, new_state, static_tools, custom_tools)
result[key] = val
return result
def evaluate_ast( def evaluate_ast(
expression: ast.AST, expression: ast.AST,
state: Dict[str, Any], state: Dict[str, Any],
tools: Dict[str, Callable], static_tools: Dict[str, Callable],
custom_tools: Dict[str, Callable],
authorized_imports: List[str] = LIST_SAFE_MODULES, authorized_imports: List[str] = LIST_SAFE_MODULES,
): ):
""" """
...@@ -632,146 +719,128 @@ def evaluate_ast( ...@@ -632,146 +719,128 @@ def evaluate_ast(
state (`Dict[str, Any]`): state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation
encounters assignements. encounters assignements.
tools (`Dict[str, Callable]`): static_tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error.
`InterpreterError`. custom_tools (`Dict[str, Callable]`):
Functions that may be called during the evaluation. These static_tools can be overwritten.
authorized_imports (`List[str]`): authorized_imports (`List[str]`):
The list of modules that can be imported by the code. By default, only a few safe modules are allowed. The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
Add more at your own risk! Add more at your own risk!
""" """
global OPERATIONS_COUNT
if OPERATIONS_COUNT >= MAX_OPERATIONS:
raise InterpreterError(
f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations."
)
OPERATIONS_COUNT += 1
if isinstance(expression, ast.Assign): if isinstance(expression, ast.Assign):
# Assignement -> we evaluate the assignment which should update the state # Assignement -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result. # We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, tools) return evaluate_assign(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.AugAssign): elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, tools) return evaluate_augassign(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Call): elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call # Function call -> we return the value of the function call
return evaluate_call(expression, state, tools) return evaluate_call(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Constant): elif isinstance(expression, ast.Constant):
# Constant -> just return the value # Constant -> just return the value
return expression.value return expression.value
elif isinstance(expression, ast.Tuple): elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts) return tuple(evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts)
elif isinstance(expression, ast.ListComp): elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
return evaluate_listcomp(expression, state, tools) return evaluate_listcomp(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.UnaryOp): elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, tools) return evaluate_unaryop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Starred):
return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.BoolOp): elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation # Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, tools) return evaluate_boolop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Break): elif isinstance(expression, ast.Break):
raise BreakException() raise BreakException()
elif isinstance(expression, ast.Continue): elif isinstance(expression, ast.Continue):
raise ContinueException() raise ContinueException()
elif isinstance(expression, ast.BinOp): elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation # Binary operation -> execute operation
return evaluate_binop(expression, state, tools) return evaluate_binop(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Compare): elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison # Comparison -> evaluate the comparison
return evaluate_condition(expression, state, tools) return evaluate_condition(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Lambda): elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, tools) return evaluate_lambda(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.FunctionDef): elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, tools) return evaluate_function_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Dict): elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values # Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, tools) for k in expression.keys] keys = [evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys]
values = [evaluate_ast(v, state, tools) for v in expression.values] values = [evaluate_ast(v, state, static_tools, custom_tools) for v in expression.values]
return dict(zip(keys, values)) return dict(zip(keys, values))
elif isinstance(expression, ast.Expr): elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content # Expression -> evaluate the content
return evaluate_ast(expression.value, state, tools) return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.For): elif isinstance(expression, ast.For):
# For loop -> execute the loop # For loop -> execute the loop
return evaluate_for(expression, state, tools) return evaluate_for(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.FormattedValue): elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return # Formatted value (part of f-string) -> evaluate the content and return
return evaluate_ast(expression.value, state, tools) return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.If): elif isinstance(expression, ast.If):
# If -> execute the right branch # If -> execute the right branch
return evaluate_if(expression, state, tools) return evaluate_if(expression, state, static_tools, custom_tools)
elif hasattr(ast, "Index") and isinstance(expression, ast.Index): elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
return evaluate_ast(expression.value, state, tools) return evaluate_ast(expression.value, state, static_tools, custom_tools)
elif isinstance(expression, ast.JoinedStr): elif isinstance(expression, ast.JoinedStr):
return "".join([str(evaluate_ast(v, state, tools)) for v in expression.values]) return "".join([str(evaluate_ast(v, state, static_tools, custom_tools)) for v in expression.values])
elif isinstance(expression, ast.List): elif isinstance(expression, ast.List):
# List -> evaluate all elements # List -> evaluate all elements
return [evaluate_ast(elt, state, tools) for elt in expression.elts] return [evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts]
elif isinstance(expression, ast.Name): elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state # Name -> pick up the value in the state
return evaluate_name(expression, state, tools) return evaluate_name(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Subscript): elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing # Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, tools) return evaluate_subscript(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.IfExp): elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, tools) test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
if test_val: if test_val:
return evaluate_ast(expression.body, state, tools) return evaluate_ast(expression.body, state, static_tools, custom_tools)
else: else:
return evaluate_ast(expression.orelse, state, tools) return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
elif isinstance(expression, ast.Attribute): elif isinstance(expression, ast.Attribute):
obj = evaluate_ast(expression.value, state, tools) value = evaluate_ast(expression.value, state, static_tools, custom_tools)
return getattr(obj, expression.attr) return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice): elif isinstance(expression, ast.Slice):
return slice( return slice(
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None, evaluate_ast(expression.lower, state, static_tools, custom_tools)
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None, if expression.lower is not None
evaluate_ast(expression.step, state, tools) if expression.step is not None else None, else None,
evaluate_ast(expression.upper, state, static_tools, custom_tools)
if expression.upper is not None
else None,
evaluate_ast(expression.step, state, static_tools, custom_tools) if expression.step is not None else None,
) )
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
result = []
vars = {}
for generator in expression.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value:
vars[var_name] = value
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
result.append(elem)
return result
elif isinstance(expression, ast.DictComp): elif isinstance(expression, ast.DictComp):
result = {} return evaluate_dictcomp(expression, state, static_tools, custom_tools)
for gen in expression.generators:
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
state[gen.target.id] = container
key = evaluate_ast(expression.key, state, tools)
value = evaluate_ast(expression.value, state, tools)
result[key] = value
return result
elif isinstance(expression, ast.Import):
for alias in expression.names:
if alias.name in authorized_imports:
module = __import__(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpreterError(f"Import of {alias.name} is not allowed.")
return None
elif isinstance(expression, ast.While): elif isinstance(expression, ast.While):
return evaluate_while(expression, state, tools) return evaluate_while(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.ImportFrom): elif isinstance(expression, (ast.Import, ast.ImportFrom)):
if expression.module in authorized_imports: return import_modules(expression, state, authorized_imports)
module = __import__(expression.module)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None
elif isinstance(expression, ast.ClassDef): elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, tools) return evaluate_class_def(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Try): elif isinstance(expression, ast.Try):
return evaluate_try(expression, state, tools) return evaluate_try(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Raise): elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, tools) return evaluate_raise(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Assert): elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, tools) return evaluate_assert(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.With): elif isinstance(expression, ast.With):
return evaluate_with(expression, state, tools) return evaluate_with(expression, state, static_tools, custom_tools)
elif isinstance(expression, ast.Set): elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, tools) for elt in expression.elts} return {evaluate_ast(elt, state, static_tools, custom_tools) for elt in expression.elts}
elif isinstance(expression, ast.Return): elif isinstance(expression, ast.Return):
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None) raise ReturnException(
evaluate_ast(expression.value, state, static_tools, custom_tools) if expression.value else None
)
else: else:
# For now we refuse anything else. Let's add things as we need them. # For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.") raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
...@@ -779,7 +848,8 @@ def evaluate_ast( ...@@ -779,7 +848,8 @@ def evaluate_ast(
def evaluate_python_code( def evaluate_python_code(
code: str, code: str,
tools: Optional[Dict[str, Callable]] = None, static_tools: Optional[Dict[str, Callable]] = None,
custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None, state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = LIST_SAFE_MODULES, authorized_imports: List[str] = LIST_SAFE_MODULES,
): ):
...@@ -792,9 +862,12 @@ def evaluate_python_code( ...@@ -792,9 +862,12 @@ def evaluate_python_code(
Args: Args:
code (`str`): code (`str`):
The code to evaluate. The code to evaluate.
tools (`Dict[str, Callable]`): static_tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an The functions that may be called during the evaluation.
`InterpreterError`. These tools cannot be overwritten in the code: any assignment to their name will raise an error.
custom_tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation.
These tools can be overwritten in the code: any assignment to their name will overwrite them.
state (`Dict[str, Any]`): state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated. updated by this function to contain all variables as they are evaluated.
...@@ -806,20 +879,34 @@ def evaluate_python_code( ...@@ -806,20 +879,34 @@ def evaluate_python_code(
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}") raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
if state is None: if state is None:
state = {} state = {}
if tools is None: if static_tools is None:
tools = {} static_tools = {}
if custom_tools is None:
custom_tools = {}
result = None result = None
global PRINT_OUTPUTS global PRINT_OUTPUTS
PRINT_OUTPUTS = "" PRINT_OUTPUTS = ""
global OPERATIONS_COUNT
OPERATIONS_COUNT = 0
for node in expression.body: for node in expression.body:
try: try:
result = evaluate_ast(node, state, tools, authorized_imports) result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports)
except InterpreterError as e: except InterpreterError as e:
msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}" msg = ""
if len(PRINT_OUTPUTS) > 0: if len(PRINT_OUTPUTS) > 0:
msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n" if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
msg += f"Print outputs:\n{PRINT_OUTPUTS}\n====\n"
else:
msg += f"Print outputs:\n{PRINT_OUTPUTS[:MAX_LEN_OUTPUT]}\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._\n====\n"
msg += f"EXECUTION FAILED:\nEvaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg) raise InterpreterError(msg)
finally: finally:
if len(PRINT_OUTPUTS) < MAX_LEN_OUTPUT:
state["print_outputs"] = PRINT_OUTPUTS state["print_outputs"] = PRINT_OUTPUTS
else:
state["print_outputs"] = (
PRINT_OUTPUTS[:MAX_LEN_OUTPUT]
+ f"\n_Print outputs were over {MAX_LEN_OUTPUT} characters, so they have been truncated._"
)
return result return result
...@@ -223,7 +223,7 @@ Action: ...@@ -223,7 +223,7 @@ Action:
# check that add_base_tools will not interfere with existing tools # check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e: with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True) agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
assert "python_interpreter already exists in the toolbox" in str(e) assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np
import pytest import pytest
from transformers import load_tool from transformers import load_tool
...@@ -241,8 +242,41 @@ for block in text_block: ...@@ -241,8 +242,41 @@ for block in text_block:
code = """ code = """
digits, i = [1, 2, 3], 1 digits, i = [1, 2, 3], 1
digits[i], digits[i + 1] = digits[i + 1], digits[i]""" digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
code = """
def calculate_isbn_10_check_digit(number):
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
remainder = total % 11
check_digit = 11 - remainder
if check_digit == 10:
return 'X'
elif check_digit == 11:
return '0'
else:
return str(check_digit)
# Given 9-digit numbers
numbers = [
"478225952",
"643485613",
"739394228",
"291726859",
"875262394",
"542617795",
"031810713",
"957007669",
"871467426"
]
# Calculate check digits for each number
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
print(check_digits)
"""
state = {} state = {}
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state) evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
)
def test_listcomp(self): def test_listcomp(self):
code = "x = [i for i in range(3)]" code = "x = [i for i in range(3)]"
...@@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]""" ...@@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
result = evaluate_python_code(code, {"range": range}, state={}) result = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4} assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert result == {102: "b"}
code = """
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self): def test_tuple_assignment(self):
code = "a, b = 0, 1\nb" code = "a, b = 0, 1\nb"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
...@@ -341,7 +386,7 @@ if char.isalpha(): ...@@ -341,7 +386,7 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose" assert result == "lose"
code = "import time\ntime.sleep(0.1)" code = "import time, re\ntime.sleep(0.1)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None assert result is None
...@@ -369,6 +414,23 @@ if char.isalpha(): ...@@ -369,6 +414,23 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A" assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
def test_additional_imports(self):
code = "import numpy as np"
evaluate_python_code(code, authorized_imports=["numpy"], state={})
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})
def test_multiple_comparators(self): def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4" code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
...@@ -400,7 +462,7 @@ def function(): ...@@ -400,7 +462,7 @@ def function():
print("2") print("2")
function()""" function()"""
state = {} state = {}
evaluate_python_code(code, {"print": print}, state) evaluate_python_code(code, {"print": print}, state=state)
assert state["print_outputs"] == "1\n2\n" assert state["print_outputs"] == "1\n2\n"
def test_tuple_target_in_iterator(self): def test_tuple_target_in_iterator(self):
...@@ -612,7 +674,7 @@ assert lock.locked == False ...@@ -612,7 +674,7 @@ assert lock.locked == False
""" """
state = {} state = {}
tools = {} tools = {}
evaluate_python_code(code, tools, state) evaluate_python_code(code, tools, state=state)
def test_default_arg_in_function(self): def test_default_arg_in_function(self):
code = """ code = """
...@@ -672,3 +734,94 @@ returns_none(1) ...@@ -672,3 +734,94 @@ returns_none(1)
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
assert result is None assert result is None
def test_nested_for_loop(self):
code = """
all_res = []
for i in range(10):
subres = []
for j in range(i):
subres.append(j)
all_res.append(subres)
out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self):
code = """
import pandas as pd
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
assert np.array_equal(result, [-1, 5])
code = """
import pandas as pd
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert np.array_equal(result.values[0], [104, 1])
code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
assert result.values[1] == 0.5
def test_starred(self):
code = """
from math import radians, sin, cos, sqrt, atan2
def haversine(lat1, lon1, lat2, lon2):
R = 6371000 # Radius of the Earth in meters
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
distance = R * c
return distance
coords_geneva = (46.1978, 6.1342)
coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
assert round(result, 1) == 622395.4
def test_for(self):
code = """
shifts = {
"Worker A": ("6:45 pm", "8:00 pm"),
"Worker B": ("10:00 am", "11:45 am")
}
shift_intervals = {}
for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment