Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15560252
Unverified
Commit
15560252
authored
Jul 05, 2024
by
Aymeric Roucher
Committed by
GitHub
Jul 05, 2024
Browse files
Code agent: allow function persistence between steps (#31769)
* Code agent: allow function persistence between steps
parent
eef0507f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
11 deletions
+63
-11
src/transformers/agents/agent_types.py
src/transformers/agents/agent_types.py
+6
-3
src/transformers/agents/agents.py
src/transformers/agents/agents.py
+5
-2
src/transformers/agents/python_interpreter.py
src/transformers/agents/python_interpreter.py
+6
-1
tests/agents/test_agents.py
tests/agents/test_agents.py
+46
-3
tests/agents/test_python_interpreter.py
tests/agents/test_python_interpreter.py
+0
-2
No files found.
src/transformers/agents/agent_types.py
View file @
15560252
...
...
@@ -188,7 +188,7 @@ class AgentAudio(AgentType, str):
self
.
samplerate
=
samplerate
if
isinstance
(
value
,
(
str
,
pathlib
.
Path
)):
self
.
_path
=
value
elif
isinstance
(
value
,
torch
.
Tensor
):
elif
is_torch_available
()
and
isinstance
(
value
,
torch
.
Tensor
):
self
.
_tensor
=
value
elif
isinstance
(
value
,
tuple
):
self
.
samplerate
=
value
[
0
]
...
...
@@ -232,7 +232,10 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING
=
{
"text"
:
AgentText
,
"image"
:
AgentImage
,
"audio"
:
AgentAudio
}
INSTANCE_TYPE_MAPPING
=
{
str
:
AgentText
,
float
:
AgentText
,
int
:
AgentText
,
Tensor
:
AgentAudio
,
ImageType
:
AgentImage
}
INSTANCE_TYPE_MAPPING
=
{
str
:
AgentText
,
ImageType
:
AgentImage
}
if
is_torch_available
():
INSTANCE_TYPE_MAPPING
[
Tensor
]
=
AgentAudio
def
handle_agent_inputs
(
*
args
,
**
kwargs
):
...
...
@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
for
_k
,
_v
in
INSTANCE_TYPE_MAPPING
.
items
():
if
isinstance
(
output
,
_k
):
return
_v
(
output
)
return
AgentType
(
output
)
return
output
src/transformers/agents/agents.py
View file @
15560252
...
...
@@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent):
self
.
additional_authorized_imports
=
additional_authorized_imports
if
additional_authorized_imports
else
[]
self
.
authorized_imports
=
list
(
set
(
LIST_SAFE_MODULES
)
|
set
(
self
.
additional_authorized_imports
))
self
.
system_prompt
=
self
.
system_prompt
.
replace
(
"<<authorized_imports>>"
,
str
(
self
.
authorized_imports
))
self
.
available_tools
=
{
**
BASE_PYTHON_TOOLS
.
copy
(),
**
self
.
toolbox
.
tools
,
}
# This list can be augmented by the code agent creating some new functions
def
step
(
self
):
"""
...
...
@@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent):
# Execute
self
.
log_code_action
(
code_action
)
try
:
available_tools
=
{
**
BASE_PYTHON_TOOLS
.
copy
(),
**
self
.
toolbox
.
tools
}
result
=
self
.
python_evaluator
(
code_action
,
available_tools
,
tools
=
self
.
available_tools
,
state
=
self
.
state
,
authorized_imports
=
self
.
authorized_imports
,
)
...
...
src/transformers/agents/python_interpreter.py
View file @
15560252
...
...
@@ -778,7 +778,10 @@ def evaluate_ast(
def
evaluate_python_code
(
code
:
str
,
tools
:
Optional
[
Dict
[
str
,
Callable
]]
=
{},
state
=
None
,
authorized_imports
:
List
[
str
]
=
LIST_SAFE_MODULES
code
:
str
,
tools
:
Optional
[
Dict
[
str
,
Callable
]]
=
None
,
state
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
authorized_imports
:
List
[
str
]
=
LIST_SAFE_MODULES
,
):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
...
...
@@ -803,6 +806,8 @@ def evaluate_python_code(
raise
SyntaxError
(
f
"The code generated by the agent is not valid.
\n
{
e
}
"
)
if
state
is
None
:
state
=
{}
if
tools
is
None
:
tools
=
{}
result
=
None
global
PRINT_OUTPUTS
PRINT_OUTPUTS
=
""
...
...
tests/agents/test_agents.py
View file @
15560252
...
...
@@ -94,12 +94,48 @@ final_answer("got an error")
"""
def
fake_react_code_functiondef
(
messages
,
stop_sequences
=
None
)
->
str
:
prompt
=
str
(
messages
)
if
"special_marker"
not
in
prompt
:
return
"""
Thought: Let's define the function. special_marker
Code:
```py
import numpy as np
def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
```<end_code>
"""
else
:
# We're at step 2
return
"""
Thought: I can now answer the initial question
Code:
```py
x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w)
final_answer(res)
```<end_code>
"""
def
fake_code_llm_oneshot
(
messages
,
stop_sequences
=
None
)
->
str
:
return
"""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
final_answer(result)
```
"""
def
fake_code_llm_no_return
(
messages
,
stop_sequences
=
None
)
->
str
:
return
"""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result)
```
"""
...
...
@@ -135,8 +171,8 @@ Action:
def
test_fake_react_code_agent
(
self
):
agent
=
ReactCodeAgent
(
tools
=
[
PythonInterpreterTool
()],
llm_engine
=
fake_react_code_llm
)
output
=
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
assert
isinstance
(
output
,
AgentTex
t
)
assert
output
==
"
7.2904
"
assert
isinstance
(
output
,
floa
t
)
assert
output
==
7.2904
assert
agent
.
logs
[
0
][
"task"
]
==
"What is 2 multiplied by 3.6452?"
assert
float
(
agent
.
logs
[
1
][
"observation"
].
strip
())
-
12.511648
<
1e-6
assert
agent
.
logs
[
2
][
"tool_call"
]
==
{
...
...
@@ -157,7 +193,7 @@ Action:
def
test_react_fails_max_iterations
(
self
):
agent
=
ReactCodeAgent
(
tools
=
[
PythonInterpreterTool
()],
llm_engine
=
fake_code_llm_
oneshot
,
# use this callable because it never ends
llm_engine
=
fake_code_llm_
no_return
,
# use this callable because it never ends
max_iterations
=
5
,
)
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
...
...
@@ -192,3 +228,10 @@ Action:
# check that python_interpreter base tool does not get added to code agents
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_llm
,
add_base_tools
=
True
)
assert
len
(
agent
.
toolbox
.
tools
)
==
6
# added final_answer tool + 5 base tools (excluding interpreter)
def
test_function_persistence_across_steps
(
self
):
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_functiondef
,
max_iterations
=
2
,
additional_authorized_imports
=
[
"numpy"
]
)
res
=
agent
.
run
(
"ok"
)
assert
res
[
0
]
==
0.5
tests/agents/test_python_interpreter.py
View file @
15560252
...
...
@@ -660,7 +660,6 @@ add_one(1, 1)
"""
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
print
(
state
)
assert
result
==
2
# test returning None
...
...
@@ -672,5 +671,4 @@ returns_none(1)
"""
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
print
(
state
)
assert
result
is
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment