"docs/source/en/api/diffusion_pipeline.md" did not exist on "a6e2c1fe5c02cae8a9f077f5d4e11b73d5791723"
Commit 67ca83cf authored by Rayyyyy's avatar Rayyyyy
Browse files

Support GLM-4-0414

parent 78ba9d16
......@@ -172,13 +172,20 @@ abstract class BaseBrowser {
logger.debug(`Searching for: ${query}`);
const search = new URLSearchParams({ q: query });
recency_days > 0 && search.append('recency_days', recency_days.toString());
if (config.CUSTOM_CONFIG_ID) {
search.append('customconfig', config.CUSTOM_CONFIG_ID.toString());
}
const url = `${config.BING_SEARCH_API_URL}/search?${search.toString()}`;
console.log('Full URL:', url); // 输出完整的 URL查看是否正确
return withTimeout(
config.BROWSER_TIMEOUT,
fetch(`${config.BING_SEARCH_API_URL}/search?${search.toString()}`, {
fetch(url, {
headers: {
'Ocp-Apim-Subscription-Key': config.BING_SEARCH_API_KEY,
}
}).then(
})
.then(
res =>
res.json() as Promise<{
queryContext: {
......@@ -255,11 +262,11 @@ abstract class BaseBrowser {
}
})
.catch(err => {
logger.error(err.message);
logger.error(`搜索请求失败:${query},错误信息:${err.message}`);
if (err.code === 'ECONNABORTED') {
throw new Error(`Timeout while executing search for: ${query}`);
}
throw new Error(`Network or server error occurred`);
throw new Error(`网络或服务器发生错误,请检查URL: ${url}`);
});
},
open_url: (url: string) => {
......
export default {
LOG_LEVEL: 'debug',
BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/',
BING_SEARCH_API_KEY: '',
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/',
BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY',
CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处
HOST: 'localhost',
PORT: 3000,
};
\ No newline at end of file
};
......@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => {
} = req.body;
logger.info(`session_id: ${session_id}`);
logger.info(`action: ${action}`);
if (!session_history[session_id]) {
session_history[session_id] = new SimpleBrowser();
}
......
......@@ -53,4 +53,4 @@ export const withTimeout = <T>(
setTimeout(() => reject(new TimeoutError()), millis)
);
return Promise.race([promiseWithTime(promise), timeout]);
};
\ No newline at end of file
};
# Please install the requirments.txt in inference first!
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
langchain>=0.2.12
langchain-community>=0.2.11
matplotlib>=3.9.1
pymupdf>=1.24.9
python-docx>=1.1.2
python-pptx>=0.6.23
pyyaml>=6.0.1
requests>=2.31.0
streamlit>=1.37.1
zhipuai>=2.1.4
......@@ -13,7 +13,6 @@ from enum import Enum, auto
from typing import Protocol
import streamlit as st
from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS
......@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS
class ClientType(Enum):
HF = auto()
VLLM = auto()
API = auto()
class Client(Protocol):
......@@ -34,15 +34,15 @@ class Client(Protocol):
) -> Generator[tuple[str | dict, list[dict]]]: ...
def process_input(history: list[dict], tools: list[dict]) -> list[dict]:
def process_input(history: list[dict], tools: list[dict], role_name_replace: dict = None) -> list[dict]:
chat_history = []
if len(tools) > 0:
chat_history.append(
{"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
)
# if len(tools) > 0:
chat_history.append({"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)})
for conversation in history:
role = str(conversation.role).removeprefix("<|").removesuffix("|>")
if role_name_replace:
role = role_name_replace.get(role, role)
item = {
"role": role,
"content": conversation.content,
......@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client:
e.msg += "; did you forget to install vLLM?"
raise
return VLLMClient(model_path)
case ClientType.API:
from clients.openai import APIClient
return APIClient(model_path)
raise NotImplementedError(f"Client type {typ} is not supported.")
......@@ -2,25 +2,23 @@
HuggingFace client.
"""
import threading
from collections.abc import Generator
from threading import Thread
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from client import Client, process_input, process_response
from conversation import Conversation
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
class HFClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True,
model_path,
trust_remote_code=True,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cuda",
).eval()
......
"""
OpenAI API client.
"""
from collections.abc import Generator
from client import Client, process_input, process_response
from conversation import Conversation
from openai import OpenAI
def format_openai_tool(origin_tools):
openai_tools = []
for tool in origin_tools:
openai_param = {}
for param in tool["params"]:
openai_param[param["name"]] = {}
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool["description"],
"parameters": {
"type": "object",
"properties": {
param["name"]: {"type": param["type"], "description": param["description"]}
for param in tool["params"]
},
"required": [param["name"] for param in tool["params"] if param["required"]],
},
},
}
openai_tools.append(openai_tool)
return openai_tools
class APIClient(Client):
def __init__(self, model_path: str):
base_url = "http://127.0.0.1:8000/v1/"
self.client = OpenAI(api_key="EMPTY", base_url=base_url)
self.use_stream = False
self.role_name_replace = {"observation": "tool"}
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, "", role_name_replace=self.role_name_replace)
# messages = process_input(history, '', role_name_replace=self.role_name_replace)
openai_tools = format_openai_tool(tools)
response = self.client.chat.completions.create(
model="glm-4",
messages=chat_history,
tools=openai_tools,
stream=self.use_stream,
max_tokens=parameters["max_new_tokens"],
temperature=parameters["temperature"],
presence_penalty=1.2,
top_p=parameters["top_p"],
tool_choice="auto",
)
output = response.choices[0].message
if output.tool_calls:
glm4_output = output.tool_calls[0].function.name + "\n" + output.tool_calls[0].function.arguments
else:
glm4_output = output.content
yield process_response(glm4_output, chat_history)
......@@ -8,23 +8,19 @@ installation guide before running this client.
import time
from collections.abc import Generator
from transformers import AutoTokenizer
from vllm import SamplingParams, LLMEngine, EngineArgs
from client import Client, process_input, process_response
from conversation import Conversation
from transformers import AutoTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams
class VLLMClient(Client):
def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.engine_args = EngineArgs(
model=model_path,
tensor_parallel_size=1,
dtype="bfloat16", # torch.bfloat16 is needed.
trust_remote_code=True,
gpu_memory_utilization=0.6,
enforce_eager=True,
worker_use_ray=False,
......@@ -35,29 +31,20 @@ class VLLMClient(Client):
self, tools: list[dict], history: list[Conversation], **parameters
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template(
chat_history, add_generation_prompt=True, tokenize=False
)
model_inputs = self.tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, tokenize=False)
parameters["max_tokens"] = parameters.pop("max_new_tokens")
params_dict = {
"n": 1,
"best_of": 1,
"top_p": 1,
"top_k": -1,
"use_beam_search": False,
"length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"logprobs": None,
"prompt_logprobs": None,
}
params_dict.update(parameters)
sampling_params = SamplingParams(**params_dict)
self.engine.add_request(
request_id=str(time.time()), inputs=model_inputs, params=sampling_params
)
self.engine.add_request(request_id=str(time.time()), inputs=model_inputs, params=sampling_params)
while self.engine.has_unfinished_requests():
request_outputs = self.engine.step()
for request_output in request_outputs:
......
......@@ -5,12 +5,11 @@ from datetime import datetime
from enum import Enum, auto
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL.Image import Image
from streamlit.delta_generator import DeltaGenerator
from tools.browser import Quote, quotes
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")
SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
......@@ -30,7 +29,8 @@ def build_system_prompt(
):
value = SELFCOG_PROMPT
value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
value += "\n\n# 可用工具"
if enabled_tools or functions:
value += "\n\n# 可用工具"
contents = []
for tool in enabled_tools:
contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
......@@ -130,23 +130,23 @@ class Conversation:
if self.role != Role.USER:
show_text = text
else:
splitted = text.split('files uploaded.\n')
splitted = text.split("files uploaded.\n")
if len(splitted) == 1:
show_text = text
else:
# Show expander for document content
doc = splitted[0]
show_text = splitted[-1]
expander = message.expander(f'File Content')
expander = message.expander("File Content")
expander.markdown(doc)
message.markdown(show_text)
def postprocess_text(text: str, replace_quote: bool) -> str:
text = text.replace("\(", "$")
text = text.replace("\)", "$")
text = text.replace("\[", "$$")
text = text.replace("\]", "$$")
text = text.replace(r"\(", "$")
text = text.replace(r"\)", "$")
text = text.replace(r"\[", "$$")
text = text.replace(r"\]", "$$")
text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "")
......@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str:
for match in QUOTE_REGEX.finditer(text):
quote_id = match.group(1)
quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
text = text.replace(
match.group(0), f" (来源:[{quote.title}]({quote.url})) "
)
text = text.replace(match.group(0), f" (来源:[{quote.title}]({quote.url})) ")
return text.strip()
......@@ -12,10 +12,6 @@ from io import BytesIO
from uuid import uuid4
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL import Image
from client import Client, ClientType, get_client
from conversation import (
FILE_TEMPLATE,
......@@ -24,14 +20,17 @@ from conversation import (
postprocess_text,
response_to_str,
)
from PIL import Image
from streamlit.delta_generator import DeltaGenerator
from tools.tool_registry import dispatch_tool, get_tools
from utils import extract_pdf, extract_docx, extract_pptx, extract_text
from utils import extract_docx, extract_pdf, extract_pptx, extract_text
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
USE_API = os.environ.get("USE_API", "0") == "1"
class Mode(str, Enum):
......@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client:
case Mode.ALL_TOOLS:
st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF
typ = ClientType.API if USE_API else typ
return get_client(CHAT_MODEL_PATH, typ)
case Mode.LONG_CTX:
st.session_state.top_k = 10
......@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX:
content = extract_pptx(file_path)
else:
content = extract_text(file_path)
uploaded_texts.append(
FILE_TEMPLATE.format(file_name=file_name, file_content=content)
)
uploaded_texts.append(FILE_TEMPLATE.format(file_name=file_name, file_content=content))
os.remove(file_path)
st.session_state.uploaded_texts = "\n\n".join(uploaded_texts)
st.session_state.uploaded_file_nums = len(uploaded_files)
......@@ -230,9 +228,7 @@ def main(prompt_text: str):
# Append uploaded files
uploaded_texts = st.session_state.get("uploaded_texts")
if page == Mode.LONG_CTX and uploaded_texts and first_round:
meta_msg = "{} files uploaded.\n".format(
st.session_state.uploaded_file_nums
)
meta_msg = "{} files uploaded.\n".format(st.session_state.uploaded_file_nums)
prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text
# Clear after first use
st.session_state.files_uploaded = True
......@@ -247,16 +243,12 @@ def main(prompt_text: str):
append_conversation(Conversation(role, prompt_text, image=image), history)
placeholder = st.container()
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
def add_new_block():
nonlocal message_placeholder, markdown_placeholder
message_placeholder = placeholder.chat_message(
name="assistant", avatar="assistant"
)
message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
markdown_placeholder = message_placeholder.empty()
def commit_conversation(
......@@ -301,33 +293,18 @@ def main(prompt_text: str):
history_len = len(chat_history)
last_response = response
replace_quote = chat_history[-1]["role"] == "assistant"
markdown_placeholder.markdown(
postprocess_text(
str(response) + "●", replace_quote=replace_quote
)
)
markdown_placeholder.markdown(postprocess_text(str(response) + "●", replace_quote=replace_quote))
else:
metadata = (
page == Mode.ALL_TOOLS
and isinstance(response, dict)
and response.get("name")
or None
)
metadata = page == Mode.ALL_TOOLS and isinstance(response, dict) and response.get("name") or None
role = Role.TOOL if metadata else Role.ASSISTANT
text = (
response.get("content")
if metadata
else response_to_str(response)
)
text = response.get("content") if metadata else response_to_str(response)
commit_conversation(role, text, metadata)
if metadata:
add_new_block()
try:
with markdown_placeholder:
with st.spinner(f"Calling tool {metadata}..."):
observations = dispatch_tool(
metadata, text, str(st.session_state.session_id)
)
observations = dispatch_tool(metadata, text, str(st.session_state.session_id))
except Exception as e:
traceback.print_exc()
st.error(f'Uncaught exception in `"{metadata}"`: {e}')
......@@ -346,7 +323,7 @@ def main(prompt_text: str):
continue
else:
break
except Exception as e:
except Exception:
traceback.print_exc()
st.error(f"Uncaught exception: {traceback.format_exc()}")
else:
......
......@@ -6,22 +6,26 @@ Simple browser tool.
Please start the backend browser server according to the instructions in the README.
"""
from pprint import pprint
import re
from dataclasses import dataclass
from pprint import pprint
import requests
import streamlit as st
from dataclasses import dataclass
from .config import BROWSER_SERVER_URL
from .interface import ToolObservation
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")
@dataclass
class Quote:
title: str
url: str
# Quotes for displaying reference
if "quotes" not in st.session_state:
st.session_state.quotes = {}
......@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes
def map_response(response: dict) -> ToolObservation:
# Save quotes for reference
print('===BROWSER_RESPONSE===')
print("===BROWSER_RESPONSE===")
pprint(response)
role_metadata = response.get("roleMetadata")
metadata = response.get("metadata")
if role_metadata.split()[0] == 'quote_result' and metadata:
if role_metadata.split()[0] == "quote_result" and metadata:
quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
quote: dict[str, str] = metadata['metadata_list'][0]
quotes[quote_id] = Quote(quote['title'], quote['url'])
elif role_metadata == 'browser_result' and metadata:
for i, quote in enumerate(metadata['metadata_list']):
quotes[str(i)] = Quote(quote['title'], quote['url'])
quote: dict[str, str] = metadata["metadata_list"][0]
quotes[quote_id] = Quote(quote["title"], quote["url"])
elif role_metadata == "browser_result" and metadata:
for i, quote in enumerate(metadata["metadata_list"]):
quotes[str(i)] = Quote(quote["title"], quote["url"])
return ToolObservation(
content_type=response.get("contentType"),
......
......@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage
from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
from .interface import ToolObservation
@st.cache_resource
def get_zhipu_client():
return ZhipuAI(api_key=ZHIPU_AI_KEY)
def map_response(img: GeneratedImage):
return ToolObservation(
content_type='image',
text='CogView 已经生成并向用户展示了生成的图片。',
content_type="image",
text="CogView 已经生成并向用户展示了生成的图片。",
image_url=img.url,
role_metadata='cogview_result'
role_metadata="cogview_result",
)
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
client = get_zhipu_client()
response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
......
BROWSER_SERVER_URL = "http://localhost:3000"
IPYKERNEL = "glm-4-demo"
ZHIPU_AI_KEY = ""
COGVIEW_MODEL = "cogview-3"
from dataclasses import dataclass
from typing import Any
@dataclass
class ToolObservation:
content_type: str
......
from pprint import pprint
import queue
import re
from pprint import pprint
from subprocess import PIPE
from typing import Literal
......@@ -10,19 +10,22 @@ import streamlit as st
from .config import IPYKERNEL
from .interface import ToolObservation
ANSI_ESCAPE = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
CODE = re.compile(r'```([^\n]*)\n(.*?)```')
class CodeKernel:
def __init__(self,
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
ANSI_ESCAPE = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]")
CODE = re.compile(r"```([^\n]*)\n(.*?)```")
class CodeKernel:
def __init__(
self,
kernel_name="kernel",
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1,
):
self.kernel_name = kernel_name
self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path
......@@ -37,19 +40,16 @@ class CodeKernel:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL,
connection_file=self.kernel_config_path,
exec_files=[self.init_file_path],
env=env)
self.kernel_manager = jupyter_client.KernelManager(
kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env
)
if self.kernel_config_path:
self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_config_path))
print("Backend kernel started with the configuration: {}".format(self.kernel_config_path))
else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format(
self.kernel_manager.connection_file))
print("Backend kernel started with the configuration: {}".format(self.kernel_manager.connection_file))
if verbose:
pprint(self.kernel_manager.get_connection_info())
......@@ -64,13 +64,13 @@ class CodeKernel:
self.kernel.execute(code)
try:
shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
while True:
msg_out = io_msg_content
### Poll the message
try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content']
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle':
io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
if "execution_state" in io_msg_content and io_msg_content["execution_state"] == "idle":
break
except queue.Empty:
break
......@@ -100,12 +100,12 @@ class CodeKernel:
return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error':
if msg["content"]["status"] == "error":
try:
error_msg = msg['content']['traceback']
error_msg = msg["content"]["traceback"]
except:
try:
error_msg = msg['content']['traceback'][-1].strip()
error_msg = msg["content"]["traceback"][-1].strip()
except:
error_msg = "Traceback Error"
if verbose:
......@@ -114,12 +114,12 @@ class CodeKernel:
return None
def check_msg(self, msg, verbose=False):
status = msg['content']['status']
if status == 'ok':
status = msg["content"]["status"]
if status == "ok":
if verbose:
print("Execution succeeded.")
elif status == 'error':
for line in msg['content']['traceback']:
elif status == "error":
for line in msg["content"]["traceback"]:
if verbose:
print(line)
......@@ -144,17 +144,17 @@ class CodeKernel:
def is_alive(self):
return self.kernel.is_alive()
def clean_ansi_codes(input_string):
return ANSI_ESCAPE.sub('', input_string)
return ANSI_ESCAPE.sub("", input_string)
def extract_code(text: str) -> str:
matches = CODE.findall(text, re.DOTALL)
return matches[-1][1]
def execute(
code: str,
kernel: CodeKernel
) -> tuple[Literal['text', 'image'] | None, str]:
def execute(code: str, kernel: CodeKernel) -> tuple[Literal["text", "image"] | None, str]:
res = ""
res_type = None
code = code.replace("<|observation|>", "")
......@@ -164,37 +164,38 @@ def execute(
code = code.replace("<|system|>", "")
msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout":
return res_type, 'Timed out'
elif msg['metadata']['status'] == 'error':
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True)))
if msg["metadata"]["status"] == "timeout":
return res_type, "Timed out"
elif msg["metadata"]["status"] == "error":
return res_type, clean_ansi_codes("\n".join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output:
if "text" in output:
res_type = "text"
res = output['text']
elif 'data' in output:
for key in output['data']:
if 'text/plain' in key:
res = output["text"]
elif "data" in output:
for key in output["data"]:
if "text/plain" in key:
res_type = "text"
res = output['data'][key]
elif 'image/png' in key:
res = output["data"][key]
elif "image/png" in key:
res_type = "image"
res = output['data'][key]
res = output["data"][key]
break
return res_type, res
@st.cache_resource
def get_kernel() -> CodeKernel:
return CodeKernel()
def tool_call(code: str, session_id: str) -> list[ToolObservation]:
kernel = get_kernel()
res_type, res = execute(code, kernel)
# Convert base64 to data uri
text = '[Image]' if res_type == 'image' else res
image = f'data:image/png;base64,{res}' if res_type == 'image' else None
text = "[Image]" if res_type == "image" else res
image = f"data:image/png;base64,{res}" if res_type == "image" else None
return [ToolObservation(res_type, text, image)]
......@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and
through defined interfaces.
"""
from collections.abc import Callable
import copy
import inspect
import json
from pprint import pformat
import subprocess
import traceback
from collections.abc import Callable
from types import GenericAlias
from typing import get_origin, Annotated
import subprocess
from .interface import ToolObservation
from typing import Annotated, get_origin
from .browser import tool_call as browser
from .cogview import tool_call as cogview
from .interface import ToolObservation
from .python import tool_call as python
ALL_TOOLS = {
"simple_browser": browser,
"python": python,
......@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv
# Dispatch predefined tools
if tool_name in ALL_TOOLS:
return ALL_TOOLS[tool_name](code, session_id)
code = code.strip().rstrip('<|observation|>').strip()
code = code.strip().rstrip("<|observation|>").strip()
# Dispatch custom tools
try:
......@@ -105,8 +104,8 @@ def get_tools() -> list[dict]:
@register_tool
def random_number_generator(
seed: Annotated[int, "The random seed used by the generator", True],
range: Annotated[tuple[int, int], "The range of the generated numbers", True],
seed: Annotated[int, "The random seed used by the generator", True],
range: Annotated[tuple[int, int], "The range of the generated numbers", True],
) -> int:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
......@@ -125,7 +124,7 @@ def random_number_generator(
@register_tool
def get_weather(
city_name: Annotated[str, "The name of the city to be queried", True],
city_name: Annotated[str, "The name of the city to be queried", True],
) -> str:
"""
Get the current weather for `city_name`
......@@ -153,16 +152,14 @@ def get_weather(
except:
import traceback
ret = (
"Error encountered while fetching weather data!\n" + traceback.format_exc()
)
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
return str(ret)
@register_tool
def get_shell(
query: Annotated[str, "The command should run in Linux shell", True],
query: Annotated[str, "The command should run in Linux shell", True],
) -> str:
"""
Use shell to run command
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment