Unverified Commit b3818805 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Agents planning (#31702)

* Allow planning for agents
parent 0fdea860
......@@ -25,7 +25,19 @@ from ..utils.import_utils import is_pygments_available
from .agent_types import AgentAudio, AgentImage, AgentText
from .default_tools import BASE_PYTHON_TOOLS, FinalAnswerTool, setup_default_tools
from .llm_engine import HfEngine, MessageRole
from .prompts import DEFAULT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_CODE_SYSTEM_PROMPT, DEFAULT_REACT_JSON_SYSTEM_PROMPT
from .prompts import (
DEFAULT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_CODE_SYSTEM_PROMPT,
DEFAULT_REACT_JSON_SYSTEM_PROMPT,
PLAN_UPDATE_FINAL_PLAN_REDACTION,
SYSTEM_PROMPT_FACTS,
SYSTEM_PROMPT_FACTS_UPDATE,
SYSTEM_PROMPT_PLAN,
SYSTEM_PROMPT_PLAN_UPDATE,
USER_PROMPT_FACTS_UPDATE,
USER_PROMPT_PLAN,
USER_PROMPT_PLAN_UPDATE,
)
from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
from .tools import (
DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
......@@ -99,12 +111,19 @@ def parse_json_blob(json_blob: str) -> Dict[str, str]:
def parse_code_blob(code_blob: str) -> str:
try:
pattern = r"```(?:py|python)?\n(.*?)```"
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code_blob, re.DOTALL)
return match.group(1).strip()
except Exception as e:
raise ValueError(
f"The code blob you used is invalid: due to the following error: {e}. This means that the regex pattern {pattern} was not respected. Make sure to correct its formatting. Code blob was: {code_blob}"
f"""
The code blob you used is invalid: due to the following error: {e}
This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
Thoughts: Your thoughts
Code:
```py
# Your python code here
```<end_action>"""
)
......@@ -113,6 +132,8 @@ def parse_json_tool_call(json_blob: str) -> Tuple[str, Dict[str, str]]:
tool_call = parse_json_blob(json_blob)
if "action" in tool_call and "action_input" in tool_call:
return tool_call["action"], tool_call["action_input"]
elif "action" in tool_call:
return tool_call["action"], None
else:
raise ValueError(
f"Missing keys: {[key for key in ['action', 'action_input'] if key not in tool_call]} in blob {tool_call}"
......@@ -208,7 +229,7 @@ class Toolbox:
The tool to add to the toolbox.
"""
if tool.name in self._tools:
raise KeyError(f"Error: tool {tool.name} already exists in the toolbox.")
raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
self._tools[tool.name] = tool
def remove_tool(self, tool_name: str):
......@@ -359,12 +380,8 @@ class Agent:
"""Get the toolbox currently available to the agent"""
return self._toolbox
def initialize_for_run(self, task: str, **kwargs):
def initialize_for_run(self):
self.token_count = 0
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.system_prompt = format_prompt_with_tools(
self._toolbox,
self.system_prompt_template,
......@@ -380,7 +397,7 @@ class Agent:
self.logger.debug("System prompt is as follows:")
self.logger.debug(self.system_prompt)
def write_inner_memory_from_logs(self) -> List[Dict[str, str]]:
def write_inner_memory_from_logs(self, summary_mode: Optional[bool] = False) -> List[Dict[str, str]]:
"""
Reads past llm_outputs, actions, and observations or errors from the logs into a series of messages
that can be used as input to the LLM.
......@@ -390,43 +407,51 @@ class Agent:
"role": MessageRole.USER,
"content": "Task: " + self.logs[0]["task"],
}
if summary_mode:
memory = [task_message]
else:
memory = [prompt_message, task_message]
for i, step_log in enumerate(self.logs[1:]):
if "llm_output" in step_log:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"] + "\n"}
if "llm_output" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": step_log["llm_output"].strip()}
memory.append(thought_message)
if "facts" in step_log:
thought_message = {
"role": MessageRole.ASSISTANT,
"content": "[FACTS LIST]:\n" + step_log["facts"].strip(),
}
memory.append(thought_message)
if "plan" in step_log and not summary_mode:
thought_message = {"role": MessageRole.ASSISTANT, "content": "[PLAN]:\n" + step_log["plan"].strip()}
memory.append(thought_message)
if "tool_call" in step_log and summary_mode:
tool_call_message = {
"role": MessageRole.ASSISTANT,
"content": f"[STEP {i} TOOL CALL]: " + str(step_log["tool_call"]).strip(),
}
memory.append(tool_call_message)
if "task" in step_log:
tool_call_message = {
"role": MessageRole.USER,
"content": "New task:\n" + step_log["task"],
}
memory.append(tool_call_message)
if "error" in step_log or "observation" in step_log:
if "error" in step_log:
message_content = (
"Error: "
f"[OUTPUT OF STEP {i}] Error: "
+ str(step_log["error"])
+ "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"
)
elif "observation" in step_log:
message_content = f"Observation: {step_log['observation']}"
message_content = f"[OUTPUT OF STEP {i}] Observation:\n{step_log['observation']}"
tool_response_message = {"role": MessageRole.TOOL_RESPONSE, "content": message_content}
memory.append(tool_response_message)
if len(memory) % 3 == 0:
reminder_content = (
"Reminder: you are working towards solving the following task: " + self.logs[0]["task"]
)
reminder_content += "\nHere is a summary of your past tool calls and their results:"
for j in range(i + 1):
reminder_content += "\nStep " + str(j + 1)
if "tool_call" in self.logs[j]:
reminder_content += "\nTool call:" + str(self.logs[j]["tool_call"])
if self.memory_verbose:
if "observation" in self.logs[j]:
reminder_content += "\nObservation:" + str(self.logs[j]["observation"])
if "error" in self.logs[j]:
reminder_content += "\nError:" + str(self.logs[j]["error"])
memory.append(
{
"role": MessageRole.USER,
"content": reminder_content,
}
)
return memory
def get_succinct_logs(self):
......@@ -459,7 +484,7 @@ class Agent:
This method replaces arguments with the actual values from the state if they refer to state variables.
Args:
tool_name (`str`): Name of the Tool to execute (shoulde be one from self.toolbox).
tool_name (`str`): Name of the Tool to execute (should be one from self.toolbox).
arguments (Dict[str, str]): Arguments passed to the Tool.
"""
if tool_name not in self.toolbox.tools:
......@@ -559,7 +584,11 @@ class CodeAgent(Agent):
agent.run("What is the result of 2 power 3.7384?")
```
"""
self.initialize_for_run(task, **kwargs)
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
self.initialize_for_run()
# Run LLM
prompt_message = {"role": MessageRole.SYSTEM, "content": self.system_prompt}
......@@ -598,7 +627,8 @@ class CodeAgent(Agent):
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
output = self.python_evaluator(
code_action,
available_tools,
static_tools=available_tools,
custom_tools={},
state=self.state,
authorized_imports=self.authorized_imports,
)
......@@ -623,6 +653,7 @@ class ReactAgent(Agent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
......@@ -632,6 +663,7 @@ class ReactAgent(Agent):
tool_description_template=tool_description_template,
**kwargs,
)
self.planning_interval = planning_interval
def provide_final_answer(self, task) -> str:
"""
......@@ -655,11 +687,13 @@ class ReactAgent(Agent):
except Exception as e:
return f"Error in generating final llm output: {e}."
def run(self, task: str, stream: bool = False, **kwargs):
def run(self, task: str, stream: bool = False, reset: bool = True, **kwargs):
"""
Runs the agent for the given task.
Args:
task (`str`): The task to perform
Example:
```py
from transformers.agents import ReactCodeAgent
......@@ -667,14 +701,23 @@ class ReactAgent(Agent):
agent.run("What is the result of 2 power 3.7384?")
```
"""
self.task = task
if len(kwargs) > 0:
self.task += f"\nYou have been provided with these initial arguments: {str(kwargs)}."
self.state = kwargs.copy()
if reset:
self.initialize_for_run()
else:
self.logs.append({"task": task})
if stream:
return self.stream_run(task, **kwargs)
return self.stream_run(task)
else:
return self.direct_run(task, **kwargs)
def stream_run(self, task: str, **kwargs):
self.initialize_for_run(task, **kwargs)
return self.direct_run(task)
def stream_run(self, task: str):
"""
Runs the agent in streaming mode, yielding steps as they are executed: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
......@@ -700,13 +743,16 @@ class ReactAgent(Agent):
yield final_answer
def direct_run(self, task: str, **kwargs):
self.initialize_for_run(task, **kwargs)
def direct_run(self, task: str):
"""
Runs the agent in direct mode, returning outputs only at the end: should be launched only in the `run` method.
"""
final_answer = None
iteration = 0
while final_answer is None and iteration < self.max_iterations:
try:
if self.planning_interval is not None and iteration % self.planning_interval == 0:
self.planning_step(task, is_first_step=(iteration == 0), iteration=iteration)
step_logs = self.step()
if "final_answer" in step_logs:
final_answer = step_logs["final_answer"]
......@@ -726,6 +772,96 @@ class ReactAgent(Agent):
return final_answer
def planning_step(self, task, is_first_step: bool = False, iteration: int = None):
"""
Used periodically by the agent to plan the next steps to reach the objective.
Args:
task (`str`): The task to perform
is_first_step (`bool`): If this step is not the first one, the plan should be an update over a previous plan.
iteration (`int`): The number of the current step, used as an indication for the LLM.
"""
if is_first_step:
message_prompt_facts = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_FACTS}
message_prompt_task = {
"role": MessageRole.USER,
"content": f"""Here is the task:
```
{task}
```
Now begin!""",
}
answer_facts = self.llm_engine([message_prompt_facts, message_prompt_task])
message_system_prompt_plan = {"role": MessageRole.SYSTEM, "content": SYSTEM_PROMPT_PLAN}
message_user_prompt_plan = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
answer_facts=answer_facts,
),
}
answer_plan = self.llm_engine(
[message_system_prompt_plan, message_user_prompt_plan], stop_sequences=["<end_plan>"]
)
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
{answer_plan}
```"""
final_facts_redaction = f"""Here are the facts that I know so far:
```
{answer_facts}
```""".strip()
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Initial plan: =====")
self.logger.debug(final_plan_redaction)
else: # update plan
agent_memory = self.write_inner_memory_from_logs(
summary_mode=False
) # This will not log the plan but will log facts
# Redact updated facts
facts_update_system_prompt = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_FACTS_UPDATE,
}
facts_update_message = {
"role": MessageRole.USER,
"content": USER_PROMPT_FACTS_UPDATE,
}
facts_update = self.llm_engine([facts_update_system_prompt] + agent_memory + [facts_update_message])
# Redact updated plan
plan_update_message = {
"role": MessageRole.SYSTEM,
"content": SYSTEM_PROMPT_PLAN_UPDATE.format(task=task),
}
plan_update_message_user = {
"role": MessageRole.USER,
"content": USER_PROMPT_PLAN_UPDATE.format(
task=task,
tool_descriptions=self._toolbox.show_tool_descriptions(self.tool_description_template),
facts_update=facts_update,
remaining_steps=(self.max_iterations - iteration),
),
}
plan_update = self.llm_engine(
[plan_update_message] + agent_memory + [plan_update_message_user], stop_sequences=["<end_plan>"]
)
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(task=task, plan_update=plan_update)
final_facts_redaction = f"""Here is the updated list of the facts that I know:
```
{facts_update}
```"""
self.logs.append({"plan": final_plan_redaction, "facts": final_facts_redaction})
self.logger.debug("===== Updated plan: =====")
self.logger.debug(final_plan_redaction)
class ReactJsonAgent(ReactAgent):
"""
......@@ -740,6 +876,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine: Callable = HfEngine(),
system_prompt: str = DEFAULT_REACT_JSON_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
......@@ -747,6 +884,7 @@ class ReactJsonAgent(ReactAgent):
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs,
)
......@@ -792,11 +930,16 @@ class ReactJsonAgent(ReactAgent):
self.logger.warning(f"Calling tool: '{tool_name}' with arguments: {arguments}")
if tool_name == "final_answer":
if isinstance(arguments, dict):
if "answer" in arguments:
answer = arguments["answer"]
if (
isinstance(answer, str) and answer in self.state.keys()
): # if the answer is a state variable, return the value
answer = self.state[answer]
else:
answer = arguments
else:
answer = arguments
if answer in self.state: # if the answer is a state variable, return the value
answer = self.state[answer]
current_step_logs["final_answer"] = answer
return current_step_logs
else:
......@@ -835,6 +978,7 @@ class ReactCodeAgent(ReactAgent):
system_prompt: str = DEFAULT_REACT_CODE_SYSTEM_PROMPT,
tool_description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE,
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
**kwargs,
):
super().__init__(
......@@ -842,6 +986,7 @@ class ReactCodeAgent(ReactAgent):
llm_engine=llm_engine,
system_prompt=system_prompt,
tool_description_template=tool_description_template,
planning_interval=planning_interval,
**kwargs,
)
......@@ -856,10 +1001,7 @@ class ReactCodeAgent(ReactAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.available_tools = {
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
} # This list can be augmented by the code agent creating some new functions
self.custom_tools = {}
def step(self):
"""
......@@ -911,7 +1053,11 @@ class ReactCodeAgent(ReactAgent):
try:
result = self.python_evaluator(
code_action,
tools=self.available_tools,
static_tools={
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
},
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
)
......@@ -920,7 +1066,7 @@ class ReactCodeAgent(ReactAgent):
self.logger.log(32, information)
current_step_logs["observation"] = information
except Exception as e:
error_msg = f"Failed while trying to execute the code below:\n{CustomFormatter.reset + code_action + CustomFormatter.reset}\nThis failed due to the following error:\n{str(e)}"
error_msg = f"Code execution failed due to the following error:\n{str(e)}"
if "'dict' object has no attribute 'read'" in str(e):
error_msg += "\nYou get this error because you passed a dict as input for one of the arguments instead of a string."
raise AgentExecutionError(error_msg)
......
......@@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
def forward(self, code):
output = str(
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
evaluate_python_code(code, static_tools=self.available_tools, authorized_imports=self.authorized_imports)
)
return output
......
......@@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
"""
SYSTEM_PROMPT_FACTS = """Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
---
### 1. Facts given in the task
List here the specific facts given in the task that could help you (there might be nothing here).
### 2. Facts to look up
List here any facts that we may need to look up.
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
### 3. Facts to derive
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
### 1. Facts given in the task
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else."""
SYSTEM_PROMPT_PLAN = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there."""
USER_PROMPT_PLAN = """
Here is your task:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
List of facts that you know:
```
{answer_facts}
```
Now begin! Write your plan below."""
SYSTEM_PROMPT_FACTS_UPDATE = """
You are a world expert at gathering known and unknown facts based on a conversation.
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Find the task and history below."""
USER_PROMPT_FACTS_UPDATE = """Earlier we've built a list of facts.
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
Please update your list of facts based on the previous history, and provide these headings:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Now write your new list of facts below."""
SYSTEM_PROMPT_PLAN_UPDATE = """You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
You have been given a task:
```
{task}
```
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
If the previous tries so far have met some success, you can make an updated plan based on these actions.
If you are stalled, you can make a completely new plan starting from scratch.
"""
USER_PROMPT_PLAN_UPDATE = """You're still working towards solving this task:
```
{task}
```
You have access to these tools:
{tool_descriptions}
Here is the up to date list of facts that you know:
```
{facts_update}
```
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Beware that you have {remaining_steps} steps remaining.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '\n<end_plan>' tag and stop there.
Now write your new plan below."""
PLAN_UPDATE_FINAL_PLAN_REDACTION = """I still need to solve the task I was given:
```
{task}
```
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""
......@@ -223,7 +223,7 @@ Action:
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
assert "python_interpreter already exists in the toolbox" in str(e)
assert "already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
import pytest
from transformers import load_tool
......@@ -241,8 +242,41 @@ for block in text_block:
code = """
digits, i = [1, 2, 3], 1
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
evaluate_python_code(code, {"range": range, "print": print, "int": int}, {})
code = """
def calculate_isbn_10_check_digit(number):
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
remainder = total % 11
check_digit = 11 - remainder
if check_digit == 10:
return 'X'
elif check_digit == 11:
return '0'
else:
return str(check_digit)
# Given 9-digit numbers
numbers = [
"478225952",
"643485613",
"739394228",
"291726859",
"875262394",
"542617795",
"031810713",
"957007669",
"871467426"
]
# Calculate check digits for each number
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
print(check_digits)
"""
state = {}
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
evaluate_python_code(
code, {"range": range, "print": print, "sum": sum, "enumerate": enumerate, "int": int, "str": str}, state
)
def test_listcomp(self):
code = "x = [i for i in range(3)]"
......@@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
result = evaluate_python_code(code, {"range": range}, state={})
assert result == {0: 0, 1: 1, 2: 4}
code = "{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert result == {102: "b"}
code = """
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result = evaluate_python_code(code, {}, state={})
assert result == {"A": ("a", "b"), "B": ("a", "b")}
def test_tuple_assignment(self):
code = "a, b = 0, 1\nb"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
......@@ -341,7 +386,7 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "lose"
code = "import time\ntime.sleep(0.1)"
code = "import time, re\ntime.sleep(0.1)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result is None
......@@ -369,6 +414,23 @@ if char.isalpha():
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result == "LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code = "import numpy.random as rd\nrng = rd.default_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
code = "from numpy.random import default_rng as d_rng\nrng = d_rng(12345)\nrng.random()"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}, authorized_imports=["numpy"])
def test_additional_imports(self):
code = "import numpy as np"
evaluate_python_code(code, authorized_imports=["numpy"], state={})
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})
def test_multiple_comparators(self):
code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
......@@ -400,7 +462,7 @@ def function():
print("2")
function()"""
state = {}
evaluate_python_code(code, {"print": print}, state)
evaluate_python_code(code, {"print": print}, state=state)
assert state["print_outputs"] == "1\n2\n"
def test_tuple_target_in_iterator(self):
......@@ -612,7 +674,7 @@ assert lock.locked == False
"""
state = {}
tools = {}
evaluate_python_code(code, tools, state)
evaluate_python_code(code, tools, state=state)
def test_default_arg_in_function(self):
code = """
......@@ -672,3 +734,94 @@ returns_none(1)
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
assert result is None
def test_nested_for_loop(self):
code = """
all_res = []
for i in range(10):
subres = []
for j in range(i):
subres.append(j)
all_res.append(subres)
out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range}, state=state)
assert result == [0, 0, 1, 0, 1, 2, 0, 1, 2, 3]
def test_pandas(self):
code = """
import pandas as pd
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state, authorized_imports=["pandas"])
assert np.array_equal(result, [-1, 5])
code = """
import pandas as pd
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result = evaluate_python_code(code, {"print": print}, state={}, authorized_imports=["pandas"])
assert np.array_equal(result.values[0], [104, 1])
code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result = evaluate_python_code(code, {}, state={}, authorized_imports=["pandas"])
assert result.values[1] == 0.5
def test_starred(self):
code = """
from math import radians, sin, cos, sqrt, atan2
def haversine(lat1, lon1, lat2, lon2):
R = 6371000 # Radius of the Earth in meters
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
distance = R * c
return distance
coords_geneva = (46.1978, 6.1342)
coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={}, authorized_imports=["math"])
assert round(result, 1) == 622395.4
def test_for(self):
code = """
shifts = {
"Worker A": ("6:45 pm", "8:00 pm"),
"Worker B": ("10:00 am", "11:45 am")
}
shift_intervals = {}
for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
result = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment