Unverified Commit 15560252 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Code agent: allow function persistence between steps (#31769)

* Code agent: allow function persistence between steps
parent eef0507f
...@@ -188,7 +188,7 @@ class AgentAudio(AgentType, str): ...@@ -188,7 +188,7 @@ class AgentAudio(AgentType, str):
self.samplerate = samplerate self.samplerate = samplerate
if isinstance(value, (str, pathlib.Path)): if isinstance(value, (str, pathlib.Path)):
self._path = value self._path = value
elif isinstance(value, torch.Tensor): elif is_torch_available() and isinstance(value, torch.Tensor):
self._tensor = value self._tensor = value
elif isinstance(value, tuple): elif isinstance(value, tuple):
self.samplerate = value[0] self.samplerate = value[0]
...@@ -232,7 +232,10 @@ class AgentAudio(AgentType, str): ...@@ -232,7 +232,10 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio} AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText, float: AgentText, int: AgentText, Tensor: AgentAudio, ImageType: AgentImage} INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
if is_torch_available():
INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
def handle_agent_inputs(*args, **kwargs): def handle_agent_inputs(*args, **kwargs):
...@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None): ...@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
for _k, _v in INSTANCE_TYPE_MAPPING.items(): for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(output, _k): if isinstance(output, _k):
return _v(output) return _v(output)
return AgentType(output) return output
...@@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent): ...@@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent):
self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else [] self.additional_authorized_imports = additional_authorized_imports if additional_authorized_imports else []
self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(self.additional_authorized_imports))
self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports)) self.system_prompt = self.system_prompt.replace("<<authorized_imports>>", str(self.authorized_imports))
self.available_tools = {
**BASE_PYTHON_TOOLS.copy(),
**self.toolbox.tools,
} # This list can be augmented by the code agent creating some new functions
def step(self): def step(self):
""" """
...@@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent): ...@@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent):
# Execute # Execute
self.log_code_action(code_action) self.log_code_action(code_action)
try: try:
available_tools = {**BASE_PYTHON_TOOLS.copy(), **self.toolbox.tools}
result = self.python_evaluator( result = self.python_evaluator(
code_action, code_action,
available_tools, tools=self.available_tools,
state=self.state, state=self.state,
authorized_imports=self.authorized_imports, authorized_imports=self.authorized_imports,
) )
......
...@@ -778,7 +778,10 @@ def evaluate_ast( ...@@ -778,7 +778,10 @@ def evaluate_ast(
def evaluate_python_code( def evaluate_python_code(
code: str, tools: Optional[Dict[str, Callable]] = {}, state=None, authorized_imports: List[str] = LIST_SAFE_MODULES code: str,
tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = LIST_SAFE_MODULES,
): ):
""" """
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
...@@ -803,6 +806,8 @@ def evaluate_python_code( ...@@ -803,6 +806,8 @@ def evaluate_python_code(
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}") raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
if state is None: if state is None:
state = {} state = {}
if tools is None:
tools = {}
result = None result = None
global PRINT_OUTPUTS global PRINT_OUTPUTS
PRINT_OUTPUTS = "" PRINT_OUTPUTS = ""
......
...@@ -94,12 +94,48 @@ final_answer("got an error") ...@@ -94,12 +94,48 @@ final_answer("got an error")
""" """
def fake_react_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: Let's define the function. special_marker
Code:
```py
import numpy as np
def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
```<end_code>
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Code:
```py
x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w)
final_answer(res)
```<end_code>
"""
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
return """ return """
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
Code: Code:
```py ```py
result = python_interpreter(code="2*3.6452") result = python_interpreter(code="2*3.6452")
final_answer(result)
```
"""
def fake_code_llm_no_return(messages, stop_sequences=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result) print(result)
``` ```
""" """
...@@ -135,8 +171,8 @@ Action: ...@@ -135,8 +171,8 @@ Action:
def test_fake_react_code_agent(self): def test_fake_react_code_agent(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm) agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
output = agent.run("What is 2 multiplied by 3.6452?") output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText) assert isinstance(output, float)
assert output == "7.2904" assert output == 7.2904
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?" assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6 assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
assert agent.logs[2]["tool_call"] == { assert agent.logs[2]["tool_call"] == {
...@@ -157,7 +193,7 @@ Action: ...@@ -157,7 +193,7 @@ Action:
def test_react_fails_max_iterations(self): def test_react_fails_max_iterations(self):
agent = ReactCodeAgent( agent = ReactCodeAgent(
tools=[PythonInterpreterTool()], tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends llm_engine=fake_code_llm_no_return, # use this callable because it never ends
max_iterations=5, max_iterations=5,
) )
agent.run("What is 2 multiplied by 3.6452?") agent.run("What is 2 multiplied by 3.6452?")
...@@ -192,3 +228,10 @@ Action: ...@@ -192,3 +228,10 @@ Action:
# check that python_interpreter base tool does not get added to code agents # check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True) agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter) assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
def test_function_persistence_across_steps(self):
agent = ReactCodeAgent(
tools=[], llm_engine=fake_react_code_functiondef, max_iterations=2, additional_authorized_imports=["numpy"]
)
res = agent.run("ok")
assert res[0] == 0.5
...@@ -660,7 +660,6 @@ add_one(1, 1) ...@@ -660,7 +660,6 @@ add_one(1, 1)
""" """
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result == 2 assert result == 2
# test returning None # test returning None
...@@ -672,5 +671,4 @@ returns_none(1) ...@@ -672,5 +671,4 @@ returns_none(1)
""" """
state = {} state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state) result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result is None assert result is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment