"vscode:/vscode.git/clone" did not exist on "4ece3b9433ea0bedff0d64fe00623c35766d7d44"
Unverified Commit 12ae6d35 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Fix gradio tool demos (#31230)

* Fix gradio tool demos
parent dcdda532
...@@ -88,7 +88,8 @@ class AgentImage(AgentType, ImageType): ...@@ -88,7 +88,8 @@ class AgentImage(AgentType, ImageType):
""" """
def __init__(self, value): def __init__(self, value):
super().__init__(value) AgentType.__init__(self, value)
ImageType.__init__(self)
if not is_vision_available(): if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.") raise ImportError("PIL must be installed in order to handle images.")
...@@ -103,6 +104,8 @@ class AgentImage(AgentType, ImageType): ...@@ -103,6 +104,8 @@ class AgentImage(AgentType, ImageType):
self._path = value self._path = value
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
self._tensor = value self._tensor = value
elif isinstance(value, np.ndarray):
self._tensor = torch.tensor(value)
else: else:
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}") raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
...@@ -125,6 +128,10 @@ class AgentImage(AgentType, ImageType): ...@@ -125,6 +128,10 @@ class AgentImage(AgentType, ImageType):
self._raw = Image.open(self._path) self._raw = Image.open(self._path)
return self._raw return self._raw
if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
return Image.fromarray((255 - array * 255).astype(np.uint8))
def to_string(self): def to_string(self):
""" """
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
...@@ -137,14 +144,13 @@ class AgentImage(AgentType, ImageType): ...@@ -137,14 +144,13 @@ class AgentImage(AgentType, ImageType):
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
self._raw.save(self._path) self._raw.save(self._path)
return self._path return self._path
if self._tensor is not None: if self._tensor is not None:
array = self._tensor.cpu().detach().numpy() array = self._tensor.cpu().detach().numpy()
# There is likely simpler than load into image into save # There is likely simpler than load into image into save
img = Image.fromarray((array * 255).astype(np.uint8)) img = Image.fromarray((255 - array * 255).astype(np.uint8))
directory = tempfile.mkdtemp() directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png") self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
...@@ -153,8 +159,19 @@ class AgentImage(AgentType, ImageType): ...@@ -153,8 +159,19 @@ class AgentImage(AgentType, ImageType):
return self._path return self._path
def save(self, output_bytes, format, **params):
"""
Saves the image to a file.
Args:
output_bytes (bytes): The output bytes to save the image to.
format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
**params: Additional parameters to pass to PIL.Image.save.
"""
img = self.to_raw()
img.save(output_bytes, format, **params)
class AgentAudio(AgentType):
class AgentAudio(AgentType, str):
""" """
Audio type returned by the agent. Audio type returned by the agent.
""" """
...@@ -169,11 +186,13 @@ class AgentAudio(AgentType): ...@@ -169,11 +186,13 @@ class AgentAudio(AgentType):
self._tensor = None self._tensor = None
self.samplerate = samplerate self.samplerate = samplerate
if isinstance(value, (str, pathlib.Path)): if isinstance(value, (str, pathlib.Path)):
self._path = value self._path = value
elif isinstance(value, torch.Tensor): elif isinstance(value, torch.Tensor):
self._tensor = value self._tensor = value
elif isinstance(value, tuple):
self.samplerate = value[0]
self._tensor = torch.tensor(value[1])
else: else:
raise ValueError(f"Unsupported audio type: {type(value)}") raise ValueError(f"Unsupported audio type: {type(value)}")
......
...@@ -133,7 +133,8 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000, ...@@ -133,7 +133,8 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
""" """
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <<tool_names>> DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
To do so, you have been given access to the following tools: <<tool_names>>
The way you use the tools is by specifying a json blob, ending with '<end_action>'. The way you use the tools is by specifying a json blob, ending with '<end_action>'.
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool). Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
...@@ -261,7 +262,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000, ...@@ -261,7 +262,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
""" """
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can. DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code. To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
......
...@@ -47,10 +47,6 @@ from .agent_types import handle_agent_inputs, handle_agent_outputs ...@@ -47,10 +47,6 @@ from .agent_types import handle_agent_inputs, handle_agent_outputs
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if is_torch_available(): if is_torch_available():
import torch import torch
...@@ -623,20 +619,20 @@ def launch_gradio_demo(tool_class: Tool): ...@@ -623,20 +619,20 @@ def launch_gradio_demo(tool_class: Tool):
return tool(*args, **kwargs) return tool(*args, **kwargs)
gradio_inputs = [] gradio_inputs = []
for input_type in [tool_input["type"] for tool_input in tool_class.inputs.values()]: for input_name, input_details in tool_class.inputs.items():
if input_type in [str, int, float]: input_type = input_details["type"]
gradio_inputs += "text" if input_type == "text":
elif is_vision_available() and input_type == PIL.Image.Image: gradio_inputs.append(gr.Textbox(label=input_name))
gradio_inputs += "image" elif input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
else: else:
gradio_inputs += "audio" error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)
if tool_class.output_type in [str, int, float]: gradio_output = tool_class.output_type
gradio_output = "text" assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported."
elif is_vision_available() and tool_class.output_type == PIL.Image.Image:
gradio_output = "image"
else:
gradio_output = "audio"
gr.Interface( gr.Interface(
fn=fn, fn=fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment