Unverified Commit a24a9a66 authored by Aymeric Roucher's avatar Aymeric Roucher Committed by GitHub
Browse files

Add stream messages from agent run for gradio chatbot (#32142)

* Add stream_to_gradio method for running agent in gradio demo
parent 811a9caa
...@@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool]) ...@@ -509,3 +509,54 @@ agent = ReactCodeAgent(tools=[search_tool])
agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?") agent.run("How many more blocks (also denoted as layers) in BERT base encoder than the encoder from the architecture proposed in Attention is All You Need?")
``` ```
## Gradio interface
You can leverage `gradio.Chatbot`to display your agent's thoughts using `stream_to_gradio`, here is an example:
```py
import gradio as gr
from transformers import (
load_tool,
ReactCodeAgent,
HfEngine,
stream_to_gradio,
)
# Import tool from Hub
image_generation_tool = load_tool("m-ric/text-to-image")
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
# Initialize the agent with the image generation tool
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
def interact_with_agent(task):
messages = []
messages.append(gr.ChatMessage(role="user", content=task))
yield messages
for msg in stream_to_gradio(agent, task):
messages.append(msg)
yield messages + [
gr.ChatMessage(role="assistant", content="⏳ Task not finished yet!")
]
yield messages
with gr.Blocks() as demo:
text_input = gr.Textbox(lines=1, label="Chat Message", value="Make me a picture of the Statue of Liberty.")
submit = gr.Button("Run illustrator agent!")
chatbot = gr.Chatbot(
label="Agent",
type="messages",
avatar_images=(
None,
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
),
)
submit.click(interact_with_agent, [text_input], [chatbot])
if __name__ == "__main__":
demo.launch()
```
\ No newline at end of file
...@@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class: ...@@ -72,6 +72,10 @@ We provide two types of agents, based on the main [`Agent`] class:
[[autodoc]] launch_gradio_demo [[autodoc]] launch_gradio_demo
### stream_to_gradio
[[autodoc]] stream_to_gradio
### ToolCollection ### ToolCollection
[[autodoc]] ToolCollection [[autodoc]] ToolCollection
......
...@@ -67,6 +67,7 @@ _import_structure = { ...@@ -67,6 +67,7 @@ _import_structure = {
"ToolCollection", "ToolCollection",
"launch_gradio_demo", "launch_gradio_demo",
"load_tool", "load_tool",
"stream_to_gradio",
], ],
"audio_utils": [], "audio_utils": [],
"benchmark": [], "benchmark": [],
...@@ -4733,6 +4734,7 @@ if TYPE_CHECKING: ...@@ -4733,6 +4734,7 @@ if TYPE_CHECKING:
ToolCollection, ToolCollection,
launch_gradio_demo, launch_gradio_demo,
load_tool, load_tool,
stream_to_gradio,
) )
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
......
...@@ -26,6 +26,7 @@ from ..utils import ( ...@@ -26,6 +26,7 @@ from ..utils import (
_import_structure = { _import_structure = {
"agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"], "agents": ["Agent", "CodeAgent", "ReactAgent", "ReactCodeAgent", "ReactJsonAgent", "Toolbox"],
"llm_engine": ["HfEngine"], "llm_engine": ["HfEngine"],
"monitoring": ["stream_to_gradio"],
"tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"], "tools": ["PipelineTool", "Tool", "ToolCollection", "launch_gradio_demo", "load_tool"],
} }
...@@ -45,6 +46,7 @@ else: ...@@ -45,6 +46,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox from .agents import Agent, CodeAgent, ReactAgent, ReactCodeAgent, ReactJsonAgent, Toolbox
from .llm_engine import HfEngine from .llm_engine import HfEngine
from .monitoring import stream_to_gradio
from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool from .tools import PipelineTool, Tool, ToolCollection, launch_gradio_demo, load_tool
try: try:
......
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_types import AgentAudio, AgentImage, AgentText
from .agents import ReactAgent
def pull_message(step_log: dict):
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
if step_log.get("rationale"):
yield ChatMessage(role="assistant", content=step_log["rationale"])
if step_log.get("tool_call"):
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
content = step_log["tool_call"]["tool_arguments"]
if used_code:
content = f"```py\n{content}\n```"
yield ChatMessage(
role="assistant",
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
content=content,
)
if step_log.get("observation"):
yield ChatMessage(role="assistant", content=f"```\n{step_log['observation']}\n```")
if step_log.get("error"):
yield ChatMessage(
role="assistant",
content=str(step_log["error"]),
metadata={"title": "💥 Error"},
)
def stream_to_gradio(agent: ReactAgent, task: str, **kwargs):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
try:
from gradio import ChatMessage
except ImportError:
raise ImportError("Gradio should be installed in order to launch a gradio demo.")
for step_log in agent.run(task, stream=True, **kwargs):
if isinstance(step_log, dict):
for message in pull_message(step_log):
yield message
if isinstance(step_log, AgentText):
yield ChatMessage(role="assistant", content=f"**Final answer:**\n```\n{step_log.to_string()}\n```")
elif isinstance(step_log, AgentImage):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "image/png"},
)
elif isinstance(step_log, AgentAudio):
yield ChatMessage(
role="assistant",
content={"path": step_log.to_string(), "mime_type": "audio/wav"},
)
else:
yield ChatMessage(role="assistant", content=step_log)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment