Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b3818805
Unverified
Commit
b3818805
authored
Jul 22, 2024
by
Aymeric Roucher
Committed by
GitHub
Jul 22, 2024
Browse files
Agents planning (#31702)
* Allow planning for agents
parent
0fdea860
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
770 additions
and
273 deletions
+770
-273
src/transformers/agents/agents.py
src/transformers/agents/agents.py
+211
-65
src/transformers/agents/default_tools.py
src/transformers/agents/default_tools.py
+1
-1
src/transformers/agents/prompts.py
src/transformers/agents/prompts.py
+112
-1
src/transformers/agents/python_interpreter.py
src/transformers/agents/python_interpreter.py
+288
-201
tests/agents/test_agents.py
tests/agents/test_agents.py
+1
-1
tests/agents/test_python_interpreter.py
tests/agents/test_python_interpreter.py
+157
-4
No files found.
src/transformers/agents/agents.py
View file @
b3818805
This diff is collapsed.
Click to expand it.
src/transformers/agents/default_tools.py
View file @
b3818805
...
...
@@ -173,7 +173,7 @@ class PythonInterpreterTool(Tool):
def
forward
(
self
,
code
):
output
=
str
(
evaluate_python_code
(
code
,
tools
=
self
.
available_tools
,
authorized_imports
=
self
.
authorized_imports
)
evaluate_python_code
(
code
,
static_
tools
=
self
.
available_tools
,
authorized_imports
=
self
.
authorized_imports
)
)
return
output
...
...
src/transformers/agents/prompts.py
View file @
b3818805
...
...
@@ -365,7 +365,118 @@ Here are the rules you should always follow to solve your task:
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
9. Don't give up! You're in charge of solving the task, not providing directions to solve it.
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
"""
SYSTEM_PROMPT_FACTS
=
"""Below I will present you a task.
You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need.
To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it.
Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey:
---
### 1. Facts given in the task
List here the specific facts given in the task that could help you (there might be nothing here).
### 2. Facts to look up
List here any facts that we may need to look up.
Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here.
### 3. Facts to derive
List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation.
Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings:
### 1. Facts given in the task
### 2. Facts to look up
### 3. Facts to derive
Do not add anything else."""
SYSTEM_PROMPT_PLAN
=
"""You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '
\n
<end_plan>' tag and stop there."""
USER_PROMPT_PLAN
=
"""
Here is your task:
Task:
```
{task}
```
Your plan can leverage any of these tools:
{tool_descriptions}
List of facts that you know:
```
{answer_facts}
```
Now begin! Write your plan below."""
SYSTEM_PROMPT_FACTS_UPDATE
=
"""
You are a world expert at gathering known and unknown facts based on a conversation.
Below you will find a task, and ahistory of attempts made to solve the task. You will have to produce a list of these:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Find the task and history below."""
USER_PROMPT_FACTS_UPDATE
=
"""Earlier we've built a list of facts.
But since in your previous steps you may have learned useful new facts or invalidated some false ones.
Please update your list of facts based on the previous history, and provide these headings:
### 1. Facts given in the task
### 2. Facts that we have learned
### 3. Facts still to look up
### 4. Facts still to derive
Now write your new list of facts below."""
SYSTEM_PROMPT_PLAN_UPDATE
=
"""You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools.
You have been given a task:
```
{task}
```
Find below the record of what has been tried so far to solve it. Then you will be asked to make an updated plan to solve the task.
If the previous tries so far have met some success, you can make an updated plan based on these actions.
If you are stalled, you can make a completely new plan starting from scratch.
"""
USER_PROMPT_PLAN_UPDATE
=
"""You're still working towards solving this task:
```
{task}
```
You have access to these tools:
{tool_descriptions}
Here is the up to date list of facts that you know:
```
{facts_update}
```
Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts.
This plan should involve individual tasks based on the avilable tools, that if executed correctly will yield the correct answer.
Beware that you have {remaining_steps} steps remaining.
Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS.
After writing the final step of the plan, write the '
\n
<end_plan>' tag and stop there.
Now write your new plan below."""
PLAN_UPDATE_FINAL_PLAN_REDACTION
=
"""I still need to solve the task I was given:
```
{task}
```
Here is my new/updated plan of action to solve the task:
```
{plan_update}
```"""
src/transformers/agents/python_interpreter.py
View file @
b3818805
This diff is collapsed.
Click to expand it.
tests/agents/test_agents.py
View file @
b3818805
...
...
@@ -223,7 +223,7 @@ Action:
# check that add_base_tools will not interfere with existing tools
with
pytest
.
raises
(
KeyError
)
as
e
:
agent
=
ReactJsonAgent
(
tools
=
toolset_3
,
llm_engine
=
fake_react_json_llm
,
add_base_tools
=
True
)
assert
"
python_interpreter
already exists in the toolbox"
in
str
(
e
)
assert
"already exists in the toolbox"
in
str
(
e
)
# check that python_interpreter base tool does not get added to code agents
agent
=
ReactCodeAgent
(
tools
=
[],
llm_engine
=
fake_react_code_llm
,
add_base_tools
=
True
)
...
...
tests/agents/test_python_interpreter.py
View file @
b3818805
...
...
@@ -15,6 +15,7 @@
import
unittest
import
numpy
as
np
import
pytest
from
transformers
import
load_tool
...
...
@@ -241,8 +242,41 @@ for block in text_block:
code
=
"""
digits, i = [1, 2, 3], 1
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
evaluate_python_code
(
code
,
{
"range"
:
range
,
"print"
:
print
,
"int"
:
int
},
{})
code
=
"""
def calculate_isbn_10_check_digit(number):
total = sum((10 - i) * int(digit) for i, digit in enumerate(number))
remainder = total % 11
check_digit = 11 - remainder
if check_digit == 10:
return 'X'
elif check_digit == 11:
return '0'
else:
return str(check_digit)
# Given 9-digit numbers
numbers = [
"478225952",
"643485613",
"739394228",
"291726859",
"875262394",
"542617795",
"031810713",
"957007669",
"871467426"
]
# Calculate check digits for each number
check_digits = [calculate_isbn_10_check_digit(number) for number in numbers]
print(check_digits)
"""
state
=
{}
evaluate_python_code
(
code
,
{
"range"
:
range
,
"print"
:
print
,
"int"
:
int
},
state
)
evaluate_python_code
(
code
,
{
"range"
:
range
,
"print"
:
print
,
"sum"
:
sum
,
"enumerate"
:
enumerate
,
"int"
:
int
,
"str"
:
str
},
state
)
def
test_listcomp
(
self
):
code
=
"x = [i for i in range(3)]"
...
...
@@ -273,6 +307,17 @@ digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
result
=
evaluate_python_code
(
code
,
{
"range"
:
range
},
state
=
{})
assert
result
==
{
0
:
0
,
1
:
1
,
2
:
4
}
code
=
"{num: name for num, name in {101: 'a', 102: 'b'}.items() if name not in ['a']}"
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
},
state
=
{},
authorized_imports
=
[
"pandas"
])
assert
result
==
{
102
:
"b"
}
code
=
"""
shifts = {'A': ('6:45', '8:00'), 'B': ('10:00', '11:45')}
shift_minutes = {worker: ('a', 'b') for worker, (start, end) in shifts.items()}
"""
result
=
evaluate_python_code
(
code
,
{},
state
=
{})
assert
result
==
{
"A"
:
(
"a"
,
"b"
),
"B"
:
(
"a"
,
"b"
)}
def
test_tuple_assignment
(
self
):
code
=
"a, b = 0, 1
\n
b"
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{})
...
...
@@ -341,7 +386,7 @@ if char.isalpha():
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{})
assert
result
==
"lose"
code
=
"import time
\n
time.sleep(0.1)"
code
=
"import time
, re
\n
time.sleep(0.1)"
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{})
assert
result
is
None
...
...
@@ -369,6 +414,23 @@ if char.isalpha():
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{})
assert
result
==
"LATIN CAPITAL LETTER A"
# Test submodules are handled properly, thus not raising error
code
=
"import numpy.random as rd
\n
rng = rd.default_rng(12345)
\n
rng.random()"
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{},
authorized_imports
=
[
"numpy"
])
code
=
"from numpy.random import default_rng as d_rng
\n
rng = d_rng(12345)
\n
rng.random()"
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{},
authorized_imports
=
[
"numpy"
])
def
test_additional_imports
(
self
):
code
=
"import numpy as np"
evaluate_python_code
(
code
,
authorized_imports
=
[
"numpy"
],
state
=
{})
code
=
"import numpy.random as rd"
evaluate_python_code
(
code
,
authorized_imports
=
[
"numpy.random"
],
state
=
{})
evaluate_python_code
(
code
,
authorized_imports
=
[
"numpy"
],
state
=
{})
with
pytest
.
raises
(
InterpreterError
):
evaluate_python_code
(
code
,
authorized_imports
=
[
"random"
],
state
=
{})
def
test_multiple_comparators
(
self
):
code
=
"0 <= -1 < 4 and 0 <= -5 < 4"
result
=
evaluate_python_code
(
code
,
BASE_PYTHON_TOOLS
,
state
=
{})
...
...
@@ -400,7 +462,7 @@ def function():
print("2")
function()"""
state
=
{}
evaluate_python_code
(
code
,
{
"print"
:
print
},
state
)
evaluate_python_code
(
code
,
{
"print"
:
print
},
state
=
state
)
assert
state
[
"print_outputs"
]
==
"1
\n
2
\n
"
def
test_tuple_target_in_iterator
(
self
):
...
...
@@ -612,7 +674,7 @@ assert lock.locked == False
"""
state
=
{}
tools
=
{}
evaluate_python_code
(
code
,
tools
,
state
)
evaluate_python_code
(
code
,
tools
,
state
=
state
)
def
test_default_arg_in_function
(
self
):
code
=
"""
...
...
@@ -672,3 +734,94 @@ returns_none(1)
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
,
"ord"
:
ord
,
"chr"
:
chr
},
state
=
state
)
assert
result
is
None
def
test_nested_for_loop
(
self
):
code
=
"""
all_res = []
for i in range(10):
subres = []
for j in range(i):
subres.append(j)
all_res.append(subres)
out = [i for sublist in all_res for i in sublist]
out[:10]
"""
state
=
{}
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"range"
:
range
},
state
=
state
)
assert
result
==
[
0
,
0
,
1
,
0
,
1
,
2
,
0
,
1
,
2
,
3
]
def
test_pandas
(
self
):
code
=
"""
import pandas as pd
df = pd.DataFrame.from_dict({'SetCount': ['5', '4', '5'], 'Quantity': [1, 0, -1]})
df['SetCount'] = pd.to_numeric(df['SetCount'], errors='coerce')
parts_with_5_set_count = df[df['SetCount'] == 5.0]
parts_with_5_set_count[['Quantity', 'SetCount']].values[1]
"""
state
=
{}
result
=
evaluate_python_code
(
code
,
{},
state
=
state
,
authorized_imports
=
[
"pandas"
])
assert
np
.
array_equal
(
result
,
[
-
1
,
5
])
code
=
"""
import pandas as pd
df = pd.DataFrame.from_dict({"AtomicNumber": [111, 104, 105], "ok": [0, 1, 2]})
print("HH0")
# Filter the DataFrame to get only the rows with outdated atomic numbers
filtered_df = df.loc[df['AtomicNumber'].isin([104])]
"""
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
},
state
=
{},
authorized_imports
=
[
"pandas"
])
assert
np
.
array_equal
(
result
.
values
[
0
],
[
104
,
1
])
code
=
"""import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
"""
result
=
evaluate_python_code
(
code
,
{},
state
=
{},
authorized_imports
=
[
"pandas"
])
assert
result
.
values
[
1
]
==
0.5
def
test_starred
(
self
):
code
=
"""
from math import radians, sin, cos, sqrt, atan2
def haversine(lat1, lon1, lat2, lon2):
R = 6371000 # Radius of the Earth in meters
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
c = 2 * atan2(sqrt(a), sqrt(1 - a))
distance = R * c
return distance
coords_geneva = (46.1978, 6.1342)
coords_barcelona = (41.3869, 2.1660)
distance_geneva_barcelona = haversine(*coords_geneva, *coords_barcelona)
"""
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"map"
:
map
},
state
=
{},
authorized_imports
=
[
"math"
])
assert
round
(
result
,
1
)
==
622395.4
def
test_for
(
self
):
code
=
"""
shifts = {
"Worker A": ("6:45 pm", "8:00 pm"),
"Worker B": ("10:00 am", "11:45 am")
}
shift_intervals = {}
for worker, (start, end) in shifts.items():
shift_intervals[worker] = end
shift_intervals
"""
result
=
evaluate_python_code
(
code
,
{
"print"
:
print
,
"map"
:
map
},
state
=
{})
assert
result
==
{
"Worker A"
:
"8:00 pm"
,
"Worker B"
:
"11:45 am"
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment