Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15560252
Unverified
Commit
15560252
authored
Jul 05, 2024
by
Aymeric Roucher
Committed by
GitHub
Jul 05, 2024
Browse files
Code agent: allow function persistence between steps (#31769)
* Code agent: allow function persistence between steps
parent
eef0507f
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
63 additions
and
11 deletions
+63
-11
src/transformers/agents/agent_types.py
src/transformers/agents/agent_types.py
+6
-3
src/transformers/agents/agents.py
src/transformers/agents/agents.py
+5
-2
src/transformers/agents/python_interpreter.py
src/transformers/agents/python_interpreter.py
+6
-1
tests/agents/test_agents.py
tests/agents/test_agents.py
+46
-3
tests/agents/test_python_interpreter.py
tests/agents/test_python_interpreter.py
+0
-2
No files found.
src/transformers/agents/agent_types.py
View file @
15560252
...
@@ -188,7 +188,7 @@ class AgentAudio(AgentType, str):
...
@@ -188,7 +188,7 @@ class AgentAudio(AgentType, str):
self
.
samplerate
=
samplerate
self
.
samplerate
=
samplerate
if
isinstance
(
value
,
(
str
,
pathlib
.
Path
)):
if
isinstance
(
value
,
(
str
,
pathlib
.
Path
)):
self
.
_path
=
value
self
.
_path
=
value
elif
isinstance
(
value
,
torch
.
Tensor
):
elif
is_torch_available
()
and
isinstance
(
value
,
torch
.
Tensor
):
self
.
_tensor
=
value
self
.
_tensor
=
value
elif
isinstance
(
value
,
tuple
):
elif
isinstance
(
value
,
tuple
):
self
.
samplerate
=
value
[
0
]
self
.
samplerate
=
value
[
0
]
...
@@ -232,7 +232,10 @@ class AgentAudio(AgentType, str):
...
@@ -232,7 +232,10 @@ class AgentAudio(AgentType, str):
AGENT_TYPE_MAPPING
=
{
"text"
:
AgentText
,
"image"
:
AgentImage
,
"audio"
:
AgentAudio
}
AGENT_TYPE_MAPPING
=
{
"text"
:
AgentText
,
"image"
:
AgentImage
,
"audio"
:
AgentAudio
}
INSTANCE_TYPE_MAPPING
=
{
str
:
AgentText
,
float
:
AgentText
,
int
:
AgentText
,
Tensor
:
AgentAudio
,
ImageType
:
AgentImage
}
INSTANCE_TYPE_MAPPING
=
{
str
:
AgentText
,
ImageType
:
AgentImage
}
if
is_torch_available
():
INSTANCE_TYPE_MAPPING
[
Tensor
]
=
AgentAudio
def
handle_agent_inputs
(
*
args
,
**
kwargs
):
def
handle_agent_inputs
(
*
args
,
**
kwargs
):
...
@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
...
@@ -251,4 +254,4 @@ def handle_agent_outputs(output, output_type=None):
for
_k
,
_v
in
INSTANCE_TYPE_MAPPING
.
items
():
for
_k
,
_v
in
INSTANCE_TYPE_MAPPING
.
items
():
if
isinstance
(
output
,
_k
):
if
isinstance
(
output
,
_k
):
return
_v
(
output
)
return
_v
(
output
)
return
AgentType
(
output
)
return
output
src/transformers/agents/agents.py
View file @
15560252
...
@@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent):
...
@@ -856,6 +856,10 @@ class ReactCodeAgent(ReactAgent):
self
.
additional_authorized_imports
=
additional_authorized_imports
if
additional_authorized_imports
else
[]
self
.
additional_authorized_imports
=
additional_authorized_imports
if
additional_authorized_imports
else
[]
self
.
authorized_imports
=
list
(
set
(
LIST_SAFE_MODULES
)
|
set
(
self
.
additional_authorized_imports
))
self
.
authorized_imports
=
list
(
set
(
LIST_SAFE_MODULES
)
|
set
(
self
.
additional_authorized_imports
))
self
.
system_prompt
=
self
.
system_prompt
.
replace
(
"<<authorized_imports>>"
,
str
(
self
.
authorized_imports
))
self
.
system_prompt
=
self
.
system_prompt
.
replace
(
"<<authorized_imports>>"
,
str
(
self
.
authorized_imports
))
self
.
available_tools
=
{
**
BASE_PYTHON_TOOLS
.
copy
(),
**
self
.
toolbox
.
tools
,
}
# This list can be augmented by the code agent creating some new functions
def
step
(
self
):
def
step
(
self
):
"""
"""
...
@@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent):
...
@@ -905,10 +909,9 @@ class ReactCodeAgent(ReactAgent):
# Execute
# Execute
self
.
log_code_action
(
code_action
)
self
.
log_code_action
(
code_action
)
try
:
try
:
available_tools
=
{
**
BASE_PYTHON_TOOLS
.
copy
(),
**
self
.
toolbox
.
tools
}
result
=
self
.
python_evaluator
(
result
=
self
.
python_evaluator
(
code_action
,
code_action
,
available_tools
,
tools
=
self
.
available_tools
,
state
=
self
.
state
,
state
=
self
.
state
,
authorized_imports
=
self
.
authorized_imports
,
authorized_imports
=
self
.
authorized_imports
,
)
)
...
...
src/transformers/agents/python_interpreter.py
View file @
15560252
...
@@ -778,7 +778,10 @@ def evaluate_ast(
...
@@ -778,7 +778,10 @@ def evaluate_ast(
def
evaluate_python_code
(
def
evaluate_python_code
(
code
:
str
,
tools
:
Optional
[
Dict
[
str
,
Callable
]]
=
{},
state
=
None
,
authorized_imports
:
List
[
str
]
=
LIST_SAFE_MODULES
code
:
str
,
tools
:
Optional
[
Dict
[
str
,
Callable
]]
=
None
,
state
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
authorized_imports
:
List
[
str
]
=
LIST_SAFE_MODULES
,
):
):
"""
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
...
@@ -803,6 +806,8 @@ def evaluate_python_code(
...
@@ -803,6 +806,8 @@ def evaluate_python_code(
raise
SyntaxError
(
f
"The code generated by the agent is not valid.
\n
{
e
}
"
)
raise
SyntaxError
(
f
"The code generated by the agent is not valid.
\n
{
e
}
"
)
if
state
is
None
:
if
state
is
None
:
state
=
{}
state
=
{}
if
tools
is
None
:
tools
=
{}
result
=
None
result
=
None
global
PRINT_OUTPUTS
global
PRINT_OUTPUTS
PRINT_OUTPUTS
=
""
PRINT_OUTPUTS
=
""
...
...
tests/agents/test_agents.py
View file @
15560252
...
@@ -94,12 +94,48 @@ final_answer("got an error")
...
@@ -94,12 +94,48 @@ final_answer("got an error")
"""
"""
def
fake_react_code_functiondef
(
messages
,
stop_sequences
=
None
)
->
str
:
prompt
=
str
(
messages
)
if
"special_marker"
not
in
prompt
:
return
"""
Thought: Let's define the function. special_marker
Code:
```py
import numpy as np
def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
```<end_code>
"""
else
:
# We're at step 2
return
"""
Thought: I can now answer the initial question
Code:
```py
x, w = [0, 1, 2, 3, 4, 5], 2
res = moving_average(x, w)
final_answer(res)
```<end_code>
"""
def
fake_code_llm_oneshot
(
messages
,
stop_sequences
=
None
)
->
str
:
def
fake_code_llm_oneshot
(
messages
,
stop_sequences
=
None
)
->
str
:
return
"""
return
"""
Thought: I should multiply 2 by 3.6452. special_marker
Thought: I should multiply 2 by 3.6452. special_marker
Code:
Code:
```py
```py
result = python_interpreter(code="2*3.6452")
result = python_interpreter(code="2*3.6452")
final_answer(result)
```
"""
def
fake_code_llm_no_return
(
messages
,
stop_sequences
=
None
)
->
str
:
return
"""
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result)
print(result)
```
```
"""
"""
...
@@ -135,8 +171,8 @@ Action:
...
@@ -135,8 +171,8 @@ Action:
def
test_fake_react_code_agent
(
self
):
def
test_fake_react_code_agent
(
self
):
agent
=
ReactCodeAgent
(
tools
=
[
PythonInterpreterTool
()],
llm_engine
=
fake_react_code_llm
)
agent
=
ReactCodeAgent
(
tools
=
[
PythonInterpreterTool
()],
llm_engine
=
fake_react_code_llm
)
output
=
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
output
=
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
assert
isinstance
(
output
,
AgentTex
t
)
assert
isinstance
(
output
,
floa
t
)
assert
output
==
"
7.2904
"
assert
output
==
7.2904
assert
agent
.
logs
[
0
][
"task"
]
==
"What is 2 multiplied by 3.6452?"
assert
agent
.
logs
[
0
][
"task"
]
==
"What is 2 multiplied by 3.6452?"
assert
float
(
agent
.
logs
[
1
][
"observation"
].
strip
())
-
12.511648
<
1e-6
assert
float
(
agent
.
logs
[
1
][
"observation"
].
strip
())
-
12.511648
<
1e-6
assert
agent
.
logs
[
2
][
"tool_call"
]
==
{
assert
agent
.
logs
[
2
][
"tool_call"
]
==
{
...
@@ -157,7 +193,7 @@ Action:
...
@@ -157,7 +193,7 @@ Action:
def
test_react_fails_max_iterations
(
self
):
def
test_react_fails_max_iterations
(
self
):
agent
=
ReactCodeAgent
(
agent
=
ReactCodeAgent
(
tools
=
[
PythonInterpreterTool
()],
tools
=
[
PythonInterpreterTool
()],
llm_engine
=
fake_code_llm_
oneshot
,
# use this callable because it never ends
llm_engine
=
fake_code_llm_
no_return
,
# use this callable because it never ends
max_iterations
=
5
,
max_iterations
=
5
,
)
)
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
agent
.
run
(
"What is 2 multiplied by 3.6452?"
)
...
@@ -192,3 +228,10 @@ Action:
...
@@ -192,3 +228,10 @@ Action:
# check that python_interpreter base tool does not get added to code agents
# check that python_interpreter base tool does not get added to code agents
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_llm
,
add_base_tools
=
True
)
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_llm
,
add_base_tools
=
True
)
assert
len
(
agent
.
toolbox
.
tools
)
==
6
# added final_answer tool + 5 base tools (excluding interpreter)
assert
len
(
agent
.
toolbox
.
tools
)
==
6
# added final_answer tool + 5 base tools (excluding interpreter)
def
test_function_persistence_across_steps
(
self
):
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_functiondef
,
max_iterations
=
2
,
additional_authorized_imports
=
[
"numpy"
]
)
res
=
agent
.
run
(
"ok"
)
assert
res
[
0
]
==
0.5
tests/agents/test_python_interpreter.py
View file @
15560252
...
@@ -660,7 +660,6 @@ add_one(1, 1)
...
@@ -660,7 +660,6 @@ add_one(1, 1)
"""
"""
state
=
{}
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
print
(
state
)
assert
result
==
2
assert
result
==
2
# test returning None
# test returning None
...
@@ -672,5 +671,4 @@ returns_none(1)
...
@@ -672,5 +671,4 @@ returns_none(1)
"""
"""
state
=
{}
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
print
(
state
)
assert
result
is
None
assert
result
is
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment