"tests/modeling_tf_common_test.py" did not exist on "bb04edb45b73d06bc8674408df193d94c3bf77a1"
Unverified Commit deff5979 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Tool types (#24032)

* Tool types

* Tests + fixes

* Isolate types

* Oops

* Review comments + docs

* Tests + docs

* soundfile -> vision
parent 061580c8
......@@ -70,3 +70,32 @@ We provide three types of agents: [`HfAgent`] uses inference endpoints for opens
### launch_gradio_demo
[[autodoc]] launch_gradio_demo
## Agent Types
Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return
text, image, audio, video, among other types. In order to increase compatibility between tools, as well as to
correctly render these returns in ipython (jupyter, colab, ipython notebooks, ...), we implement wrapper classes
around these types.
The wrapped objects should continue behaving as initially; a text object should still behave as a string, an image
object should still behave as a `PIL.Image`.
These types have three specific purposes:
- Calling `to_raw` on the type should return the underlying object
- Calling `to_string` on the type should return the object as a string: that can be the string in case of an `AgentText`
but will be the path of the serialized version of the object in other instances
- Displaying it in an ipython kernel should display the object correctly
### AgentText
[[autodoc]] transformers.tools.agent_types.AgentText
### AgentImage
[[autodoc]] transformers.tools.agent_types.AgentImage
### AgentAudio
[[autodoc]] transformers.tools.agent_types.AgentAudio
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
import tempfile
import uuid
import numpy as np
from ..utils import is_soundfile_availble, is_torch_available, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL.Image
from PIL import Image
from PIL.Image import Image as ImageType
else:
ImageType = object
if is_torch_available():
import torch
if is_soundfile_availble():
import soundfile as sf
class AgentType:
"""
Abstract class to be reimplemented to define types that can be returned by agents.
These objects serve three purposes:
- They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
- They can be stringified: str(object) in order to return a string defining the object
- They should be displayed correctly in ipython notebooks/colab/jupyter
"""
def __init__(self, value):
self._value = value
def __str__(self):
return self.to_string()
def to_raw(self):
logger.error(
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
)
return self._value
def to_string(self) -> str:
logger.error(
"This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
)
return str(self._value)
class AgentText(AgentType, str):
"""
Text type returned by the agent. Behaves as a string.
"""
def to_raw(self):
return self._value
def to_string(self):
return self._value
class AgentImage(AgentType, ImageType):
"""
Image type returned by the agent. Behaves as a PIL.Image.
"""
def __init__(self, value):
super().__init__(value)
if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.")
self._path = None
self._raw = None
self._tensor = None
if isinstance(value, ImageType):
self._raw = value
elif isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
else:
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
def _ipython_display_(self, include=None, exclude=None):
"""
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
"""
from IPython.display import Image, display
display(Image(self.to_string()))
def to_raw(self):
"""
Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
"""
if self._raw is not None:
return self._raw
if self._path is not None:
self._raw = Image.open(self._path)
return self._raw
def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
version of the image.
"""
if self._path is not None:
return self._path
if self._raw is not None:
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
self._raw.save(self._path)
return self._path
if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
# There is likely simpler than load into image into save
img = Image.fromarray((array * 255).astype(np.uint8))
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
img.save(self._path)
return self._path
class AgentAudio(AgentType):
"""
Audio type returned by the agent.
"""
def __init__(self, value, samplerate=16_000):
super().__init__(value)
if not is_soundfile_availble():
raise ImportError("soundfile must be installed in order to handle audio.")
self._path = None
self._tensor = None
self.samplerate = samplerate
if isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
else:
raise ValueError(f"Unsupported audio type: {type(value)}")
def _ipython_display_(self, include=None, exclude=None):
"""
Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
"""
from IPython.display import Audio, display
display(Audio(self.to_string(), rate=self.samplerate))
def to_raw(self):
"""
Returns the "raw" version of that object. It is a `torch.Tensor` object.
"""
if self._tensor is not None:
return self._tensor
if self._path is not None:
tensor, self.samplerate = sf.read(self._path)
self._tensor = torch.tensor(tensor)
return self._tensor
def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
version of the audio.
"""
if self._path is not None:
return self._path
if self._tensor is not None:
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
sf.write(self._path, self._tensor, samplerate=self.samplerate)
return self._path
AGENT_TYPE_MAPPING = {"text": AgentText, "image": AgentImage, "audio": AgentAudio}
INSTANCE_TYPE_MAPPING = {str: AgentText}
if is_vision_available():
INSTANCE_TYPE_MAPPING[PIL.Image] = AgentImage
def handle_agent_inputs(*args, **kwargs):
args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
kwargs = {k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()}
return args, kwargs
def handle_agent_outputs(outputs, output_types=None):
if isinstance(outputs, dict):
decoded_outputs = {}
for i, (k, v) in enumerate(outputs.items()):
if output_types is not None:
# If the class has defined outputs, we can map directly according to the class definition
if output_types[i] in AGENT_TYPE_MAPPING:
decoded_outputs[k] = AGENT_TYPE_MAPPING[output_types[i]](v)
else:
decoded_outputs[k] = AgentType(v)
else:
# If the class does not have defined output, then we map according to the type
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(v, _k):
decoded_outputs[k] = _v(v)
if k not in decoded_outputs:
decoded_outputs[k] = AgentType[v]
elif isinstance(outputs, (list, tuple)):
decoded_outputs = type(outputs)()
for i, v in enumerate(outputs):
if output_types is not None:
# If the class has defined outputs, we can map directly according to the class definition
if output_types[i] in AGENT_TYPE_MAPPING:
decoded_outputs.append(AGENT_TYPE_MAPPING[output_types[i]](v))
else:
decoded_outputs.append(AgentType(v))
else:
# If the class does not have defined output, then we map according to the type
found = False
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(v, _k):
decoded_outputs.append(_v(v))
found = True
if not found:
decoded_outputs.append(AgentType(v))
else:
if output_types[0] in AGENT_TYPE_MAPPING:
# If the class has defined outputs, we can map directly according to the class definition
decoded_outputs = AGENT_TYPE_MAPPING[output_types[0]](outputs)
else:
# If the class does not have defined output, then we map according to the type
for _k, _v in INSTANCE_TYPE_MAPPING.items():
if isinstance(outputs, _k):
return _v(outputs)
return AgentType(outputs)
return decoded_outputs
......@@ -37,6 +37,7 @@ from ..utils import (
is_vision_available,
logging,
)
from .agent_types import handle_agent_inputs, handle_agent_outputs
logger = logging.get_logger(__name__)
......@@ -413,6 +414,8 @@ class RemoteTool(Tool):
return outputs
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
output_image = self.tool_class is not None and self.tool_class.outputs == ["image"]
inputs = self.prepare_inputs(*args, **kwargs)
if isinstance(inputs, dict):
......@@ -421,6 +424,9 @@ class RemoteTool(Tool):
outputs = self.client(inputs, output_image=output_image)
if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list):
outputs = outputs[0]
outputs = handle_agent_outputs(outputs, self.tool_class.outputs if self.tool_class is not None else None)
return self.extract_outputs(outputs)
......@@ -550,6 +556,8 @@ class PipelineTool(Tool):
return self.post_processor(outputs)
def __call__(self, *args, **kwargs):
args, kwargs = handle_agent_inputs(*args, **kwargs)
if not self.is_initialized:
self.setup()
......@@ -557,7 +565,9 @@ class PipelineTool(Tool):
encoded_inputs = send_to_device(encoded_inputs, self.device)
outputs = self.forward(encoded_inputs)
outputs = send_to_device(outputs, "cpu")
return self.decode(outputs)
decoded_outputs = self.decode(outputs)
return handle_agent_outputs(decoded_outputs, self.outputs)
def launch_gradio_demo(tool_class: Tool):
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import uuid
from pathlib import Path
from transformers.testing_utils import get_tests_dir, require_soundfile, require_torch, require_vision
from transformers.tools.agent_types import AgentAudio, AgentImage, AgentText
from transformers.utils import is_soundfile_availble, is_torch_available, is_vision_available
if is_torch_available():
import torch
if is_soundfile_availble():
import soundfile as sf
if is_vision_available():
from PIL import Image
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
return os.path.join(directory, str(uuid.uuid4()) + suffix)
@require_soundfile
@require_torch
class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self):
tensor = torch.rand(12, dtype=torch.float64) - 0.5
agent_type = AgentAudio(tensor)
path = str(agent_type.to_string())
# Ensure that the tensor and the agent_type's tensor are the same
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
del agent_type
# Ensure the path remains even after the object deletion
self.assertTrue(os.path.exists(path))
# Ensure that the file contains the same value as the original tensor
new_tensor, _ = sf.read(path)
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
def test_from_string(self):
tensor = torch.rand(12, dtype=torch.float64) - 0.5
path = get_new_path(suffix=".wav")
sf.write(path, tensor, 16000)
agent_type = AgentAudio(path)
self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
self.assertEqual(agent_type.to_string(), path)
@require_vision
@require_torch
class AgentImageTests(unittest.TestCase):
def test_from_tensor(self):
tensor = torch.randint(0, 256, (64, 64, 3))
agent_type = AgentImage(tensor)
path = str(agent_type.to_string())
# Ensure that the tensor and the agent_type's tensor are the same
self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4))
self.assertIsInstance(agent_type.to_raw(), Image.Image)
# Ensure the path remains even after the object deletion
del agent_type
self.assertTrue(os.path.exists(path))
def test_from_string(self):
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
image = Image.open(path)
agent_type = AgentImage(path)
self.assertTrue(path.samefile(agent_type.to_string()))
self.assertTrue(image == agent_type.to_raw())
# Ensure the path remains even after the object deletion
del agent_type
self.assertTrue(os.path.exists(path))
def test_from_image(self):
path = Path(get_tests_dir("fixtures/tests_samples/COCO")) / "000000039769.png"
image = Image.open(path)
agent_type = AgentImage(image)
self.assertFalse(path.samefile(agent_type.to_string()))
self.assertTrue(image == agent_type.to_raw())
# Ensure the path remains even after the object deletion
del agent_type
self.assertTrue(os.path.exists(path))
class AgentTextTests(unittest.TestCase):
def test_from_string(self):
string = "Hey!"
agent_type = AgentText(string)
self.assertEqual(string, agent_type.to_string())
self.assertEqual(string, agent_type.to_raw())
self.assertEqual(string, agent_type)
......@@ -30,28 +30,27 @@ class DocumentQuestionAnsweringToolTester(unittest.TestCase, ToolTesterMixin):
def test_exact_match_arg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
document = dataset[0]["image"]
result = self.tool(image, "When is the coffee break?")
result = self.tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_arg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
document = dataset[0]["image"]
result = self.remote_tool(image, "When is the coffee break?")
result = self.remote_tool(document, "When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
def test_exact_match_kwarg(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
document = dataset[0]["image"]
result = self.tool(image=image, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
self.tool(document=document, question="When is the coffee break?")
def test_exact_match_kwarg_remote(self):
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
image = dataset[0]["image"]
document = dataset[0]["image"]
result = self.remote_tool(image=image, question="When is the coffee break?")
result = self.remote_tool(document=document, question="When is the coffee break?")
self.assertEqual(result, "11-14 to 11:39 a.m.")
......@@ -37,9 +37,11 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
resulting_tensor[:3],
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
)
)
......@@ -47,8 +49,10 @@ class TextToSpeechToolTester(unittest.TestCase, ToolTesterMixin):
# SpeechT5 isn't deterministic
torch.manual_seed(0)
result = self.tool("hey")
resulting_tensor = result.to_raw()
self.assertTrue(
torch.allclose(
result[:3], torch.tensor([-0.00040140701457858086, -0.0002551682700868696, -0.00010294507956132293])
resulting_tensor[:3],
torch.tensor([-0.0005966668832115829, -0.0003657640190795064, -0.00013439502799883485]),
)
)
......@@ -18,6 +18,7 @@ from typing import List
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir, is_tool_test
from transformers.tools.agent_types import AGENT_TYPE_MAPPING, AgentAudio, AgentImage, AgentText
if is_torch_available():
......@@ -54,11 +55,11 @@ def output_types(outputs: List):
output_types = []
for output in outputs:
if isinstance(output, str):
if isinstance(output, (str, AgentText)):
output_types.append("text")
elif isinstance(output, Image.Image):
elif isinstance(output, (Image.Image, AgentImage)):
output_types.append("image")
elif isinstance(output, torch.Tensor):
elif isinstance(output, (torch.Tensor, AgentAudio)):
output_types.append("audio")
else:
raise ValueError(f"Invalid output: {output}")
......@@ -98,3 +99,35 @@ class ToolTesterMixin:
self.assertTrue(hasattr(self.tool, "description"))
self.assertTrue(hasattr(self.tool, "default_checkpoint"))
self.assertTrue(self.tool.description.startswith("This is a tool that"))
def test_agent_types_outputs(self):
inputs = create_inputs(self.tool.inputs)
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))
for output, output_type in zip(outputs, self.tool.outputs):
agent_type = AGENT_TYPE_MAPPING[output_type]
self.assertTrue(isinstance(output, agent_type))
def test_agent_types_inputs(self):
inputs = create_inputs(self.tool.inputs)
_inputs = []
for _input, input_type in zip(inputs, self.tool.inputs):
if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))
......@@ -16,6 +16,7 @@
import unittest
from transformers import load_tool
from transformers.tools.agent_types import AGENT_TYPE_MAPPING
from .test_tools_common import ToolTesterMixin, output_types
......@@ -51,3 +52,35 @@ class TranslationToolTester(unittest.TestCase, ToolTesterMixin):
outputs = [outputs]
self.assertListEqual(output_types(outputs), self.tool.outputs)
def test_agent_types_outputs(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))
for output, output_type in zip(outputs, self.tool.outputs):
agent_type = AGENT_TYPE_MAPPING[output_type]
self.assertTrue(isinstance(output, agent_type))
def test_agent_types_inputs(self):
inputs = ["Hey, what's up?", "English", "Spanish"]
_inputs = []
for _input, input_type in zip(inputs, self.tool.inputs):
if isinstance(input_type, list):
_inputs.append([AGENT_TYPE_MAPPING[_input_type](_input) for _input_type in input_type])
else:
_inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
# Should not raise an error
outputs = self.tool(*inputs)
if not isinstance(outputs, list):
outputs = [outputs]
self.assertEqual(len(outputs), len(self.tool.outputs))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment