Unverified Commit b38612d3 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Agents: Improve python interpreter (#31409)

* Improve Python interpreter
* Add with and assert statements
* Prevent overwriting existing tools
* Check interpreter errors are well logged in code agent
* Add lazy evaluation for and and or
* Improve variable assignment
* Fix early return statements in functions
* Add small import fix on interpreter tool
parent 1f9387d3
...@@ -34,11 +34,16 @@ def custom_print(*args): ...@@ -34,11 +34,16 @@ def custom_print(*args):
BASE_PYTHON_TOOLS = { BASE_PYTHON_TOOLS = {
"print": custom_print, "print": custom_print,
"isinstance": isinstance,
"range": range, "range": range,
"float": float, "float": float,
"int": int, "int": int,
"bool": bool, "bool": bool,
"str": str, "str": str,
"set": set,
"list": list,
"dict": dict,
"tuple": tuple,
"round": round, "round": round,
"ceil": math.ceil, "ceil": math.ceil,
"floor": math.floor, "floor": math.floor,
...@@ -60,10 +65,6 @@ BASE_PYTHON_TOOLS = { ...@@ -60,10 +65,6 @@ BASE_PYTHON_TOOLS = {
"max": max, "max": max,
"min": min, "min": min,
"abs": abs, "abs": abs,
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
"enumerate": enumerate, "enumerate": enumerate,
"zip": zip, "zip": zip,
"reversed": reversed, "reversed": reversed,
...@@ -74,6 +75,15 @@ BASE_PYTHON_TOOLS = { ...@@ -74,6 +75,15 @@ BASE_PYTHON_TOOLS = {
"filter": filter, "filter": filter,
"ord": ord, "ord": ord,
"chr": chr, "chr": chr,
"next": next,
"iter": iter,
"divmod": divmod,
"callable": callable,
"getattr": getattr,
"hasattr": hasattr,
"setattr": setattr,
"issubclass": issubclass,
"type": type,
} }
...@@ -147,9 +157,9 @@ class PythonInterpreterTool(Tool): ...@@ -147,9 +157,9 @@ class PythonInterpreterTool(Tool):
def __init__(self, *args, authorized_imports=None, **kwargs): def __init__(self, *args, authorized_imports=None, **kwargs):
if authorized_imports is None: if authorized_imports is None:
authorized_imports = list(set(LIST_SAFE_MODULES)) self.authorized_imports = list(set(LIST_SAFE_MODULES))
else: else:
authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports)) self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
self.inputs = { self.inputs = {
"code": { "code": {
"type": "text", "type": "text",
...@@ -162,7 +172,9 @@ class PythonInterpreterTool(Tool): ...@@ -162,7 +172,9 @@ class PythonInterpreterTool(Tool):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def forward(self, code): def forward(self, code):
output = str(evaluate_python_code(code, tools=self.available_tools)) output = str(
evaluate_python_code(code, tools=self.available_tools, authorized_imports=self.authorized_imports)
)
return output return output
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .agents import BASE_PYTHON_TOOLS from .agents import BASE_PYTHON_TOOLS
from .python_interpreter import InterpretorError, evaluate from .python_interpreter import InterpreterError, evaluate
### Fake tools for test ### Fake tools for test
...@@ -256,7 +256,7 @@ def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpret ...@@ -256,7 +256,7 @@ def evaluate_code(code, inputs=None, state=None, verbose=False, return_interpret
try: try:
return evaluate(code, tools, state) return evaluate(code, tools, state)
except InterpretorError as e: except InterpreterError as e:
return str(e) return str(e)
except Exception as e: except Exception as e:
if verbose: if verbose:
......
...@@ -54,7 +54,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions: ...@@ -54,7 +54,7 @@ def get_clean_message_list(message_list: List[Dict[str, str]], role_conversions:
message["role"] = role_conversions[role] message["role"] = role_conversions[role]
if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]: if len(final_message_list) > 0 and message["role"] == final_message_list[-1]["role"]:
final_message_list[-1]["content"] += "\n===\n" + message["content"] final_message_list[-1]["content"] += "\n=======\n" + message["content"]
else: else:
final_message_list.append(message) final_message_list.append(message)
return final_message_list return final_message_list
......
...@@ -21,7 +21,7 @@ from collections.abc import Mapping ...@@ -21,7 +21,7 @@ from collections.abc import Mapping
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
class InterpretorError(ValueError): class InterpreterError(ValueError):
""" """
An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported An error raised when the interpretor cannot evaluate a Python expression, due to syntax error or unsupported
operations. operations.
...@@ -50,6 +50,8 @@ LIST_SAFE_MODULES = [ ...@@ -50,6 +50,8 @@ LIST_SAFE_MODULES = [
"unicodedata", "unicodedata",
] ]
PRINT_OUTPUTS = ""
class BreakException(Exception): class BreakException(Exception):
pass pass
...@@ -59,13 +61,18 @@ class ContinueException(Exception): ...@@ -59,13 +61,18 @@ class ContinueException(Exception):
pass pass
class ReturnException(Exception):
def __init__(self, value):
self.value = value
def get_iterable(obj): def get_iterable(obj):
if isinstance(obj, list): if isinstance(obj, list):
return obj return obj
elif hasattr(obj, "__iter__"): elif hasattr(obj, "__iter__"):
return list(obj) return list(obj)
else: else:
raise InterpretorError("Object is not iterable") raise InterpreterError("Object is not iterable")
def evaluate_unaryop(expression, state, tools): def evaluate_unaryop(expression, state, tools):
...@@ -79,7 +86,7 @@ def evaluate_unaryop(expression, state, tools): ...@@ -79,7 +86,7 @@ def evaluate_unaryop(expression, state, tools):
elif isinstance(expression.op, ast.Invert): elif isinstance(expression.op, ast.Invert):
return ~operand return ~operand
else: else:
raise InterpretorError(f"Unary operation {expression.op.__class__.__name__} is not supported.") raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
def evaluate_lambda(lambda_expression, state, tools): def evaluate_lambda(lambda_expression, state, tools):
...@@ -99,10 +106,15 @@ def evaluate_while(while_loop, state, tools): ...@@ -99,10 +106,15 @@ def evaluate_while(while_loop, state, tools):
iterations = 0 iterations = 0
while evaluate_ast(while_loop.test, state, tools): while evaluate_ast(while_loop.test, state, tools):
for node in while_loop.body: for node in while_loop.body:
try:
evaluate_ast(node, state, tools) evaluate_ast(node, state, tools)
except BreakException:
return None
except ContinueException:
break
iterations += 1 iterations += 1
if iterations > max_iterations: if iterations > max_iterations:
raise InterpretorError(f"Maximum number of {max_iterations} iterations in While loop exceeded") raise InterpreterError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
return None return None
...@@ -110,15 +122,33 @@ def create_function(func_def, state, tools): ...@@ -110,15 +122,33 @@ def create_function(func_def, state, tools):
def new_func(*args, **kwargs): def new_func(*args, **kwargs):
func_state = state.copy() func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args] arg_names = [arg.arg for arg in func_def.args.args]
default_values = [evaluate_ast(d, state, tools) for d in func_def.args.defaults]
# Apply default values
defaults = dict(zip(arg_names[-len(default_values) :], default_values))
# Set positional arguments
for name, value in zip(arg_names, args): for name, value in zip(arg_names, args):
func_state[name] = value func_state[name] = value
# # Set keyword arguments
for name, value in kwargs.items():
func_state[name] = value
# Handle variable arguments
if func_def.args.vararg: if func_def.args.vararg:
vararg_name = func_def.args.vararg.arg vararg_name = func_def.args.vararg.arg
func_state[vararg_name] = args func_state[vararg_name] = args
if func_def.args.kwarg: if func_def.args.kwarg:
kwarg_name = func_def.args.kwarg.arg kwarg_name = func_def.args.kwarg.arg
func_state[kwarg_name] = kwargs func_state[kwarg_name] = kwargs
# Set default values for arguments that were not provided
for name, value in defaults.items():
if name not in func_state:
func_state[name] = value
# Update function state with self and __class__ # Update function state with self and __class__
if func_def.args.args and func_def.args.args[0].arg == "self": if func_def.args.args and func_def.args.args[0].arg == "self":
if args: if args:
...@@ -126,8 +156,11 @@ def create_function(func_def, state, tools): ...@@ -126,8 +156,11 @@ def create_function(func_def, state, tools):
func_state["__class__"] = args[0].__class__ func_state["__class__"] = args[0].__class__
result = None result = None
try:
for stmt in func_def.body: for stmt in func_def.body:
result = evaluate_ast(stmt, func_state, tools) result = evaluate_ast(stmt, func_state, tools)
except ReturnException as e:
result = e.value
return result return result
return new_func return new_func
...@@ -155,9 +188,12 @@ def evaluate_class_def(class_def, state, tools): ...@@ -155,9 +188,12 @@ def evaluate_class_def(class_def, state, tools):
class_dict[stmt.name] = evaluate_function_def(stmt, state, tools) class_dict[stmt.name] = evaluate_function_def(stmt, state, tools)
elif isinstance(stmt, ast.Assign): elif isinstance(stmt, ast.Assign):
for target in stmt.targets: for target in stmt.targets:
if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(stmt.value, state, tools) class_dict[target.id] = evaluate_ast(stmt.value, state, tools)
elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(stmt.value, state, tools)
else: else:
raise InterpretorError(f"Unsupported statement in class body: {stmt.__class__.__name__}") raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}")
new_class = type(class_name, tuple(bases), class_dict) new_class = type(class_name, tuple(bases), class_dict)
state[class_name] = new_class state[class_name] = new_class
...@@ -165,14 +201,34 @@ def evaluate_class_def(class_def, state, tools): ...@@ -165,14 +201,34 @@ def evaluate_class_def(class_def, state, tools):
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]): def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
# Extract the target variable name and the operation # Helper function to get current value and set new value based on the target type
if isinstance(expression.target, ast.Name): def get_current_value(target):
var_name = expression.target.id if isinstance(target, ast.Name):
current_value = state.get(var_name, 0) # Assuming default of 0 if not in state return state.get(target.id, 0)
elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools)
key = evaluate_ast(target.slice, state, tools)
return obj[key]
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools)
return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts)
elif isinstance(target, ast.List):
return [get_current_value(elt) for elt in target.elts]
else:
raise InterpreterError("AugAssign not supported for {type(target)} targets.")
current_value = get_current_value(expression.target)
value_to_add = evaluate_ast(expression.value, state, tools) value_to_add = evaluate_ast(expression.value, state, tools)
# Determine the operation and apply it # Determine the operation and apply it
if isinstance(expression.op, ast.Add): if isinstance(expression.op, ast.Add):
if isinstance(current_value, list):
if not isinstance(value_to_add, list):
raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.")
updated_value = current_value + value_to_add
else:
updated_value = current_value + value_to_add updated_value = current_value + value_to_add
elif isinstance(expression.op, ast.Sub): elif isinstance(expression.op, ast.Sub):
updated_value = current_value - value_to_add updated_value = current_value - value_to_add
...@@ -180,22 +236,42 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: ...@@ -180,22 +236,42 @@ def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools:
updated_value = current_value * value_to_add updated_value = current_value * value_to_add
elif isinstance(expression.op, ast.Div): elif isinstance(expression.op, ast.Div):
updated_value = current_value / value_to_add updated_value = current_value / value_to_add
# Add other operations as needed elif isinstance(expression.op, ast.Mod):
updated_value = current_value % value_to_add
elif isinstance(expression.op, ast.Pow):
updated_value = current_value**value_to_add
elif isinstance(expression.op, ast.FloorDiv):
updated_value = current_value // value_to_add
elif isinstance(expression.op, ast.BitAnd):
updated_value = current_value & value_to_add
elif isinstance(expression.op, ast.BitOr):
updated_value = current_value | value_to_add
elif isinstance(expression.op, ast.BitXor):
updated_value = current_value ^ value_to_add
elif isinstance(expression.op, ast.LShift):
updated_value = current_value << value_to_add
elif isinstance(expression.op, ast.RShift):
updated_value = current_value >> value_to_add
else:
raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.")
# Update the state # Update the state
state[var_name] = updated_value set_value(expression.target, updated_value, state, tools)
return updated_value return updated_value
else:
raise InterpretorError("AugAssign not supported for non-simple variable targets.")
def evaluate_boolop(boolop, state, tools): def evaluate_boolop(node, state, tools):
values = [evaluate_ast(val, state, tools) for val in boolop.values] if isinstance(node.op, ast.And):
op = boolop.op for value in node.values:
if isinstance(op, ast.And): if not evaluate_ast(value, state, tools):
return all(values) return False
elif isinstance(op, ast.Or): return True
return any(values) elif isinstance(node.op, ast.Or):
for value in node.values:
if evaluate_ast(value, state, tools):
return True
return False
def evaluate_binop(binop, state, tools): def evaluate_binop(binop, state, tools):
...@@ -233,41 +309,49 @@ def evaluate_binop(binop, state, tools): ...@@ -233,41 +309,49 @@ def evaluate_binop(binop, state, tools):
def evaluate_assign(assign, state, tools): def evaluate_assign(assign, state, tools):
var_names = assign.targets
result = evaluate_ast(assign.value, state, tools) result = evaluate_ast(assign.value, state, tools)
if len(var_names) == 1: if len(assign.targets) == 1:
target = var_names[0] target = assign.targets[0]
if isinstance(target, ast.Tuple): set_value(target, result, state, tools)
else:
if len(assign.targets) != len(result):
raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.")
for tgt, val in zip(assign.targets, result):
set_value(tgt, val, state, tools)
return result
def set_value(target, value, state, tools):
if isinstance(target, ast.Name):
if target.id in tools:
raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!")
state[target.id] = value
elif isinstance(target, ast.Tuple):
if not isinstance(value, tuple):
raise InterpreterError("Cannot unpack non-tuple value")
if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts): for i, elem in enumerate(target.elts):
state[elem.id] = result[i] set_value(elem, value[i], state, tools)
elif isinstance(target, ast.Attribute):
obj = evaluate_ast(target.value, state, tools)
setattr(obj, target.attr, result)
elif isinstance(target, ast.Subscript): elif isinstance(target, ast.Subscript):
obj = evaluate_ast(target.value, state, tools) obj = evaluate_ast(target.value, state, tools)
key = evaluate_ast(target.slice, state, tools) key = evaluate_ast(target.slice, state, tools)
obj[key] = result obj[key] = value
else: elif isinstance(target, ast.Attribute):
state[target.id] = result obj = evaluate_ast(target.value, state, tools)
setattr(obj, target.attr, value)
else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
for var_name, r in zip(var_names, result):
state[var_name.id] = r
return result
def evaluate_call(call, state, tools): def evaluate_call(call, state, tools):
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)): if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpretorError( raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})." f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
) )
if isinstance(call.func, ast.Attribute): if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, tools) obj = evaluate_ast(call.func.value, state, tools)
func_name = call.func.attr func_name = call.func.attr
if not hasattr(obj, func_name): if not hasattr(obj, func_name):
raise InterpretorError(f"Object {obj} has no attribute {func_name}") raise InterpreterError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name) func = getattr(obj, func_name)
elif isinstance(call.func, ast.Name): elif isinstance(call.func, ast.Name):
func_name = call.func.id func_name = call.func.id
...@@ -278,7 +362,7 @@ def evaluate_call(call, state, tools): ...@@ -278,7 +362,7 @@ def evaluate_call(call, state, tools):
elif func_name in ERRORS: elif func_name in ERRORS:
func = ERRORS[func_name] func = ERRORS[func_name]
else: else:
raise InterpretorError( raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})." f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
) )
...@@ -297,22 +381,22 @@ def evaluate_call(call, state, tools): ...@@ -297,22 +381,22 @@ def evaluate_call(call, state, tools):
if "__class__" in state and "self" in state: if "__class__" in state and "self" in state:
return super(state["__class__"], state["self"]) return super(state["__class__"], state["self"])
else: else:
raise InterpretorError("super() needs at least one argument") raise InterpreterError("super() needs at least one argument")
cls = args[0] cls = args[0]
if not isinstance(cls, type): if not isinstance(cls, type):
raise InterpretorError("super() argument 1 must be type") raise InterpreterError("super() argument 1 must be type")
if len(args) == 1: if len(args) == 1:
return super(cls) return super(cls)
elif len(args) == 2: elif len(args) == 2:
instance = args[1] instance = args[1]
return super(cls, instance) return super(cls, instance)
else: else:
raise InterpretorError("super() takes at most 2 arguments") raise InterpreterError("super() takes at most 2 arguments")
else: else:
if func_name == "print": if func_name == "print":
output = " ".join(map(str, args)) output = " ".join(map(str, args))
state["print_outputs"] += output + "\n" global PRINT_OUTPUTS
PRINT_OUTPUTS += output + "\n"
return output return output
else: # Assume it's a callable object else: # Assume it's a callable object
output = func(*args, **kwargs) output = func(*args, **kwargs)
...@@ -325,8 +409,14 @@ def evaluate_subscript(subscript, state, tools): ...@@ -325,8 +409,14 @@ def evaluate_subscript(subscript, state, tools):
if isinstance(index, slice): if isinstance(index, slice):
return value[index] return value[index]
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}")
return value[int(index)] return value[int(index)]
elif isinstance(value, str): elif isinstance(value, str):
# Ensure the index is within bounds
if not (-len(value) <= index < len(value)):
raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}")
return value[index] return value[index]
elif index in value: elif index in value:
return value[index] return value[index]
...@@ -334,7 +424,7 @@ def evaluate_subscript(subscript, state, tools): ...@@ -334,7 +424,7 @@ def evaluate_subscript(subscript, state, tools):
close_matches = difflib.get_close_matches(index, list(value.keys())) close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0: if len(close_matches) > 0:
return value[close_matches[0]] return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.") raise InterpreterError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools): def evaluate_name(name, state, tools):
...@@ -347,7 +437,7 @@ def evaluate_name(name, state, tools): ...@@ -347,7 +437,7 @@ def evaluate_name(name, state, tools):
close_matches = difflib.get_close_matches(name.id, list(state.keys())) close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0: if len(close_matches) > 0:
return state[close_matches[0]] return state[close_matches[0]]
raise InterpretorError(f"The variable `{name.id}` is not defined.") raise InterpreterError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools): def evaluate_condition(condition, state, tools):
...@@ -355,30 +445,36 @@ def evaluate_condition(condition, state, tools): ...@@ -355,30 +445,36 @@ def evaluate_condition(condition, state, tools):
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators] comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
ops = [type(op) for op in condition.ops] ops = [type(op) for op in condition.ops]
result = left result = True
current_left = left
for op, comparator in zip(ops, comparators): for op, comparator in zip(ops, comparators):
if op == ast.Eq: if op == ast.Eq:
result = result == comparator result = result and (current_left == comparator)
elif op == ast.NotEq: elif op == ast.NotEq:
result = result != comparator result = result and (current_left != comparator)
elif op == ast.Lt: elif op == ast.Lt:
result = result < comparator result = result and (current_left < comparator)
elif op == ast.LtE: elif op == ast.LtE:
result = result <= comparator result = result and (current_left <= comparator)
elif op == ast.Gt: elif op == ast.Gt:
result = result > comparator result = result and (current_left > comparator)
elif op == ast.GtE: elif op == ast.GtE:
result = result >= comparator result = result and (current_left >= comparator)
elif op == ast.Is: elif op == ast.Is:
result = result is comparator result = result and (current_left is comparator)
elif op == ast.IsNot: elif op == ast.IsNot:
result = result is not comparator result = result and (current_left is not comparator)
elif op == ast.In: elif op == ast.In:
result = result in comparator result = result and (current_left in comparator)
elif op == ast.NotIn: elif op == ast.NotIn:
result = result not in comparator result = result and (current_left not in comparator)
else: else:
raise InterpretorError(f"Operator not supported: {op}") raise InterpreterError(f"Operator not supported: {op}")
current_left = comparator
if not result:
break
return result return result
...@@ -425,15 +521,17 @@ def evaluate_for(for_loop, state, tools): ...@@ -425,15 +521,17 @@ def evaluate_for(for_loop, state, tools):
def evaluate_listcomp(listcomp, state, tools): def evaluate_listcomp(listcomp, state, tools):
result = [] result = []
vars = {}
for generator in listcomp.generators: for generator in listcomp.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools) iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value: for value in iter_value:
vars[var_name] = value new_state = state.copy()
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs): if isinstance(generator.target, ast.Tuple):
elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools) for idx, elem in enumerate(generator.target.elts):
result.append(elem) new_state[elem.id] = value[idx]
else:
new_state[generator.target.id] = value
if all(evaluate_ast(if_clause, new_state, tools) for if_clause in generator.ifs):
result.append(evaluate_ast(listcomp.elt, new_state, tools))
return result return result
...@@ -478,7 +576,42 @@ def evaluate_raise(raise_node, state, tools): ...@@ -478,7 +576,42 @@ def evaluate_raise(raise_node, state, tools):
else: else:
raise exc raise exc
else: else:
raise InterpretorError("Re-raise is not supported without an active exception") raise InterpreterError("Re-raise is not supported without an active exception")
def evaluate_assert(assert_node, state, tools):
test_result = evaluate_ast(assert_node.test, state, tools)
if not test_result:
if assert_node.msg:
msg = evaluate_ast(assert_node.msg, state, tools)
raise AssertionError(msg)
else:
# Include the failing condition in the assertion message
test_code = ast.unparse(assert_node.test)
raise AssertionError(f"Assertion failed: {test_code}")
def evaluate_with(with_node, state, tools):
contexts = []
for item in with_node.items:
context_expr = evaluate_ast(item.context_expr, state, tools)
if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__()
contexts.append(state[item.optional_vars.id])
else:
context_var = context_expr.__enter__()
contexts.append(context_var)
try:
for stmt in with_node.body:
evaluate_ast(stmt, state, tools)
except Exception as e:
for context in reversed(contexts):
context.__exit__(type(e), e, e.__traceback__)
raise
else:
for context in reversed(contexts):
context.__exit__(None, None, None)
def evaluate_ast( def evaluate_ast(
...@@ -501,7 +634,7 @@ def evaluate_ast( ...@@ -501,7 +634,7 @@ def evaluate_ast(
encounters assignements. encounters assignements.
tools (`Dict[str, Callable]`): tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`. `InterpreterError`.
authorized_imports (`List[str]`): authorized_imports (`List[str]`):
The list of modules that can be imported by the code. By default, only a few safe modules are allowed. The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
Add more at your own risk! Add more at your own risk!
...@@ -537,8 +670,6 @@ def evaluate_ast( ...@@ -537,8 +670,6 @@ def evaluate_ast(
elif isinstance(expression, ast.Compare): elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison # Comparison -> evaluate the comparison
return evaluate_condition(expression, state, tools) return evaluate_condition(expression, state, tools)
elif isinstance(expression, ast.Return):
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.Lambda): elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, tools) return evaluate_lambda(expression, state, tools)
elif isinstance(expression, ast.FunctionDef): elif isinstance(expression, ast.FunctionDef):
...@@ -615,7 +746,7 @@ def evaluate_ast( ...@@ -615,7 +746,7 @@ def evaluate_ast(
module = __import__(alias.name) module = __import__(alias.name)
state[alias.asname or alias.name] = module state[alias.asname or alias.name] = module
else: else:
raise InterpretorError(f"Import of {alias.name} is not allowed.") raise InterpreterError(f"Import of {alias.name} is not allowed.")
return None return None
elif isinstance(expression, ast.While): elif isinstance(expression, ast.While):
return evaluate_while(expression, state, tools) return evaluate_while(expression, state, tools)
...@@ -625,7 +756,7 @@ def evaluate_ast( ...@@ -625,7 +756,7 @@ def evaluate_ast(
for alias in expression.names: for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name) state[alias.asname or alias.name] = getattr(module, alias.name)
else: else:
raise InterpretorError(f"Import from {expression.module} is not allowed.") raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None return None
elif isinstance(expression, ast.ClassDef): elif isinstance(expression, ast.ClassDef):
return evaluate_class_def(expression, state, tools) return evaluate_class_def(expression, state, tools)
...@@ -633,9 +764,17 @@ def evaluate_ast( ...@@ -633,9 +764,17 @@ def evaluate_ast(
return evaluate_try(expression, state, tools) return evaluate_try(expression, state, tools)
elif isinstance(expression, ast.Raise): elif isinstance(expression, ast.Raise):
return evaluate_raise(expression, state, tools) return evaluate_raise(expression, state, tools)
elif isinstance(expression, ast.Assert):
return evaluate_assert(expression, state, tools)
elif isinstance(expression, ast.With):
return evaluate_with(expression, state, tools)
elif isinstance(expression, ast.Set):
return {evaluate_ast(elt, state, tools) for elt in expression.elts}
elif isinstance(expression, ast.Return):
raise ReturnException(evaluate_ast(expression.value, state, tools) if expression.value else None)
else: else:
# For now we refuse anything else. Let's add things as we need them. # For now we refuse anything else. Let's add things as we need them.
raise InterpretorError(f"{expression.__class__.__name__} is not supported.") raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
def evaluate_python_code( def evaluate_python_code(
...@@ -652,7 +791,7 @@ def evaluate_python_code( ...@@ -652,7 +791,7 @@ def evaluate_python_code(
The code to evaluate. The code to evaluate.
tools (`Dict[str, Callable]`): tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`. `InterpreterError`.
state (`Dict[str, Any]`): state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated. updated by this function to contain all variables as they are evaluated.
...@@ -665,17 +804,17 @@ def evaluate_python_code( ...@@ -665,17 +804,17 @@ def evaluate_python_code(
if state is None: if state is None:
state = {} state = {}
result = None result = None
state["print_outputs"] = "" global PRINT_OUTPUTS
PRINT_OUTPUTS = ""
for idx, node in enumerate(expression.body): for node in expression.body:
try: try:
line_result = evaluate_ast(node, state, tools, authorized_imports) result = evaluate_ast(node, state, tools, authorized_imports)
except InterpretorError as e: except InterpreterError as e:
msg = f"You tried to execute the following code:\n{code}\n" msg = f"Evaluation stopped at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
msg += f"You got these outputs:\n{state['print_outputs']}\n" if len(PRINT_OUTPUTS) > 0:
msg += f"Evaluation stopped at line '{node}' because of the following error:\n{e}" msg += f"Executing code yielded these outputs:\n{PRINT_OUTPUTS}\n====\n"
raise InterpretorError(msg) raise InterpreterError(msg)
if line_result is not None: finally:
result = line_result state["print_outputs"] = PRINT_OUTPUTS
return result return result
...@@ -74,6 +74,26 @@ final_answer(7.2904) ...@@ -74,6 +74,26 @@ final_answer(7.2904)
""" """
def fake_react_code_llm_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
print = 2
```<end_code>
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Code:
```py
final_answer("got an error")
```<end_code>
"""
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str: def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
return """ return """
Thought: I should multiply 2 by 3.6452. special_marker Thought: I should multiply 2 by 3.6452. special_marker
...@@ -124,6 +144,13 @@ Action: ...@@ -124,6 +144,13 @@ Action:
"tool_name": "code interpreter", "tool_name": "code interpreter",
} }
def test_react_code_agent_code_errors_show_offending_lines(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm_error)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Evaluation stopped at line 'print = 2' because of" in str(agent.logs)
def test_setup_agent_with_empty_toolbox(self): def test_setup_agent_with_empty_toolbox(self):
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[]) ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
......
...@@ -20,7 +20,7 @@ import pytest ...@@ -20,7 +20,7 @@ import pytest
from transformers import load_tool from transformers import load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.agents.default_tools import BASE_PYTHON_TOOLS from transformers.agents.default_tools import BASE_PYTHON_TOOLS
from transformers.agents.python_interpreter import InterpretorError, evaluate_python_code from transformers.agents.python_interpreter import InterpreterError, evaluate_python_code
from .test_tools_common import ToolTesterMixin from .test_tools_common import ToolTesterMixin
...@@ -35,16 +35,6 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin): ...@@ -35,16 +35,6 @@ class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"]) self.tool = load_tool("python_interpreter", authorized_imports=["sqlite3"])
self.tool.setup() self.tool.setup()
def test_exact_match_input_spec(self):
inputs_spec = self.tool.inputs
expected_description = (
"The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
"else you will get an error. This code can only import the following python libraries: "
"['math', 'statistics', 'time', 'itertools', 'stat', 'unicodedata', 'sqlite3', 'queue', 'collections', "
"'random', 're']."
)
self.assertEqual(inputs_spec["code"]["description"], expected_description)
def test_exact_match_arg(self): def test_exact_match_arg(self):
result = self.tool("(2 / 2) * 4") result = self.tool("(2 / 2) * 4")
self.assertEqual(result, "4.0") self.assertEqual(result, "4.0")
...@@ -91,6 +81,17 @@ class PythonInterpreterTester(unittest.TestCase): ...@@ -91,6 +81,17 @@ class PythonInterpreterTester(unittest.TestCase):
assert result == 5 assert result == 5
self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 5, "y": 5, "print_outputs": ""})
code = "a=1;b=None"
result = evaluate_python_code(code, {}, state={})
# evaluate returns the value of the last assignment.
assert result is None
def test_assignment_cannot_overwrite_tool(self):
code = "print = '3'"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {"print": print}, state={})
assert "Cannot assign to name 'print': doing this would erase the existing tool!" in str(e)
def test_evaluate_call(self): def test_evaluate_call(self):
code = "y = add_two(x)" code = "y = add_two(x)"
state = {"x": 3} state = {"x": 3}
...@@ -99,7 +100,7 @@ class PythonInterpreterTester(unittest.TestCase): ...@@ -99,7 +100,7 @@ class PythonInterpreterTester(unittest.TestCase):
self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""}) self.assertDictEqual(state, {"x": 3, "y": 5, "print_outputs": ""})
# Should not work without the tool # Should not work without the tool
with pytest.raises(InterpretorError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, {}, state=state) evaluate_python_code(code, {}, state=state)
assert "tried to execute add_two" in str(e.value) assert "tried to execute add_two" in str(e.value)
...@@ -237,6 +238,12 @@ for block in text_block: ...@@ -237,6 +238,12 @@ for block in text_block:
result = evaluate_python_code(code, {}, state={}) result = evaluate_python_code(code, {}, state={})
assert result == 2 assert result == 2
code = """
digits, i = [1, 2, 3], 1
digits[i], digits[i + 1] = digits[i + 1], digits[i]"""
state = {}
evaluate_python_code(code, {"range": range, "print": print, "int": int}, state)
def test_listcomp(self): def test_listcomp(self):
code = "x = [i for i in range(3)]" code = "x = [i for i in range(3)]"
result = evaluate_python_code(code, {"range": range}, state={}) result = evaluate_python_code(code, {"range": range}, state={})
...@@ -278,10 +285,20 @@ for block in text_block: ...@@ -278,10 +285,20 @@ for block in text_block:
# test infinite loop # test infinite loop
code = "i = 0\nwhile i < 3:\n i -= 1\ni" code = "i = 0\nwhile i < 3:\n i -= 1\ni"
with pytest.raises(InterpretorError) as e: with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "iterations in While loop exceeded" in str(e) assert "iterations in While loop exceeded" in str(e)
# test lazy evaluation
code = """
house_positions = [0, 7, 10, 15, 18, 22, 22]
i, n, loc = 0, 7, 30
while i < n and house_positions[i] <= loc:
i += 1
"""
state = {}
evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
def test_generator(self): def test_generator(self):
code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)" code = "a = [1, 2, 3, 4, 5]; b = (i**2 for i in a); list(b)"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
...@@ -353,7 +370,19 @@ if char.isalpha(): ...@@ -353,7 +370,19 @@ if char.isalpha():
assert result == "LATIN CAPITAL LETTER A" assert result == "LATIN CAPITAL LETTER A"
def test_multiple_comparators(self): def test_multiple_comparators(self):
code = "0x30A0 <= ord('a') <= 0x30FF" code = "0 <= -1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 1 < 4 and 0 <= -5 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 4 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert not result
code = "0 <= 3 < 4 and 0 <= 3 < 4"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert result assert result
...@@ -364,6 +393,16 @@ if char.isalpha(): ...@@ -364,6 +393,16 @@ if char.isalpha():
assert result == "Ok no one cares" assert result == "Ok no one cares"
assert state["print_outputs"] == "Hello world!\nOk no one cares\n" assert state["print_outputs"] == "Hello world!\nOk no one cares\n"
# test print in function
code = """
print("1")
def function():
print("2")
function()"""
state = {}
evaluate_python_code(code, {"print": print}, state)
assert state["print_outputs"] == "1\n2\n"
def test_tuple_target_in_iterator(self): def test_tuple_target_in_iterator(self):
code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]" code = "for a, b in [('Ralf Weikert', 'Austria'), ('Samuel Seungwon Lee', 'South Korea')]:res = a.split()[0]"
result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={}) result = evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
...@@ -491,3 +530,147 @@ except ValueError as e: ...@@ -491,3 +530,147 @@ except ValueError as e:
state = {} state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state) result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
assert result == int assert result == int
def test_tuple_id(self):
code = """
food_items = {"apple": 2, "banana": 3, "orange": 1, "pear": 1}
unique_food_items = [item for item, count in food_item_counts.items() if count == 1]
"""
state = {}
result = evaluate_python_code(code, {}, state=state)
assert result == ["orange", "pear"]
def test_nonsimple_augassign(self):
code = """
counts_dict = {'a': 0}
counts_dict['a'] += 1
counts_list = [1, 2, 3]
counts_list += [4, 5, 6]
class Counter:
self.count = 0
a = Counter()
a.count += 1
"""
state = {}
evaluate_python_code(code, {}, state=state)
assert state["counts_dict"] == {"a": 1}
assert state["counts_list"] == [1, 2, 3, 4, 5, 6]
assert state["a"].count == 1
def test_adding_int_to_list_raises_error(self):
code = """
counts = [1, 2, 3]
counts += 1"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Cannot add non-list value 1 to a list." in str(e)
def test_error_highlights_correct_line_of_code(self):
code = """# Ok this is a very long code
# It has many commented lines
a = 1
b = 2
# Here is another piece
counts = [1, 2, 3]
counts += 1
b += 1"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Evaluation stopped at line 'counts += 1" in str(e)
def test_assert(self):
code = """
assert 1 == 1
assert 1 == 2
"""
with pytest.raises(AssertionError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "1 == 2" in str(e) and "1 == 1" not in str(e)
def test_with_context_manager(self):
code = """
class SimpleLock:
def __init__(self):
self.locked = False
def __enter__(self):
self.locked = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.locked = False
lock = SimpleLock()
with lock as l:
assert l.locked == True
assert lock.locked == False
"""
state = {}
tools = {}
evaluate_python_code(code, tools, state)
def test_default_arg_in_function(self):
code = """
def f(a, b=333, n=1000):
return b + n
n = f(1, n=667)
"""
res = evaluate_python_code(code, {}, {})
assert res == 1000
def test_set(self):
code = """
S1 = {'a', 'b', 'c'}
S2 = {'b', 'c', 'd'}
S3 = S1.difference(S2)
S4 = S1.intersection(S2)
"""
state = {}
evaluate_python_code(code, {}, state=state)
assert state["S3"] == {"a"}
assert state["S4"] == {"b", "c"}
def test_break(self):
code = """
i = 0
while True:
i+= 1
if i==3:
break
i"""
result = evaluate_python_code(code, {"print": print, "round": round}, state={})
assert result == 3
def test_return(self):
# test early returns
code = """
def add_one(n, shift):
if True:
return n + shift
return n
add_one(1, 1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result == 2
# test returning None
code = """
def returns_none(a):
return
returns_none(1)
"""
state = {}
result = evaluate_python_code(code, {"print": print, "range": range, "ord": ord, "chr": chr}, state=state)
print(state)
assert result is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment