Unverified Commit 12ae6d35 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Fix gradio tool demos (#31230)

* Fix gradio tool demos
parent dcdda532
......@@ -88,7 +88,8 @@ class AgentImage(AgentType, ImageType):
"""
def __init__(self, value):
super().__init__(value)
AgentType.__init__(self, value)
ImageType.__init__(self)
if not is_vision_available():
raise ImportError("PIL must be installed in order to handle images.")
......@@ -103,6 +104,8 @@ class AgentImage(AgentType, ImageType):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
elif isinstance(value, np.ndarray):
self._tensor = torch.tensor(value)
else:
raise ValueError(f"Unsupported type for {self.__class__.__name__}: {type(value)}")
......@@ -125,6 +128,10 @@ class AgentImage(AgentType, ImageType):
self._raw = Image.open(self._path)
return self._raw
if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
return Image.fromarray((255 - array * 255).astype(np.uint8))
def to_string(self):
"""
Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
......@@ -137,14 +144,13 @@ class AgentImage(AgentType, ImageType):
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
self._raw.save(self._path)
return self._path
if self._tensor is not None:
array = self._tensor.cpu().detach().numpy()
# There is likely simpler than load into image into save
img = Image.fromarray((array * 255).astype(np.uint8))
img = Image.fromarray((255 - array * 255).astype(np.uint8))
directory = tempfile.mkdtemp()
self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
......@@ -153,8 +159,19 @@ class AgentImage(AgentType, ImageType):
return self._path
def save(self, output_bytes, format, **params):
"""
Saves the image to a file.
Args:
output_bytes (bytes): The output bytes to save the image to.
format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
**params: Additional parameters to pass to PIL.Image.save.
"""
img = self.to_raw()
img.save(output_bytes, format, **params)
class AgentAudio(AgentType):
class AgentAudio(AgentType, str):
"""
Audio type returned by the agent.
"""
......@@ -169,11 +186,13 @@ class AgentAudio(AgentType):
self._tensor = None
self.samplerate = samplerate
if isinstance(value, (str, pathlib.Path)):
self._path = value
elif isinstance(value, torch.Tensor):
self._tensor = value
elif isinstance(value, tuple):
self.samplerate = value[0]
self._tensor = torch.tensor(value[1])
else:
raise ValueError(f"Unsupported audio type: {type(value)}")
......
......@@ -133,7 +133,8 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
"""
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You will be given a task to solve as best you can. To do so, you have been given access to the following tools: <<tool_names>>
DEFAULT_REACT_JSON_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using JSON tool calls. You will be given a task to solve as best you can.
To do so, you have been given access to the following tools: <<tool_names>>
The way you use the tools is by specifying a json blob, ending with '<end_action>'.
Specifically, this json should have an `action` key (name of the tool to use) and an `action_input` key (input to the tool).
......@@ -261,7 +262,7 @@ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,
"""
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You will be given a task to solve as best you can.
DEFAULT_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
To do so, you have been given access to *tools*: these tools are basically Python functions which you can call with code.
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
......
......@@ -47,10 +47,6 @@ from .agent_types import handle_agent_inputs, handle_agent_outputs
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL.Image
import PIL.ImageOps
if is_torch_available():
import torch
......@@ -623,20 +619,20 @@ def launch_gradio_demo(tool_class: Tool):
return tool(*args, **kwargs)
gradio_inputs = []
for input_type in [tool_input["type"] for tool_input in tool_class.inputs.values()]:
if input_type in [str, int, float]:
gradio_inputs += "text"
elif is_vision_available() and input_type == PIL.Image.Image:
gradio_inputs += "image"
for input_name, input_details in tool_class.inputs.items():
input_type = input_details["type"]
if input_type == "text":
gradio_inputs.append(gr.Textbox(label=input_name))
elif input_type == "image":
gradio_inputs.append(gr.Image(label=input_name))
elif input_type == "audio":
gradio_inputs.append(gr.Audio(label=input_name))
else:
gradio_inputs += "audio"
error_message = f"Input type '{input_type}' not supported."
raise ValueError(error_message)
if tool_class.output_type in [str, int, float]:
gradio_output = "text"
elif is_vision_available() and tool_class.output_type == PIL.Image.Image:
gradio_output = "image"
else:
gradio_output = "audio"
gradio_output = tool_class.output_type
assert gradio_output in ["text", "image", "audio"], f"Output type '{gradio_output}' not supported."
gr.Interface(
fn=fn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment