"docs/source/pt/index.md" did not exist on "1ac698744c4dbdf1495d303246d08ffacdf4f5b8"
Unverified Commit 0ba15ced authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Reboot Agents (#30387)



* Create CodeAgent and ReactAgent

* Fix formatting errors

* Update documentation for agents

* Add custom errors, improve logging

* Support variable usage in ReactAgent

* add messages

* Add message passing format

* Create React Code Agent

* Update

* Refactoring

* Fix errors

* Improve python interpreter

* Only non-tensor inputs should be sent to device

* Calculator tool slight refactor

* Improve docstrings

* Refactor

* Fix tests

* Fix more tests

* Fix even more tests

* Fix tests by replacing output and input types

* Fix operand type issue

* two small fixes

* EM TTS

* Fix agent running type errors

* Change text to speech tests to allow changed outputs

* Update doc with new agent types

* Improve code interpreter

* If max iterations reached, provide a real answer instead of an error

* Add edge case in interpreter

* Add safe imports to the interpreter

* Interpreter tweaks: tuples and listcomp

* Make style

* Make quality

* Add dictcomp to interpreter

* Rename ReactJSONAgent to ReactJsonAgent

* Misc changes

* ToolCollection

* Rename agent's logger to self.logger

* Add while loops to interpreter

* Update doc with new tools. still need to mention collections

* Add collections to the doc

* Small fixes on logs and interpretor

* Fix toolbox return type

* Docs + fixup

* Skip doctests

* Correct prompts with improved examples and formatting

* Update prompt

* Remove outdated docs

* Change agent to accept Toolbox object for tools

* Remove calculator tool

* Propagate removal of calculator in doc

* Fix 2 failing workflows

* Simplify additional argument passing

* AgentType audio

* Minor changes: function name, types

* Remove calculator tests

* Fix test

* Fix torch requirement

* Fix final answer tests

* Style fixes

* Fix tests

* Update docstrings with calculator removal

* Small type hint fixes

* Update tests/agents/test_translation.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/agents/test_python_interpreter.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/default_tools.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/tools.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/agents/test_agents.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/bert/configuration_bert.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/tools.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/speech_to_text.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/agents/test_speech_to_text.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update tests/agents/test_tools_common.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* pygments

* Answer comments

* Cleaning up

* Simplifying init for all agents

* Improving prompts and making code nicer

* Style fixes

* Add multiple comparator test in interpreter

* Style fixes

* Improve BERT example in documentation

* Add examples to doc

* Fix python interpreter quality

* Logging improvements

* Change test flag to agents

* Quality fix

* Add example for HfEngine

* Improve conversation example for HfEngine

* typo fix

* Verify doc

* Update docs/source/en/agents.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/agents.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/prompts.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/agents/python_interpreter.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/agents.md
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix style issues

* local s2t tool

---------
Co-authored-by: default avatarCyril Kondratenko <kkn1993@gmail.com>
Co-authored-by: default avatarLysandre <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 3733391c
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,7 +17,7 @@
import ast
import difflib
from collections.abc import Mapping
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional
class InterpretorError(ValueError):
......@@ -29,55 +29,317 @@ class InterpretorError(ValueError):
pass
def evaluate(code: str, tools: Dict[str, Callable], state=None, chat_mode=False):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
of functions.
LIST_SAFE_MODULES = ["random", "math", "time", "queue", "itertools", "re", "stat", "statistics", "unicodedata"]
This function will recurse through the nodes of the tree provided.
Args:
code (`str`):
The code to evaluate.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated.
chat_mode (`bool`, *optional*, defaults to `False`):
Whether or not the function is called from `Agent.chat`.
"""
try:
expression = ast.parse(code)
except SyntaxError as e:
print("The code generated by the agent is not valid.\n", e)
return
if state is None:
state = {}
class BreakException(Exception):
pass
class ContinueException(Exception):
pass
def get_iterable(obj):
if isinstance(obj, list):
return obj
elif hasattr(obj, "__iter__"):
return list(obj)
else:
raise InterpretorError("Object is not iterable")
def evaluate_unaryop(expression, state, tools):
operand = evaluate_ast(expression.operand, state, tools)
if isinstance(expression.op, ast.USub):
return -operand
elif isinstance(expression.op, ast.UAdd):
return operand
elif isinstance(expression.op, ast.Not):
return not operand
elif isinstance(expression.op, ast.Invert):
return ~operand
else:
raise InterpretorError(f"Unary operation {expression.op.__class__.__name__} is not supported.")
def evaluate_lambda(lambda_expression, state, tools):
args = [arg.arg for arg in lambda_expression.args.args]
def lambda_func(*values):
new_state = state.copy()
for arg, value in zip(args, values):
new_state[arg] = value
return evaluate_ast(lambda_expression.body, new_state, tools)
return lambda_func
def evaluate_while(while_loop, state, tools):
max_iterations = 1000
iterations = 0
while evaluate_ast(while_loop.test, state, tools):
for node in while_loop.body:
evaluate_ast(node, state, tools)
iterations += 1
if iterations > max_iterations:
raise InterpretorError(f"Maximum number of {max_iterations} iterations in While loop exceeded")
return None
def evaluate_function_def(function_def, state, tools):
def create_function(func_def, state, tools):
def new_func(*args):
new_state = state.copy()
for arg, val in zip(func_def.args.args, args):
new_state[arg.arg] = val
result = None
for node in func_def.body:
result = evaluate_ast(node, new_state, tools)
return result
return new_func
tools[function_def.name] = create_function(function_def, state, tools)
return None
def evaluate_augassign(expression: ast.AugAssign, state: Dict[str, Any], tools: Dict[str, Callable]):
# Extract the target variable name and the operation
if isinstance(expression.target, ast.Name):
var_name = expression.target.id
current_value = state.get(var_name, 0) # Assuming default of 0 if not in state
value_to_add = evaluate_ast(expression.value, state, tools)
# Determine the operation and apply it
if isinstance(expression.op, ast.Add):
updated_value = current_value + value_to_add
elif isinstance(expression.op, ast.Sub):
updated_value = current_value - value_to_add
elif isinstance(expression.op, ast.Mult):
updated_value = current_value * value_to_add
elif isinstance(expression.op, ast.Div):
updated_value = current_value / value_to_add
# Add other operations as needed
# Update the state
state[var_name] = updated_value
return updated_value
else:
raise InterpretorError("AugAssign not supported for non-simple variable targets.")
def evaluate_boolop(boolop, state, tools):
values = [evaluate_ast(val, state, tools) for val in boolop.values]
op = boolop.op
if isinstance(op, ast.And):
return all(values)
elif isinstance(op, ast.Or):
return any(values)
def evaluate_binop(binop, state, tools):
# Recursively evaluate the left and right operands
left_val = evaluate_ast(binop.left, state, tools)
right_val = evaluate_ast(binop.right, state, tools)
# Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add):
return left_val + right_val
elif isinstance(binop.op, ast.Sub):
return left_val - right_val
elif isinstance(binop.op, ast.Mult):
return left_val * right_val
elif isinstance(binop.op, ast.Div):
return left_val / right_val
elif isinstance(binop.op, ast.Mod):
return left_val % right_val
elif isinstance(binop.op, ast.Pow):
return left_val**right_val
elif isinstance(binop.op, ast.FloorDiv):
return left_val // right_val
elif isinstance(binop.op, ast.BitAnd):
return left_val & right_val
elif isinstance(binop.op, ast.BitOr):
return left_val | right_val
elif isinstance(binop.op, ast.BitXor):
return left_val ^ right_val
elif isinstance(binop.op, ast.LShift):
return left_val << right_val
elif isinstance(binop.op, ast.RShift):
return left_val >> right_val
else:
raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.")
def evaluate_assign(assign, state, tools):
var_names = assign.targets
result = evaluate_ast(assign.value, state, tools)
if len(var_names) == 1:
if isinstance(var_names[0], ast.Tuple):
for i, elem in enumerate(var_names[0].elts):
state[elem.id] = result[i]
else:
state[var_names[0].id] = result
else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
for var_name, r in zip(var_names, result):
state[var_name.id] = r
return result
def evaluate_call(call, state, tools):
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(call.func.value, state, tools)
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpretorError(f"Object {obj} has no attribute {func_name}")
func = getattr(obj, func_name)
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
return func(*args, **kwargs)
elif isinstance(call.func, ast.Name):
func_name = call.func.id
if func_name in state:
func = state[func_name]
elif func_name in tools:
func = tools[func_name]
else:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools or imported functions (tried to execute {call.func.id})."
)
# Todo deal with args
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
output = func(*args, **kwargs)
# store logs of print statements
if func_name == "print":
state["print_outputs"] += output + "\n"
return output
else:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func})."
)
def evaluate_subscript(subscript, state, tools):
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
if isinstance(index, slice):
return value[index]
elif isinstance(value, (list, tuple)):
return value[int(index)]
elif isinstance(value, str):
return value[index]
elif index in value:
return value[index]
elif isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0:
return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools):
if name.id in state:
return state[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
raise InterpretorError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools):
left = evaluate_ast(condition.left, state, tools)
comparators = [evaluate_ast(c, state, tools) for c in condition.comparators]
ops = [type(op) for op in condition.ops]
result = left
for op, comparator in zip(ops, comparators):
if op == ast.Eq:
result = result == comparator
elif op == ast.NotEq:
result = result != comparator
elif op == ast.Lt:
result = result < comparator
elif op == ast.LtE:
result = result <= comparator
elif op == ast.Gt:
result = result > comparator
elif op == ast.GtE:
result = result >= comparator
elif op == ast.Is:
result = result is comparator
elif op == ast.IsNot:
result = result is not comparator
elif op == ast.In:
result = result in comparator
elif op == ast.NotIn:
result = result not in comparator
else:
raise InterpretorError(f"Operator not supported: {op}")
return result
def evaluate_if(if_statement, state, tools):
result = None
for idx, node in enumerate(expression.body):
try:
line_result = evaluate_ast(node, state, tools)
except InterpretorError as e:
msg = f"Evaluation of the code stopped at line {idx} before the end because of the following error"
if chat_mode:
msg += (
f". Copy paste the following error message and send it back to the agent:\nI get an error: '{e}'"
)
else:
msg += f":\n{e}"
print(msg)
break
if line_result is not None:
result = line_result
test_result = evaluate_ast(if_statement.test, state, tools)
if test_result:
for line in if_statement.body:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
return result
def evaluate_for(for_loop, state, tools):
result = None
iterator = evaluate_ast(for_loop.iter, state, tools)
for counter in iterator:
state[for_loop.target.id] = counter
for node in for_loop.body:
try:
line_result = evaluate_ast(node, state, tools)
if line_result is not None:
result = line_result
except BreakException:
break
except ContinueException:
continue
else:
continue
break
return result
def evaluate_listcomp(listcomp, state, tools):
result = []
vars = {}
for generator in listcomp.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value:
vars[var_name] = value
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
elem = evaluate_ast(listcomp.elt, {**state, **vars}, tools)
result.append(elem)
return result
def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Callable]):
"""
Evaluate an absract syntax tree using the content of the variables stored in a state and only evaluating a given
Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given
set of functions.
This function will recurse trough the nodes of the tree provided.
......@@ -96,12 +358,39 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
# Assignement -> we evaluate the assignement which should update the state
# We return the variable assigned as it may be used to determine the final result.
return evaluate_assign(expression, state, tools)
elif isinstance(expression, ast.AugAssign):
return evaluate_augassign(expression, state, tools)
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
return evaluate_call(expression, state, tools)
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(evaluate_ast(elt, state, tools) for elt in expression.elts)
elif isinstance(expression, ast.ListComp):
return evaluate_listcomp(expression, state, tools)
elif isinstance(expression, ast.UnaryOp):
return evaluate_unaryop(expression, state, tools)
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
return evaluate_boolop(expression, state, tools)
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
return evaluate_binop(expression, state, tools)
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
return evaluate_condition(expression, state, tools)
elif isinstance(expression, ast.Return):
return evaluate_ast(expression.value, state, tools)
elif isinstance(expression, ast.Lambda):
return evaluate_lambda(expression, state, tools)
elif isinstance(expression, ast.FunctionDef):
return evaluate_function_def(expression, state, tools)
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [evaluate_ast(k, state, tools) for k in expression.keys]
......@@ -132,122 +421,100 @@ def evaluate_ast(expression: ast.AST, state: Dict[str, Any], tools: Dict[str, Ca
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
return evaluate_subscript(expression, state, tools)
elif isinstance(expression, ast.IfExp):
test_val = evaluate_ast(expression.test, state, tools)
if test_val:
return evaluate_ast(expression.body, state, tools)
else:
return evaluate_ast(expression.orelse, state, tools)
elif isinstance(expression, ast.Attribute):
obj = evaluate_ast(expression.value, state, tools)
return getattr(obj, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
evaluate_ast(expression.lower, state, tools) if expression.lower is not None else None,
evaluate_ast(expression.upper, state, tools) if expression.upper is not None else None,
evaluate_ast(expression.step, state, tools) if expression.step is not None else None,
)
elif isinstance(expression, ast.ListComp) or isinstance(expression, ast.GeneratorExp):
result = []
vars = {}
for generator in expression.generators:
var_name = generator.target.id
iter_value = evaluate_ast(generator.iter, state, tools)
for value in iter_value:
vars[var_name] = value
if all(evaluate_ast(if_clause, {**state, **vars}, tools) for if_clause in generator.ifs):
elem = evaluate_ast(expression.elt, {**state, **vars}, tools)
result.append(elem)
return result
elif isinstance(expression, ast.DictComp):
result = {}
for gen in expression.generators:
for container in get_iterable(evaluate_ast(gen.iter, state, tools)):
state[gen.target.id] = container
key = evaluate_ast(expression.key, state, tools)
value = evaluate_ast(expression.value, state, tools)
result[key] = value
return result
elif isinstance(expression, ast.Import):
for alias in expression.names:
if alias.name in LIST_SAFE_MODULES:
module = __import__(alias.name)
state[alias.asname or alias.name] = module
else:
raise InterpretorError(f"Import of {alias.name} is not allowed.")
return None
elif isinstance(expression, ast.While):
return evaluate_while(expression, state, tools)
elif isinstance(expression, ast.ImportFrom):
if expression.module in LIST_SAFE_MODULES:
module = __import__(expression.module)
for alias in expression.names:
state[alias.asname or alias.name] = getattr(module, alias.name)
else:
raise InterpretorError(f"Import from {expression.module} is not allowed.")
return None
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpretorError(f"{expression.__class__.__name__} is not supported.")
def evaluate_assign(assign, state, tools):
var_names = assign.targets
result = evaluate_ast(assign.value, state, tools)
if len(var_names) == 1:
state[var_names[0].id] = result
else:
if len(result) != len(var_names):
raise InterpretorError(f"Expected {len(var_names)} values but got {len(result)}.")
for var_name, r in zip(var_names, result):
state[var_name.id] = r
return result
def evaluate_call(call, state, tools):
if not isinstance(call.func, ast.Name):
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func} of "
f"type {type(call.func)}."
)
func_name = call.func.id
if func_name not in tools:
raise InterpretorError(
f"It is not permitted to evaluate other functions than the provided tools (tried to execute {call.func.id})."
)
func = tools[func_name]
# Todo deal with args
args = [evaluate_ast(arg, state, tools) for arg in call.args]
kwargs = {keyword.arg: evaluate_ast(keyword.value, state, tools) for keyword in call.keywords}
return func(*args, **kwargs)
def evaluate_subscript(subscript, state, tools):
index = evaluate_ast(subscript.slice, state, tools)
value = evaluate_ast(subscript.value, state, tools)
if isinstance(value, (list, tuple)):
return value[int(index)]
if index in value:
return value[index]
if isinstance(index, str) and isinstance(value, Mapping):
close_matches = difflib.get_close_matches(index, list(value.keys()))
if len(close_matches) > 0:
return value[close_matches[0]]
raise InterpretorError(f"Could not index {value} with '{index}'.")
def evaluate_name(name, state, tools):
if name.id in state:
return state[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
if len(close_matches) > 0:
return state[close_matches[0]]
raise InterpretorError(f"The variable `{name.id}` is not defined.")
def evaluate_condition(condition, state, tools):
if len(condition.ops) > 1:
raise InterpretorError("Cannot evaluate conditions with multiple operators")
left = evaluate_ast(condition.left, state, tools)
comparator = condition.ops[0]
right = evaluate_ast(condition.comparators[0], state, tools)
if isinstance(comparator, ast.Eq):
return left == right
elif isinstance(comparator, ast.NotEq):
return left != right
elif isinstance(comparator, ast.Lt):
return left < right
elif isinstance(comparator, ast.LtE):
return left <= right
elif isinstance(comparator, ast.Gt):
return left > right
elif isinstance(comparator, ast.GtE):
return left >= right
elif isinstance(comparator, ast.Is):
return left is right
elif isinstance(comparator, ast.IsNot):
return left is not right
elif isinstance(comparator, ast.In):
return left in right
elif isinstance(comparator, ast.NotIn):
return left not in right
else:
raise InterpretorError(f"Operator not supported: {comparator}")
def evaluate_python_code(code: str, tools: Optional[Dict[str, Callable]] = {}, state=None):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
of functions.
This function will recurse through the nodes of the tree provided.
def evaluate_if(if_statement, state, tools):
Args:
code (`str`):
The code to evaluate.
tools (`Dict[str, Callable]`):
The functions that may be called during the evaluation. Any call to another function will fail with an
`InterpretorError`.
state (`Dict[str, Any]`):
A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be
updated by this function to contain all variables as they are evaluated.
The print outputs will be stored in the state under the key 'print_outputs'.
"""
try:
expression = ast.parse(code)
except SyntaxError as e:
raise SyntaxError(f"The code generated by the agent is not valid.\n{e}")
if state is None:
state = {}
result = None
if evaluate_condition(if_statement.test, state, tools):
for line in if_statement.body:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
line_result = evaluate_ast(line, state, tools)
if line_result is not None:
result = line_result
return result
state["print_outputs"] = ""
for idx, node in enumerate(expression.body):
try:
line_result = evaluate_ast(node, state, tools)
except InterpretorError as e:
msg = f"You tried to execute the following code:\n{code}\n"
msg += f"You got these outputs:\n{state['print_outputs']}\n"
msg += f"Evaluation stopped at line '{node}' because of the following error:\n{e}"
raise InterpretorError(msg)
if line_result is not None:
result = line_result
def evaluate_for(for_loop, state, tools):
result = None
iterator = evaluate_ast(for_loop.iter, state, tools)
for counter in iterator:
state[for_loop.target.id] = counter
for expression in for_loop.body:
line_result = evaluate_ast(expression, state, tools)
if line_result is not None:
result = line_result
return result
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,28 +14,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.whisper import WhisperForConditionalGeneration, WhisperProcessor
from .base import PipelineTool
from .tools import PipelineTool
class SpeechToTextTool(PipelineTool):
default_checkpoint = "openai/whisper-base"
description = (
"This is a tool that transcribes an audio into text. It takes an input named `audio` and returns the "
"transcribed text."
)
default_checkpoint = "distil-whisper/distil-large-v3"
description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
name = "transcriber"
pre_processor_class = WhisperProcessor
model_class = WhisperForConditionalGeneration
inputs = ["audio"]
outputs = ["text"]
inputs = {"audio": {"type": "audio", "description": "The audio to transcribe"}}
output_type = "text"
def encode(self, audio):
return self.pre_processor(audio, return_tensors="pt").input_features
return self.pre_processor(audio, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(inputs=inputs)
return self.model.generate(inputs["input_features"])
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,11 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.speecht5 import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
from ..utils import is_datasets_available
from .base import PipelineTool
from .tools import PipelineTool
if is_datasets_available():
......@@ -28,16 +29,15 @@ if is_datasets_available():
class TextToSpeechTool(PipelineTool):
default_checkpoint = "microsoft/speecht5_tts"
description = (
"This is a tool that reads an English text out loud. It takes an input named `text` which should contain the "
"text to read (in English) and returns a waveform object containing the sound."
"This is a tool that reads an English text out loud. It returns a waveform object containing the sound."
)
name = "text_reader"
name = "text_to_speech"
pre_processor_class = SpeechT5Processor
model_class = SpeechT5ForTextToSpeech
post_processor_class = SpeechT5HifiGan
inputs = ["text"]
outputs = ["audio"]
inputs = {"text": {"type": "text", "description": "The text to read out loud (in English)"}}
output_type = "audio"
def setup(self):
if self.post_processor is None:
......
......@@ -16,18 +16,22 @@
# limitations under the License.
import base64
import importlib
import inspect
import io
import json
import os
import tempfile
from functools import lru_cache
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import create_repo, hf_hub_download, metadata_update, upload_folder
from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
from packaging import version
from ..dynamic_module_utils import custom_object_save, get_class_from_dynamic_module, get_imports
from ..image_utils import is_pil_image
from ..dynamic_module_utils import (
custom_object_save,
get_class_from_dynamic_module,
get_imports,
)
from ..models.auto import AutoProcessor
from ..utils import (
CONFIG_NAME,
......@@ -42,6 +46,11 @@ from .agent_types import handle_agent_inputs, handle_agent_outputs
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if is_torch_available():
import torch
......@@ -89,30 +98,46 @@ class Tool:
returns the text contained in the file'.
- **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
`"text-classifier"` or `"image_generator"`.
- **inputs** (`List[str]`) -- The list of modalities expected for the inputs (in the same order as in the call).
Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo` or to make a
nice space from your tool.
- **outputs** (`List[str]`) -- The list of modalities returned but the tool (in the same order as the return of the
call method). Modalitiies should be `"text"`, `"image"` or `"audio"`. This is only used by `launch_gradio_demo`
or to make a nice space from your tool.
- **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
It has one `type`key and a `description`key.
This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
description for your tool.
- **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
or to make a nice space from your tool, and also can be used in the generated description for your tool.
You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
instantiation.
"""
description: str = "This is a tool that ..."
name: str = ""
inputs: List[str]
outputs: List[str]
name: str
description: str
inputs: Dict[str, Dict[str, Union[str, type]]]
output_type: type
def __init__(self, *args, **kwargs):
self.is_initialized = False
def __call__(self, *args, **kwargs):
def validate_attributes(self):
required_attributes = {
"description": str,
"name": str,
"inputs": Dict,
"output_type": type,
}
for attr, expected_type in required_attributes.items():
attr_value = getattr(self, attr, None)
if not isinstance(attr_value, expected_type):
raise TypeError(f"Instance attribute {attr} must exist and be of type {expected_type.__name__}")
def forward(self, *args, **kwargs):
return NotImplemented("Write this method in your subclass of `Tool`.")
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
outputs = self.forward(*args, **kwargs)
return handle_agent_outputs(outputs, self.output_type)
def setup(self):
"""
Overwrite this method here for any operation that is expensive and needs to be executed before you start using
......@@ -156,7 +181,13 @@ class Tool:
else:
tool_config = {}
tool_config = {"tool_class": full_name, "description": self.description, "name": self.name}
tool_config = {
"tool_class": full_name,
"description": self.description,
"name": self.name,
"inputs": str(self.inputs),
"output_type": str(self.output_type),
}
with open(config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
......@@ -180,7 +211,6 @@ class Tool:
repo_id: str,
model_repo_id: Optional[str] = None,
token: Optional[str] = None,
remote: bool = False,
**kwargs,
):
"""
......@@ -203,21 +233,11 @@ class Tool:
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
remote (`bool`, *optional*, defaults to `False`):
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
`cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
others will be passed along to its init.
"""
if remote and model_repo_id is None:
endpoints = get_default_endpoints()
if repo_id not in endpoints:
raise ValueError(
f"Could not infer a default endpoint for {repo_id}, you need to pass one using the "
"`model_repo_id` argument."
)
model_repo_id = endpoints[repo_id]
hub_kwargs_names = [
"cache_dir",
"force_download",
......@@ -290,8 +310,11 @@ class Tool:
)
tool_class.description = custom_tool["description"]
if remote:
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
if tool_class.inputs != custom_tool["inputs"]:
tool_class.inputs = custom_tool["inputs"]
if tool_class.output_type != custom_tool["output_type"]:
tool_class.output_type = custom_tool["output_type"]
return tool_class(model_repo_id, token=token, **kwargs)
def push_to_hub(
......@@ -305,6 +328,14 @@ class Tool:
"""
Upload the tool to the Hub.
For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
For instance:
```
from my_tool_module import MyTool
my_tool = MyTool()
my_tool.push_to_hub("my-username/my-space")
```
Parameters:
repo_id (`str`):
The name of the repository you want to push your tool to. It should contain your organization name when
......@@ -320,7 +351,12 @@ class Tool:
Whether or not to create a PR with the uploaded files or directly commit.
"""
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="space", space_sdk="gradio"
repo_id=repo_id,
token=token,
private=private,
exist_ok=True,
repo_type="space",
space_sdk="gradio",
)
repo_id = repo_url.repo_id
metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
......@@ -343,102 +379,81 @@ class Tool:
"""
Creates a [`Tool`] from a gradio tool.
"""
import inspect
class GradioToolWrapper(Tool):
def __init__(self, _gradio_tool):
super().__init__()
self.name = _gradio_tool.name
self.description = _gradio_tool.description
self.output_type = "text"
self._gradio_tool = _gradio_tool
func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
self.inputs = {key: "" for key in func_args}
GradioToolWrapper.__call__ = gradio_tool.run
return GradioToolWrapper(gradio_tool)
def forward(self, *args, **kwargs):
return self._gradio_tool.run(*args, **kwargs)
return GradioToolWrapper(gradio_tool)
class RemoteTool(Tool):
"""
A [`Tool`] that will make requests to an inference endpoint.
@staticmethod
def from_langchain(langchain_tool):
"""
Creates a [`Tool`] from a langchain tool.
"""
Args:
endpoint_url (`str`, *optional*):
The url of the endpoint to use.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
tool_class (`type`, *optional*):
The corresponding `tool_class` if this is a remote version of an existing tool. Will help determine when
the output should be converted to another type (like images).
"""
class LangChainToolWrapper(Tool):
def __init__(self, _langchain_tool):
super().__init__()
self.name = _langchain_tool.name.lower()
self.description = _langchain_tool.description
self.inputs = parse_langchain_args(_langchain_tool.args)
self.output_type = "text"
self.langchain_tool = _langchain_tool
def forward(self, *args, **kwargs):
tool_input = kwargs.copy()
for index, argument in enumerate(args):
if index < len(self.inputs):
input_key = next(iter(self.inputs))
tool_input[input_key] = argument
return self.langchain_tool.run(tool_input)
return LangChainToolWrapper(langchain_tool)
DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
- {{ tool.name }}: {{ tool.description }}
Takes inputs: {{tool.inputs}}
"""
def __init__(self, endpoint_url=None, token=None, tool_class=None):
self.endpoint_url = endpoint_url
self.client = EndpointClient(endpoint_url, token=token)
self.tool_class = tool_class
def prepare_inputs(self, *args, **kwargs):
"""
Prepare the inputs received for the HTTP client sending data to the endpoint. Positional arguments will be
matched with the signature of the `tool_class` if it was provided at instantation. Images will be encoded into
bytes.
def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
compiled_template = compile_jinja_template(description_template)
rendered = compiled_template.render(
tool=tool,
)
return rendered
You can override this method in your custom class of [`RemoteTool`].
"""
inputs = kwargs.copy()
if len(args) > 0:
if self.tool_class is not None:
# Match args with the signature
if issubclass(self.tool_class, PipelineTool):
call_method = self.tool_class.encode
else:
call_method = self.tool_class.__call__
signature = inspect.signature(call_method).parameters
parameters = [
k
for k, p in signature.items()
if p.kind not in [inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD]
]
if parameters[0] == "self":
parameters = parameters[1:]
if len(args) > len(parameters):
raise ValueError(
f"{self.tool_class} only accepts {len(parameters)} arguments but {len(args)} were given."
)
for arg, name in zip(args, parameters):
inputs[name] = arg
elif len(args) > 1:
raise ValueError("A `RemoteTool` can only accept one positional input.")
elif len(args) == 1:
if is_pil_image(args[0]):
return {"inputs": self.client.encode_image(args[0])}
return {"inputs": args[0]}
for key, value in inputs.items():
if is_pil_image(value):
inputs[key] = self.client.encode_image(value)
return {"inputs": inputs}
def extract_outputs(self, outputs):
"""
You can override this method in your custom class of [`RemoteTool`] to apply some custom post-processing of the
outputs of the endpoint.
"""
return outputs
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
@lru_cache
def compile_jinja_template(template):
try:
import jinja2
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("template requires jinja2 to be installed.")
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
outputs = self.client(**inputs, output_image=output_image)
else:
outputs = self.client(inputs, output_image=output_image)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
raise ImportError("template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}.")
outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)
def raise_exception(message):
raise TemplateError(message)
return self.extract_outputs(outputs)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(template)
class PipelineTool(Tool):
......@@ -483,6 +498,10 @@ class PipelineTool(Tool):
model_class = None
post_processor_class = AutoProcessor
default_checkpoint = None
description = "This is a pipeline tool"
name = "pipeline"
inputs = {"prompt": str}
output_type = str
def __init__(
self,
......@@ -573,18 +592,22 @@ class PipelineTool(Tool):
self.setup()
encoded_inputs = self.encode(*args, **kwargs)
encoded_inputs = send_to_device(encoded_inputs, self.device)
outputs = self.forward(encoded_inputs)
tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
encoded_inputs = send_to_device(tensor_inputs, self.device)
outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
outputs = send_to_device(outputs, "cpu")
decoded_outputs = self.decode(outputs)
return handle_agent_outputs(decoded_outputs, self.outputs)
return handle_agent_outputs(decoded_outputs, self.output_type)
def launch_gradio_demo(tool_class: Tool):
"""
Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
`inputs` and `outputs`.
`inputs` and `output_type`.
Args:
tool_class (`type`): The class of the tool for which to launch the demo.
......@@ -599,10 +622,26 @@ def launch_gradio_demo(tool_class: Tool):
def fn(*args, **kwargs):
return tool(*args, **kwargs)
gradio_inputs = []
for input_type in [tool_input["type"] for tool_input in tool_class.inputs.values()]:
if input_type in [str, int, float]:
gradio_inputs += "text"
elif is_vision_available() and input_type == PIL.Image.Image:
gradio_inputs += "image"
else:
gradio_inputs += "audio"
if tool_class.output_type in [str, int, float]:
gradio_output = "text"
elif is_vision_available() and tool_class.output_type == PIL.Image.Image:
gradio_output = "image"
else:
gradio_output = "audio"
gr.Interface(
fn=fn,
inputs=tool_class.inputs,
outputs=tool_class.outputs,
inputs=gradio_inputs,
outputs=gradio_output,
title=tool_class.__name__,
article=tool.description,
).launch()
......@@ -610,31 +649,16 @@ def launch_gradio_demo(tool_class: Tool):
TASK_MAPPING = {
"document-question-answering": "DocumentQuestionAnsweringTool",
"image-captioning": "ImageCaptioningTool",
"image-question-answering": "ImageQuestionAnsweringTool",
"image-segmentation": "ImageSegmentationTool",
"speech-to-text": "SpeechToTextTool",
"summarization": "TextSummarizationTool",
"text-classification": "TextClassificationTool",
"text-question-answering": "TextQuestionAnsweringTool",
"text-to-speech": "TextToSpeechTool",
"translation": "TranslationTool",
"python_interpreter": "PythonInterpreterTool",
"final_answer": "FinalAnswerTool",
}
def get_default_endpoints():
endpoints_file = cached_file("huggingface-tools/default-endpoints", "default_endpoints.json", repo_type="dataset")
with open(endpoints_file, "r", encoding="utf-8") as f:
endpoints = json.load(f)
return endpoints
def supports_remote(task_or_repo_id):
endpoints = get_default_endpoints()
return task_or_repo_id in endpoints
def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **kwargs):
def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
"""
Main function to quickly load a tool, be it on the Hub or in the Transformers library.
......@@ -652,20 +676,13 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
are:
- `"document-question-answering"`
- `"image-captioning"`
- `"image-question-answering"`
- `"image-segmentation"`
- `"speech-to-text"`
- `"summarization"`
- `"text-classification"`
- `"text-question-answering"`
- `"text-to-speech"`
- `"translation"`
model_repo_id (`str`, *optional*):
Use this argument to use a different model than the default one for the tool you selected.
remote (`bool`, *optional*, defaults to `False`):
Whether to use your tool by downloading the model or (if it is available) with an inference endpoint.
token (`str`, *optional*):
The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
login` (stored in `~/.huggingface`).
......@@ -677,21 +694,9 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
if task_or_repo_id in TASK_MAPPING:
tool_class_name = TASK_MAPPING[task_or_repo_id]
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
tools_module = main_module.agents
tool_class = getattr(tools_module, tool_class_name)
if remote:
if model_repo_id is None:
endpoints = get_default_endpoints()
if task_or_repo_id not in endpoints:
raise ValueError(
f"Could not infer a default endpoint for {task_or_repo_id}, you need to pass one using the "
"`model_repo_id` argument."
)
model_repo_id = endpoints[task_or_repo_id]
return RemoteTool(model_repo_id, token=token, tool_class=tool_class)
else:
return tool_class(model_repo_id, token=token, **kwargs)
return tool_class(model_repo_id, token=token, **kwargs)
else:
logger.warning_once(
f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
......@@ -699,7 +704,7 @@ def load_tool(task_or_repo_id, model_repo_id=None, remote=False, token=None, **k
f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
f"code that you have checked."
)
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, remote=remote, **kwargs)
return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
def add_description(description):
......@@ -718,7 +723,10 @@ def add_description(description):
## Will move to the Hub
class EndpointClient:
def __init__(self, endpoint_url: str, token: Optional[str] = None):
self.headers = {**build_hf_headers(token=token), "Content-Type": "application/json"}
self.headers = {
**build_hf_headers(token=token),
"Content-Type": "application/json",
}
self.endpoint_url = endpoint_url
@staticmethod
......@@ -763,3 +771,44 @@ class EndpointClient:
return self.decode_image(response.content)
else:
return response.json()
def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
"""Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
inputs = args.copy()
for arg_details in inputs.values():
if "title" in arg_details:
arg_details.pop("title")
return inputs
class ToolCollection:
"""
Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
> [!NOTE]
> Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
> like for this collection to showcase them.
Args:
collection_slug (str):
The collection slug referencing the collection.
token (str, *optional*):
The authentication token if the collection is private.
Example:
```py
>>> from transformers import ToolCollection, ReactCodeAgent
>>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
>>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
>>> agent.run("Please draw me a picture of rivers and lakes.")
```
"""
def __init__(self, collection_slug: str, token: Optional[str] = None):
self._collection = get_collection(collection_slug, token=token)
self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
......@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
from .tools import PipelineTool
LANGUAGE_CODES = {
......@@ -231,27 +231,35 @@ class TranslationTool(PipelineTool):
Example:
```py
from transformers.tools import TranslationTool
from transformers.agents import TranslationTool
translator = TranslationTool()
translator("This is a super nice API!", src_lang="English", tgt_lang="French")
```
"""
lang_to_code = LANGUAGE_CODES
default_checkpoint = "facebook/nllb-200-distilled-600M"
description = (
"This is a tool that translates text from a language to another. It takes three inputs: `text`, which should "
"be the text to translate, `src_lang`, which should be the language of the text to translate and `tgt_lang`, "
"which should be the language for the desired ouput language. Both `src_lang` and `tgt_lang` are written in "
"plain English, such as 'Romanian', or 'Albanian'. It returns the text translated in `tgt_lang`."
"This is a tool that translates text from a language to another."
f"Both `src_lang`and `tgt_lang` should belong to this list of languages: {list(lang_to_code.keys())}."
)
name = "translator"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
lang_to_code = LANGUAGE_CODES
inputs = ["text", "text", "text"]
outputs = ["text"]
inputs = {
"text": {"type": "text", "description": "The text to translate"},
"src_lang": {
"type": "text",
"description": "The language of the text to translate. Written in plain English, such as 'Romanian', or 'Albanian'",
},
"tgt_lang": {
"type": "text",
"description": "The language for the desired ouput language. Written in plain English, such as 'Romanian', or 'Albanian'",
},
}
output_type = "text"
def encode(self, text, src_lang, tgt_lang):
if src_lang not in self.lang_to_code:
......
......@@ -201,7 +201,7 @@ _run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=Fa
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
_run_tool_tests = parse_flag_from_env("RUN_TOOL_TESTS", default=False)
_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
......@@ -276,19 +276,19 @@ def is_pipeline_test(test_case):
return pytest.mark.is_pipeline_test()(test_case)
def is_tool_test(test_case):
def is_agent_test(test_case):
"""
Decorator marking a test as a tool test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
"""
if not _run_tool_tests:
return unittest.skip("test is a tool test")(test_case)
if not _run_agent_tests:
return unittest.skip("test is an agent test")(test_case)
else:
try:
import pytest # We don't need a hard dependency on pytest in the main library
except ImportError:
return test_case
else:
return pytest.mark.is_tool_test()(test_case)
return pytest.mark.is_agent_test()(test_case)
def slow(test_case):
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import json
import os
import time
from dataclasses import dataclass
from typing import Dict
import requests
from huggingface_hub import HfFolder, hf_hub_download, list_spaces
from ..models.auto import AutoTokenizer
from ..utils import is_offline_mode, is_openai_available, is_torch_available, logging
from .base import TASK_MAPPING, TOOL_CONFIG_FILE, Tool, load_tool, supports_remote
from .prompts import CHAT_MESSAGE_PROMPT, download_prompt
from .python_interpreter import evaluate
logger = logging.get_logger(__name__)
if is_openai_available():
import openai
if is_torch_available():
from ..generation import StoppingCriteria, StoppingCriteriaList
from ..models.auto import AutoModelForCausalLM
else:
StoppingCriteria = object
_tools_are_initialized = False
BASE_PYTHON_TOOLS = {
"print": print,
"range": range,
"float": float,
"int": int,
"bool": bool,
"str": str,
}
@dataclass
class PreTool:
task: str
description: str
repo_id: str
HUGGINGFACE_DEFAULT_TOOLS = {}
HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
"image-transformation",
"text-download",
"text-to-image",
"text-to-video",
]
def get_remote_tools(organization="huggingface-tools"):
if is_offline_mode():
logger.info("You are in offline mode, so remote tools are not available.")
return {}
spaces = list_spaces(author=organization)
tools = {}
for space_info in spaces:
repo_id = space_info.id
resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
with open(resolved_config_file, encoding="utf-8") as reader:
config = json.load(reader)
task = repo_id.split("/")[-1]
tools[config["name"]] = PreTool(task=task, description=config["description"], repo_id=repo_id)
return tools
def _setup_default_tools():
global HUGGINGFACE_DEFAULT_TOOLS
global _tools_are_initialized
if _tools_are_initialized:
return
main_module = importlib.import_module("transformers")
tools_module = main_module.tools
remote_tools = get_remote_tools()
for task_name, tool_class_name in TASK_MAPPING.items():
tool_class = getattr(tools_module, tool_class_name)
description = tool_class.description
HUGGINGFACE_DEFAULT_TOOLS[tool_class.name] = PreTool(task=task_name, description=description, repo_id=None)
if not is_offline_mode():
for task_name in HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB:
found = False
for tool_name, tool in remote_tools.items():
if tool.task == task_name:
HUGGINGFACE_DEFAULT_TOOLS[tool_name] = tool
found = True
break
if not found:
raise ValueError(f"{task_name} is not implemented on the Hub.")
_tools_are_initialized = True
def resolve_tools(code, toolbox, remote=False, cached_tools=None):
if cached_tools is None:
resolved_tools = BASE_PYTHON_TOOLS.copy()
else:
resolved_tools = cached_tools
for name, tool in toolbox.items():
if name not in code or name in resolved_tools:
continue
if isinstance(tool, Tool):
resolved_tools[name] = tool
else:
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
_remote = remote and supports_remote(task_or_repo_id)
resolved_tools[name] = load_tool(task_or_repo_id, remote=_remote)
return resolved_tools
def get_tool_creation_code(code, toolbox, remote=False):
code_lines = ["from transformers import load_tool", ""]
for name, tool in toolbox.items():
if name not in code or isinstance(tool, Tool):
continue
task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
line = f'{name} = load_tool("{task_or_repo_id}"'
if remote:
line += ", remote=True"
line += ")"
code_lines.append(line)
return "\n".join(code_lines) + "\n"
def clean_code_for_chat(result):
lines = result.split("\n")
idx = 0
while idx < len(lines) and not lines[idx].lstrip().startswith("```"):
idx += 1
explanation = "\n".join(lines[:idx]).strip()
if idx == len(lines):
return explanation, None
idx += 1
start_idx = idx
while not lines[idx].lstrip().startswith("```"):
idx += 1
code = "\n".join(lines[start_idx:idx]).strip()
return explanation, code
def clean_code_for_run(result):
result = f"I will use the following {result}"
explanation, code = result.split("Answer:")
explanation = explanation.strip()
code = code.strip()
code_lines = code.split("\n")
if code_lines[0] in ["```", "```py", "```python"]:
code_lines = code_lines[1:]
if code_lines[-1] == "```":
code_lines = code_lines[:-1]
code = "\n".join(code_lines)
return explanation, code
class Agent:
"""
Base class for all agents which contains the main API methods.
Args:
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
"""
def __init__(self, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
_setup_default_tools()
agent_name = self.__class__.__name__
self.chat_prompt_template = download_prompt(chat_prompt_template, agent_name, mode="chat")
self.run_prompt_template = download_prompt(run_prompt_template, agent_name, mode="run")
self._toolbox = HUGGINGFACE_DEFAULT_TOOLS.copy()
self.log = print
if additional_tools is not None:
if isinstance(additional_tools, (list, tuple)):
additional_tools = {t.name: t for t in additional_tools}
elif not isinstance(additional_tools, dict):
additional_tools = {additional_tools.name: additional_tools}
replacements = {name: tool for name, tool in additional_tools.items() if name in HUGGINGFACE_DEFAULT_TOOLS}
self._toolbox.update(additional_tools)
if len(replacements) > 1:
names = "\n".join([f"- {n}: {t}" for n, t in replacements.items()])
logger.warning(
f"The following tools have been replaced by the ones provided in `additional_tools`:\n{names}."
)
elif len(replacements) == 1:
name = list(replacements.keys())[0]
logger.warning(f"{name} has been replaced by {replacements[name]} as provided in `additional_tools`.")
self.prepare_for_new_chat()
@property
def toolbox(self) -> Dict[str, Tool]:
"""Get all tool currently available to the agent"""
return self._toolbox
def format_prompt(self, task, chat_mode=False):
description = "\n".join([f"- {name}: {tool.description}" for name, tool in self.toolbox.items()])
if chat_mode:
if self.chat_history is None:
prompt = self.chat_prompt_template.replace("<<all_tools>>", description)
else:
prompt = self.chat_history
prompt += CHAT_MESSAGE_PROMPT.replace("<<task>>", task)
else:
prompt = self.run_prompt_template.replace("<<all_tools>>", description)
prompt = prompt.replace("<<prompt>>", task)
return prompt
def set_stream(self, streamer):
"""
Set the function use to stream results (which is `print` by default).
Args:
streamer (`callable`): The function to call when streaming results from the LLM.
"""
self.log = streamer
def chat(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a new request to the agent in a chat. Will use the previous ones in its history.
Args:
task (`str`): The task to perform
return_code (`bool`, *optional*, defaults to `False`):
Whether to just return code and not evaluate it.
remote (`bool`, *optional*, defaults to `False`):
Whether or not to use remote tools (inference endpoints) instead of local ones.
kwargs (additional keyword arguments, *optional*):
Any keyword argument to send to the agent when evaluating the code.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.chat("Draw me a picture of rivers and lakes")
agent.chat("Transform the picture so that there is a rock in there")
```
"""
prompt = self.format_prompt(task, chat_mode=True)
result = self.generate_one(prompt, stop=["Human:", "====="])
self.chat_history = prompt + result.strip() + "\n"
explanation, code = clean_code_for_chat(result)
self.log(f"==Explanation from the agent==\n{explanation}")
if code is not None:
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
self.chat_state.update(kwargs)
return evaluate(code, self.cached_tools, self.chat_state, chat_mode=True)
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
def prepare_for_new_chat(self):
"""
Clears the history of prior calls to [`~Agent.chat`].
"""
self.chat_history = None
self.chat_state = {}
self.cached_tools = None
def clean_code_for_run(self, result):
"""
Override this method if you want to change the way the code is
cleaned for the `run` method.
"""
return clean_code_for_run(result)
def run(self, task, *, return_code=False, remote=False, **kwargs):
"""
Sends a request to the agent.
Args:
task (`str`): The task to perform
return_code (`bool`, *optional*, defaults to `False`):
Whether to just return code and not evaluate it.
remote (`bool`, *optional*, defaults to `False`):
Whether or not to use remote tools (inference endpoints) instead of local ones.
kwargs (additional keyword arguments, *optional*):
Any keyword argument to send to the agent when evaluating the code.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Draw me a picture of rivers and lakes")
```
"""
prompt = self.format_prompt(task)
result = self.generate_one(prompt, stop=["Task:"])
explanation, code = self.clean_code_for_run(result)
self.log(f"==Explanation from the agent==\n{explanation}")
self.log(f"\n\n==Code generated by the agent==\n{code}")
if not return_code:
self.log("\n\n==Result==")
self.cached_tools = resolve_tools(code, self.toolbox, remote=remote, cached_tools=self.cached_tools)
return evaluate(code, self.cached_tools, state=kwargs.copy())
else:
tool_code = get_tool_creation_code(code, self.toolbox, remote=remote)
return f"{tool_code}\n{code}"
def generate_one(self, prompt, stop):
# This is the method to implement in your custom agent.
raise NotImplementedError
def generate_many(self, prompts, stop):
# Override if you have a way to do batch generation faster than one by one
return [self.generate_one(prompt, stop) for prompt in prompts]
class OpenAiAgent(Agent):
"""
Agent that uses the openai API to generate code.
<Tip warning={true}>
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
</Tip>
Args:
model (`str`, *optional*, defaults to `"text-davinci-003"`):
The name of the OpenAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"OPENAI_API_KEY"`.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import OpenAiAgent
agent = OpenAiAgent(model="text-davinci-003", api_key=xxx)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self,
model="text-davinci-003",
api_key=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an openai key to use `OpenAIAgent`. You can get one here: Get one here "
"https://openai.com/api/`. If you have one, set it in your env with `os.environ['OPENAI_API_KEY'] = "
"xxx."
)
else:
openai.api_key = api_key
self.model = model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_many(self, prompts, stop):
if "gpt" in self.model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)
def generate_one(self, prompt, stop):
if "gpt" in self.model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]
def _chat_generate(self, prompt, stop):
result = openai.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result.choices[0].message.content
def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
model=self.model,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]
class AzureOpenAiAgent(Agent):
"""
Agent that uses Azure OpenAI to generate code. See the [official
documentation](https://learn.microsoft.com/en-us/azure/cognitive-services/openai/) to learn how to deploy an openAI
model on Azure
<Tip warning={true}>
The openAI models are used in generation mode, so even for the `chat()` API, it's better to use models like
`"text-davinci-003"` over the chat-GPT variant. Proper support for chat-GPT models will come in a next version.
</Tip>
Args:
deployment_id (`str`):
The name of the deployed Azure openAI model to use.
api_key (`str`, *optional*):
The API key to use. If unset, will look for the environment variable `"AZURE_OPENAI_API_KEY"`.
resource_name (`str`, *optional*):
The name of your Azure OpenAI Resource. If unset, will look for the environment variable
`"AZURE_OPENAI_RESOURCE_NAME"`.
api_version (`str`, *optional*, default to `"2022-12-01"`):
The API version to use for this agent.
is_chat_mode (`bool`, *optional*):
Whether you are using a completion model or a chat model (see note above, chat models won't be as
efficient). Will default to `gpt` being in the `deployment_id` or not.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import AzureOpenAiAgent
agent = AzureAiAgent(deployment_id="Davinci-003", api_key=xxx, resource_name=yyy)
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self,
deployment_id,
api_key=None,
resource_name=None,
api_version="2022-12-01",
is_chat_model=None,
chat_prompt_template=None,
run_prompt_template=None,
additional_tools=None,
):
if not is_openai_available():
raise ImportError("Using `OpenAiAgent` requires `openai`: `pip install openai`.")
self.deployment_id = deployment_id
openai.api_type = "azure"
if api_key is None:
api_key = os.environ.get("AZURE_OPENAI_API_KEY", None)
if api_key is None:
raise ValueError(
"You need an Azure openAI key to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_API_KEY'] = xxx."
)
else:
openai.api_key = api_key
if resource_name is None:
resource_name = os.environ.get("AZURE_OPENAI_RESOURCE_NAME", None)
if resource_name is None:
raise ValueError(
"You need a resource_name to use `AzureOpenAIAgent`. If you have one, set it in your env with "
"`os.environ['AZURE_OPENAI_RESOURCE_NAME'] = xxx."
)
else:
openai.api_base = f"https://{resource_name}.openai.azure.com"
openai.api_version = api_version
if is_chat_model is None:
is_chat_model = "gpt" in deployment_id.lower()
self.is_chat_model = is_chat_model
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_many(self, prompts, stop):
if self.is_chat_model:
return [self._chat_generate(prompt, stop) for prompt in prompts]
else:
return self._completion_generate(prompts, stop)
def generate_one(self, prompt, stop):
if self.is_chat_model:
return self._chat_generate(prompt, stop)
else:
return self._completion_generate([prompt], stop)[0]
def _chat_generate(self, prompt, stop):
result = openai.ChatCompletion.create(
engine=self.deployment_id,
messages=[{"role": "user", "content": prompt}],
temperature=0,
stop=stop,
)
return result["choices"][0]["message"]["content"]
def _completion_generate(self, prompts, stop):
result = openai.Completion.create(
engine=self.deployment_id,
prompt=prompts,
temperature=0,
stop=stop,
max_tokens=200,
)
return [answer["text"] for answer in result["choices"]]
class HfAgent(Agent):
"""
Agent that uses an inference endpoint to generate code.
Args:
url_endpoint (`str`):
The name of the url endpoint to use.
token (`str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
running `huggingface-cli login` (stored in `~/.huggingface`).
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
from transformers import HfAgent
agent = HfAgent("https://api-inference.huggingface.co/models/bigcode/starcoder")
agent.run("Is the following `text` (in Spanish) positive or negative?", text="¡Este es un API muy agradable!")
```
"""
def __init__(
self, url_endpoint, token=None, chat_prompt_template=None, run_prompt_template=None, additional_tools=None
):
self.url_endpoint = url_endpoint
if token is None:
self.token = f"Bearer {HfFolder().get_token()}"
elif token.startswith("Bearer") or token.startswith("Basic"):
self.token = token
else:
self.token = f"Bearer {token}"
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
def generate_one(self, prompt, stop):
headers = {"Authorization": self.token}
inputs = {
"inputs": prompt,
"parameters": {"max_new_tokens": 200, "return_full_text": False, "stop": stop},
}
response = requests.post(self.url_endpoint, json=inputs, headers=headers)
if response.status_code == 429:
logger.info("Getting rate-limited, waiting a tiny bit before trying again.")
time.sleep(1)
return self._generate_one(prompt)
elif response.status_code != 200:
raise ValueError(f"Error {response.status_code}: {response.json()}")
result = response.json()[0]["generated_text"]
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
return result[: -len(stop_seq)]
return result
class LocalAgent(Agent):
"""
Agent that uses a local model and tokenizer to generate code.
Args:
model ([`PreTrainedModel`]):
The model to use for the agent.
tokenizer ([`PreTrainedTokenizer`]):
The tokenizer to use for the agent.
chat_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `chat` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`chat_prompt_template.txt` in this repo in this case.
run_prompt_template (`str`, *optional*):
Pass along your own prompt if you want to override the default template for the `run` method. Can be the
actual prompt template or a repo ID (on the Hugging Face Hub). The prompt should be in a file named
`run_prompt_template.txt` in this repo in this case.
additional_tools ([`Tool`], list of tools or dictionary with tool values, *optional*):
Any additional tools to include on top of the default ones. If you pass along a tool with the same name as
one of the default tools, that default tool will be overridden.
Example:
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LocalAgent
checkpoint = "bigcode/starcoder"
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
agent = LocalAgent(model, tokenizer)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
def __init__(self, model, tokenizer, chat_prompt_template=None, run_prompt_template=None, additional_tools=None):
self.model = model
self.tokenizer = tokenizer
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Convenience method to build a `LocalAgent` from a pretrained checkpoint.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The name of a repo on the Hub or a local path to a folder containing both model and tokenizer.
kwargs (`Dict[str, Any]`, *optional*):
Keyword arguments passed along to [`~PreTrainedModel.from_pretrained`].
Example:
```py
import torch
from transformers import LocalAgent
agent = LocalAgent.from_pretrained("bigcode/starcoder", device_map="auto", torch_dtype=torch.bfloat16)
agent.run("Draw me a picture of rivers and lakes.")
```
"""
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(model, tokenizer)
@property
def _model_device(self):
if hasattr(self.model, "hf_device_map"):
return list(self.model.hf_device_map.values())[0]
for param in self.model.parameters():
return param.device
def generate_one(self, prompt, stop):
encoded_inputs = self.tokenizer(prompt, return_tensors="pt").to(self._model_device)
src_len = encoded_inputs["input_ids"].shape[1]
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, self.tokenizer)])
outputs = self.model.generate(
encoded_inputs["input_ids"], max_new_tokens=200, stopping_criteria=stopping_criteria
)
result = self.tokenizer.decode(outputs[0].tolist()[src_len:])
# Inference API returns the stop sequence
for stop_seq in stop:
if result.endswith(stop_seq):
result = result[: -len(stop_seq)]
return result
class StopSequenceCriteria(StoppingCriteria):
"""
This class can be used to stop generation whenever a sequence of tokens is encountered.
Args:
stop_sequences (`str` or `List[str]`):
The sequence (or list of sequences) on which to stop execution.
tokenizer:
The tokenizer used to decode the model outputs.
"""
def __init__(self, stop_sequences, tokenizer):
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
self.stop_sequences = stop_sequences
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
decoded_output = self.tokenizer.decode(input_ids.tolist()[0])
return any(decoded_output.endswith(stop_sequence) for stop_sequence in self.stop_sequences)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ..models.auto import AutoModelForVision2Seq
from ..utils import requires_backends
from .base import PipelineTool
if TYPE_CHECKING:
from PIL import Image
class ImageCaptioningTool(PipelineTool):
default_checkpoint = "Salesforce/blip-image-captioning-base"
description = (
"This is a tool that generates a description of an image. It takes an input named `image` which should be the "
"image to caption, and returns a text that contains the description in English."
)
name = "image_captioner"
model_class = AutoModelForVision2Seq
inputs = ["image"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image"):
return self.pre_processor(images=image, return_tensors="pt")
def forward(self, inputs):
return self.model.generate(**inputs)
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from ..models.clipseg import CLIPSegForImageSegmentation
from ..utils import is_vision_available, requires_backends
from .base import PipelineTool
if is_vision_available():
from PIL import Image
class ImageSegmentationTool(PipelineTool):
description = (
"This is a tool that creates a segmentation mask of an image according to a label. It cannot create an image. "
"It takes two arguments named `image` which should be the original image, and `label` which should be a text "
"describing the elements what should be identified in the segmentation mask. The tool returns the mask."
)
default_checkpoint = "CIDAS/clipseg-rd64-refined"
name = "image_segmenter"
model_class = CLIPSegForImageSegmentation
inputs = ["image", "text"]
outputs = ["image"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def encode(self, image: "Image", label: str):
return self.pre_processor(text=[label], images=[image], padding=True, return_tensors="pt")
def forward(self, inputs):
with torch.no_grad():
logits = self.model(**inputs).logits
return logits
def decode(self, outputs):
array = outputs.cpu().detach().numpy()
array[array <= 0] = 0
array[array > 0] = 1
return Image.fromarray((array * 255).astype(np.uint8))
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from ..utils import cached_file
# docstyle-ignore
CHAT_MESSAGE_PROMPT = """
Human: <<task>>
Assistant: """
DEFAULT_PROMPTS_REPO = "huggingface-tools/default-prompts"
PROMPT_FILES = {"chat": "chat_prompt_template.txt", "run": "run_prompt_template.txt"}
def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
"""
Downloads and caches the prompt from a repo and returns it contents (if necessary)
"""
if prompt_or_repo_id is None:
prompt_or_repo_id = DEFAULT_PROMPTS_REPO
# prompt is considered a repo ID when it does not contain any kind of space
if re.search("\\s", prompt_or_repo_id) is not None:
return prompt_or_repo_id
prompt_file = cached_file(
prompt_or_repo_id, PROMPT_FILES[mode], repo_type="dataset", user_agent={"agent": agent_name}
)
with open(prompt_file, "r", encoding="utf-8") as f:
return f.read()
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ..models.auto import AutoModelForSequenceClassification, AutoTokenizer
from .base import PipelineTool
class TextClassificationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextClassificationTool
classifier = TextClassificationTool()
classifier("This is a super nice API!", labels=["positive", "negative"])
```
"""
default_checkpoint = "facebook/bart-large-mnli"
description = (
"This is a tool that classifies an English text using provided labels. It takes two inputs: `text`, which "
"should be the text to classify, and `labels`, which should be the list of labels to use for classification. "
"It returns the most likely label in the list of provided `labels` for the input text."
)
name = "text_classifier"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSequenceClassification
inputs = ["text", ["text"]]
outputs = ["text"]
def setup(self):
super().setup()
config = self.model.config
self.entailment_id = -1
for idx, label in config.id2label.items():
if label.lower().startswith("entail"):
self.entailment_id = int(idx)
if self.entailment_id == -1:
raise ValueError("Could not determine the entailment ID from the model config, please pass it at init.")
def encode(self, text, labels):
self._labels = labels
return self.pre_processor(
[text] * len(labels),
[f"This example is {label}" for label in labels],
return_tensors="pt",
padding="max_length",
)
def decode(self, outputs):
logits = outputs.logits
label_id = torch.argmax(logits[:, 2]).item()
return self._labels[label_id]
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
QA_PROMPT = """Here is a text containing a lot of information: '''{text}'''.
Can you answer this question about the text: '{question}'"""
class TextQuestionAnsweringTool(PipelineTool):
default_checkpoint = "google/flan-t5-base"
description = (
"This is a tool that answers questions related to a text. It takes two arguments named `text`, which is the "
"text where to find the answer, and `question`, which is the question, and returns the answer to the question."
)
name = "text_qa"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text", "text"]
outputs = ["text"]
def encode(self, text: str, question: str):
prompt = QA_PROMPT.format(text=text, question=question)
return self.pre_processor(prompt, return_tensors="pt")
def forward(self, inputs):
output_ids = self.model.generate(**inputs)
in_b, _ = inputs["input_ids"].shape
out_b = output_ids.shape[0]
return output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])[0][0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.auto import AutoModelForSeq2SeqLM, AutoTokenizer
from .base import PipelineTool
class TextSummarizationTool(PipelineTool):
"""
Example:
```py
from transformers.tools import TextSummarizationTool
summarizer = TextSummarizationTool()
summarizer(long_text)
```
"""
default_checkpoint = "philschmid/bart-large-cnn-samsum"
description = (
"This is a tool that summarizes an English text. It takes an input `text` containing the text to summarize, "
"and returns a summary of the text."
)
name = "summarizer"
pre_processor_class = AutoTokenizer
model_class = AutoModelForSeq2SeqLM
inputs = ["text"]
outputs = ["text"]
def encode(self, text):
return self.pre_processor(text, return_tensors="pt", truncation=True)
def forward(self, inputs):
return self.model.generate(**inputs)[0]
def decode(self, outputs):
return self.pre_processor.decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
......@@ -142,6 +142,7 @@ _phonemizer_available = _is_package_available("phonemizer")
_psutil_available = _is_package_available("psutil")
_py3nvml_available = _is_package_available("py3nvml")
_pyctcdecode_available = _is_package_available("pyctcdecode")
_pygments_available = _is_package_available("pygments")
_pytesseract_available = _is_package_available("pytesseract")
_pytest_available = _is_package_available("pytest")
_pytorch_quantization_available = _is_package_available("pytorch_quantization")
......@@ -297,6 +298,10 @@ def is_hqq_available():
return _hqq_available
def is_pygments_available():
return _pygments_available
def get_torch_version():
return _torch_version
......@@ -1294,6 +1299,11 @@ shi-labs.com/natten . You can also install it with pip (may take longer to build
`pip install natten`. Please note that you may need to restart your runtime after installation.
"""
NUMEXPR_IMPORT_ERROR = """
{0} requires the numexpr library but it was not found in your environment. You can install it by referring to:
https://numexpr.readthedocs.io/en/latest/index.html.
"""
# docstyle-ignore
NLTK_IMPORT_ERROR = """
......
......@@ -18,8 +18,8 @@ import unittest
import uuid
from pathlib import Path
from transformers.agents.agent_types import AgentAudio, AgentImage, AgentText
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import uuid
import pytest
from transformers.agents.agent_types import AgentText
from transformers.agents.agents import AgentMaxIterationsError, CodeAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from transformers.agents.default_tools import PythonInterpreterTool
from transformers.testing_utils import require_torch
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix)
def fake_react_json_llm(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Action:
{
"action": "python_interpreter",
"action_input": {"code": "2*3.6452"}
}
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": {"answer": "7.2904"}
}
"""
def fake_react_code_llm(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = 2**3.6452
print(result)
```<end_code>
"""
else: # We're at step 2
return """
Thought: I can now answer the initial question
Code:
```py
final_answer(7.2904)
```<end_code>
"""
def fake_code_llm_oneshot(messages, stop_sequences=None) -> str:
return """
Thought: I should multiply 2 by 3.6452. special_marker
Code:
```py
result = python_interpreter(code="2*3.6452")
print(result)
```
"""
class AgentTests(unittest.TestCase):
def test_fake_code_agent(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_code_llm_oneshot)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
def test_fake_react_json_agent(self):
agent = ReactJsonAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_json_llm)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, str)
assert output == "7.2904"
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert agent.logs[1]["observation"] == "7.2904"
assert agent.logs[1]["rationale"].strip() == "Thought: I should multiply 2 by 3.6452. special_marker"
assert (
agent.logs[2]["llm_output"]
== """
Thought: I can now answer the initial question
Action:
{
"action": "final_answer",
"action_input": {"answer": "7.2904"}
}
"""
)
def test_fake_react_code_agent(self):
agent = ReactCodeAgent(tools=[PythonInterpreterTool()], llm_engine=fake_react_code_llm)
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "7.2904"
assert agent.logs[0]["task"] == "What is 2 multiplied by 3.6452?"
assert float(agent.logs[1]["observation"].strip()) - 12.511648 < 1e-6
assert agent.logs[2]["tool_call"] == {
"tool_arguments": "final_answer(7.2904)",
"tool_name": "code interpreter",
}
def test_setup_agent_with_empty_toolbox(self):
ReactJsonAgent(llm_engine=fake_react_json_llm, tools=[])
def test_react_fails_max_iterations(self):
agent = ReactCodeAgent(
tools=[PythonInterpreterTool()],
llm_engine=fake_code_llm_oneshot, # use this callable because it never ends
max_iterations=5,
)
agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 7
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
@require_torch
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = ReactCodeAgent(tools=toolset_1, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 1 # contains only final_answer tool
toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
agent = ReactCodeAgent(tools=toolset_2, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
toolset_3 = Toolbox(toolset_2)
agent = ReactCodeAgent(tools=toolset_3, llm_engine=fake_react_code_llm)
assert len(agent.toolbox.tools) == 2 # added final_answer tool
# check that add_base_tools will not interfere with existing tools
with pytest.raises(KeyError) as e:
agent = ReactJsonAgent(tools=toolset_3, llm_engine=fake_react_json_llm, add_base_tools=True)
assert "python_interpreter already exists in the toolbox" in str(e)
# check that python_interpreter base tool does not get added to code agents
agent = ReactCodeAgent(tools=[], llm_engine=fake_react_code_llm, add_base_tools=True)
assert len(agent.toolbox.tools) == 6 # added final_answer tool + 5 base tools (excluding interpreter)
......@@ -26,7 +26,6 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("document-question-answering")
self.tool.setup()
self.remote_tool = load_tool("document-question-answering", remote=True)
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
......@@ -35,22 +34,8 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
result = self.tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_arg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
result = self.remote_tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
self.tool(document=document, question="When is the coffee break?")
def test_exact_match_kwarg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
document = dataset[0]["image"]
result = self.remote_tool(document=document, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,38 +16,56 @@
import unittest
from pathlib import Path
from transformers import is_vision_available, load_tool
from transformers.testing_utils import get_tests_dir
import numpy as np
from PIL import Image
from transformers import is_torch_available, load_tool
from transformers.agents.agent_types import AGENT_TYPE_MAPPING
from transformers.testing_utils import get_tests_dir, require_torch
from .test_tools_common import ToolTesterMixin
if is_vision_available():
from PIL import Image
if is_torch_available():
import torch
class ImageCaptioningToolTester(unittest.TestCase, ToolTesterMixin):
class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-captioning")
self.inputs = {"answer": "Final answer"}
self.tool = load_tool("final_answer")
self.tool.setup()
self.remote_tool = load_tool("image-captioning", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image)
self.assertEqual(result, "two cats sleeping on a couch")
result = self.tool("Final answer")
self.assertEqual(result, "Final answer")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image)
self.assertEqual(result, "two cats sleeping on a couch")
result = self.tool(answer=self.inputs["answer"])
self.assertEqual(result, "Final answer")
def create_inputs(self):
inputs_text = {"answer": "Text input"}
inputs_image = {
"answer": Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png").resize(
(512, 512)
)
}
inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
return {"text": inputs_text, "image": inputs_image, "audio": inputs_audio}
@require_torch
def test_agent_type_output(self):
inputs = self.create_inputs()
for input_type, input in inputs.items():
output = self.tool(**input)
agent_type = AGENT_TYPE_MAPPING[input_type]
self.assertTrue(isinstance(output, agent_type))
@require_torch
def test_agent_types_inputs(self):
inputs = self.create_inputs()
for input_type, input in inputs.items():
output = self.tool(**input)
agent_type = AGENT_TYPE_MAPPING[input_type]
self.assertTrue(isinstance(output, agent_type))
......@@ -30,24 +30,13 @@ class ImageQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def setUp(self):
self.tool = load_tool("image-question-answering")
self.tool.setup()
self.remote_tool = load_tool("image-question-answering", remote=True)
def test_exact_match_arg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_arg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image, "How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
def test_exact_match_kwarg_remote(self):
image = Image.open(Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png")
result = self.remote_tool(image=image, question="How many cats are sleeping on the couch?")
self.assertEqual(result, "2")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment