Commit 67ca83cf authored by Rayyyyy's avatar Rayyyyy
Browse files

Support GLM-4-0414

parent 78ba9d16
...@@ -172,13 +172,20 @@ abstract class BaseBrowser { ...@@ -172,13 +172,20 @@ abstract class BaseBrowser {
logger.debug(`Searching for: ${query}`); logger.debug(`Searching for: ${query}`);
const search = new URLSearchParams({ q: query }); const search = new URLSearchParams({ q: query });
recency_days > 0 && search.append('recency_days', recency_days.toString()); recency_days > 0 && search.append('recency_days', recency_days.toString());
if (config.CUSTOM_CONFIG_ID) {
search.append('customconfig', config.CUSTOM_CONFIG_ID.toString());
}
const url = `${config.BING_SEARCH_API_URL}/search?${search.toString()}`;
console.log('Full URL:', url); // 输出完整的 URL查看是否正确
return withTimeout( return withTimeout(
config.BROWSER_TIMEOUT, config.BROWSER_TIMEOUT,
fetch(`${config.BING_SEARCH_API_URL}/search?${search.toString()}`, { fetch(url, {
headers: { headers: {
'Ocp-Apim-Subscription-Key': config.BING_SEARCH_API_KEY, 'Ocp-Apim-Subscription-Key': config.BING_SEARCH_API_KEY,
} }
}).then( })
.then(
res => res =>
res.json() as Promise<{ res.json() as Promise<{
queryContext: { queryContext: {
...@@ -255,11 +262,11 @@ abstract class BaseBrowser { ...@@ -255,11 +262,11 @@ abstract class BaseBrowser {
} }
}) })
.catch(err => { .catch(err => {
logger.error(err.message); logger.error(`搜索请求失败:${query},错误信息:${err.message}`);
if (err.code === 'ECONNABORTED') { if (err.code === 'ECONNABORTED') {
throw new Error(`Timeout while executing search for: ${query}`); throw new Error(`Timeout while executing search for: ${query}`);
} }
throw new Error(`Network or server error occurred`); throw new Error(`网络或服务器发生错误,请检查URL: ${url}`);
}); });
}, },
open_url: (url: string) => { open_url: (url: string) => {
......
export default { export default {
LOG_LEVEL: 'debug', LOG_LEVEL: 'debug',
BROWSER_TIMEOUT: 10000, BROWSER_TIMEOUT: 10000,
BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/', BING_SEARCH_API_URL: 'https://api.bing.microsoft.com/v7.0/custom/',
BING_SEARCH_API_KEY: '', BING_SEARCH_API_KEY: 'YOUR_BING_SEARCH_API_KEY',
CUSTOM_CONFIG_ID : 'YOUR_CUSTOM_CONFIG_ID', //将您的Custom Configuration ID放在此处
HOST: 'localhost', HOST: 'localhost',
PORT: 3000, PORT: 3000,
}; };
\ No newline at end of file
...@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => { ...@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => {
} = req.body; } = req.body;
logger.info(`session_id: ${session_id}`); logger.info(`session_id: ${session_id}`);
logger.info(`action: ${action}`); logger.info(`action: ${action}`);
if (!session_history[session_id]) { if (!session_history[session_id]) {
session_history[session_id] = new SimpleBrowser(); session_history[session_id] = new SimpleBrowser();
} }
......
...@@ -53,4 +53,4 @@ export const withTimeout = <T>( ...@@ -53,4 +53,4 @@ export const withTimeout = <T>(
setTimeout(() => reject(new TimeoutError()), millis) setTimeout(() => reject(new TimeoutError()), millis)
); );
return Promise.race([promiseWithTime(promise), timeout]); return Promise.race([promiseWithTime(promise), timeout]);
}; };
\ No newline at end of file
# Please install the requirments.txt in inference first!
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
langchain>=0.2.12
langchain-community>=0.2.11
matplotlib>=3.9.1
pymupdf>=1.24.9
python-docx>=1.1.2
python-pptx>=0.6.23
pyyaml>=6.0.1
requests>=2.31.0
streamlit>=1.37.1
zhipuai>=2.1.4
...@@ -13,7 +13,6 @@ from enum import Enum, auto ...@@ -13,7 +13,6 @@ from enum import Enum, auto
from typing import Protocol from typing import Protocol
import streamlit as st import streamlit as st
from conversation import Conversation, build_system_prompt from conversation import Conversation, build_system_prompt
from tools.tool_registry import ALL_TOOLS from tools.tool_registry import ALL_TOOLS
...@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS ...@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS
class ClientType(Enum): class ClientType(Enum):
HF = auto() HF = auto()
VLLM = auto() VLLM = auto()
API = auto()
class Client(Protocol): class Client(Protocol):
...@@ -34,15 +34,15 @@ class Client(Protocol): ...@@ -34,15 +34,15 @@ class Client(Protocol):
) -> Generator[tuple[str | dict, list[dict]]]: ... ) -> Generator[tuple[str | dict, list[dict]]]: ...
def process_input(history: list[dict], tools: list[dict]) -> list[dict]: def process_input(history: list[dict], tools: list[dict], role_name_replace: dict = None) -> list[dict]:
chat_history = [] chat_history = []
if len(tools) > 0: # if len(tools) > 0:
chat_history.append( chat_history.append({"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)})
{"role": "system", "content": build_system_prompt(list(ALL_TOOLS), tools)}
)
for conversation in history: for conversation in history:
role = str(conversation.role).removeprefix("<|").removesuffix("|>") role = str(conversation.role).removeprefix("<|").removesuffix("|>")
if role_name_replace:
role = role_name_replace.get(role, role)
item = { item = {
"role": role, "role": role,
"content": conversation.content, "content": conversation.content,
...@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client: ...@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client:
e.msg += "; did you forget to install vLLM?" e.msg += "; did you forget to install vLLM?"
raise raise
return VLLMClient(model_path) return VLLMClient(model_path)
case ClientType.API:
from clients.openai import APIClient
return APIClient(model_path)
raise NotImplementedError(f"Client type {typ} is not supported.") raise NotImplementedError(f"Client type {typ} is not supported.")
...@@ -2,25 +2,23 @@ ...@@ -2,25 +2,23 @@
HuggingFace client. HuggingFace client.
""" """
import threading
from collections.abc import Generator from collections.abc import Generator
from threading import Thread from threading import Thread
import torch import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from client import Client, process_input, process_response from client import Client, process_input, process_response
from conversation import Conversation from conversation import Conversation
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
class HFClient(Client): class HFClient(Client):
def __init__(self, model_path: str): def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, model_path,
trust_remote_code=True,
) )
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device_map="cuda", device_map="cuda",
).eval() ).eval()
......
"""
OpenAI API client.
"""
from collections.abc import Generator
from client import Client, process_input, process_response
from conversation import Conversation
from openai import OpenAI
def format_openai_tool(origin_tools):
openai_tools = []
for tool in origin_tools:
openai_param = {}
for param in tool["params"]:
openai_param[param["name"]] = {}
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool["description"],
"parameters": {
"type": "object",
"properties": {
param["name"]: {"type": param["type"], "description": param["description"]}
for param in tool["params"]
},
"required": [param["name"] for param in tool["params"] if param["required"]],
},
},
}
openai_tools.append(openai_tool)
return openai_tools
class APIClient(Client):
def __init__(self, model_path: str):
base_url = "http://127.0.0.1:8000/v1/"
self.client = OpenAI(api_key="EMPTY", base_url=base_url)
self.use_stream = False
self.role_name_replace = {"observation": "tool"}
def generate_stream(
self,
tools: list[dict],
history: list[Conversation],
**parameters,
) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, "", role_name_replace=self.role_name_replace)
# messages = process_input(history, '', role_name_replace=self.role_name_replace)
openai_tools = format_openai_tool(tools)
response = self.client.chat.completions.create(
model="glm-4",
messages=chat_history,
tools=openai_tools,
stream=self.use_stream,
max_tokens=parameters["max_new_tokens"],
temperature=parameters["temperature"],
presence_penalty=1.2,
top_p=parameters["top_p"],
tool_choice="auto",
)
output = response.choices[0].message
if output.tool_calls:
glm4_output = output.tool_calls[0].function.name + "\n" + output.tool_calls[0].function.arguments
else:
glm4_output = output.content
yield process_response(glm4_output, chat_history)
...@@ -8,23 +8,19 @@ installation guide before running this client. ...@@ -8,23 +8,19 @@ installation guide before running this client.
import time import time
from collections.abc import Generator from collections.abc import Generator
from transformers import AutoTokenizer
from vllm import SamplingParams, LLMEngine, EngineArgs
from client import Client, process_input, process_response from client import Client, process_input, process_response
from conversation import Conversation from conversation import Conversation
from transformers import AutoTokenizer
from vllm import EngineArgs, LLMEngine, SamplingParams
class VLLMClient(Client): class VLLMClient(Client):
def __init__(self, model_path: str): def __init__(self, model_path: str):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model_path, trust_remote_code=True
)
self.engine_args = EngineArgs( self.engine_args = EngineArgs(
model=model_path, model=model_path,
tensor_parallel_size=1, tensor_parallel_size=1,
dtype="bfloat16", # torch.bfloat16 is needed. dtype="bfloat16", # torch.bfloat16 is needed.
trust_remote_code=True,
gpu_memory_utilization=0.6, gpu_memory_utilization=0.6,
enforce_eager=True, enforce_eager=True,
worker_use_ray=False, worker_use_ray=False,
...@@ -35,29 +31,20 @@ class VLLMClient(Client): ...@@ -35,29 +31,20 @@ class VLLMClient(Client):
self, tools: list[dict], history: list[Conversation], **parameters self, tools: list[dict], history: list[Conversation], **parameters
) -> Generator[tuple[str | dict, list[dict]]]: ) -> Generator[tuple[str | dict, list[dict]]]:
chat_history = process_input(history, tools) chat_history = process_input(history, tools)
model_inputs = self.tokenizer.apply_chat_template( model_inputs = self.tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, tokenize=False)
chat_history, add_generation_prompt=True, tokenize=False
)
parameters["max_tokens"] = parameters.pop("max_new_tokens") parameters["max_tokens"] = parameters.pop("max_new_tokens")
params_dict = { params_dict = {
"n": 1, "n": 1,
"best_of": 1, "best_of": 1,
"top_p": 1, "top_p": 1,
"top_k": -1, "top_k": -1,
"use_beam_search": False,
"length_penalty": 1, "length_penalty": 1,
"early_stopping": False,
"stop_token_ids": [151329, 151336, 151338], "stop_token_ids": [151329, 151336, 151338],
"ignore_eos": False,
"logprobs": None,
"prompt_logprobs": None,
} }
params_dict.update(parameters) params_dict.update(parameters)
sampling_params = SamplingParams(**params_dict) sampling_params = SamplingParams(**params_dict)
self.engine.add_request( self.engine.add_request(request_id=str(time.time()), inputs=model_inputs, params=sampling_params)
request_id=str(time.time()), inputs=model_inputs, params=sampling_params
)
while self.engine.has_unfinished_requests(): while self.engine.has_unfinished_requests():
request_outputs = self.engine.step() request_outputs = self.engine.step()
for request_output in request_outputs: for request_output in request_outputs:
......
...@@ -5,12 +5,11 @@ from datetime import datetime ...@@ -5,12 +5,11 @@ from datetime import datetime
from enum import Enum, auto from enum import Enum, auto
import streamlit as st import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL.Image import Image from PIL.Image import Image
from streamlit.delta_generator import DeltaGenerator
from tools.browser import Quote, quotes from tools.browser import Quote, quotes
QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】") QUOTE_REGEX = re.compile(r"【(\d+)†(.+?)】")
SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。" SELFCOG_PROMPT = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
...@@ -30,7 +29,8 @@ def build_system_prompt( ...@@ -30,7 +29,8 @@ def build_system_prompt(
): ):
value = SELFCOG_PROMPT value = SELFCOG_PROMPT
value += "\n\n" + datetime.now().strftime(DATE_PROMPT) value += "\n\n" + datetime.now().strftime(DATE_PROMPT)
value += "\n\n# 可用工具" if enabled_tools or functions:
value += "\n\n# 可用工具"
contents = [] contents = []
for tool in enabled_tools: for tool in enabled_tools:
contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}") contents.append(f"\n\n## {tool}\n\n{TOOL_SYSTEM_PROMPTS[tool]}")
...@@ -130,23 +130,23 @@ class Conversation: ...@@ -130,23 +130,23 @@ class Conversation:
if self.role != Role.USER: if self.role != Role.USER:
show_text = text show_text = text
else: else:
splitted = text.split('files uploaded.\n') splitted = text.split("files uploaded.\n")
if len(splitted) == 1: if len(splitted) == 1:
show_text = text show_text = text
else: else:
# Show expander for document content # Show expander for document content
doc = splitted[0] doc = splitted[0]
show_text = splitted[-1] show_text = splitted[-1]
expander = message.expander(f'File Content') expander = message.expander("File Content")
expander.markdown(doc) expander.markdown(doc)
message.markdown(show_text) message.markdown(show_text)
def postprocess_text(text: str, replace_quote: bool) -> str: def postprocess_text(text: str, replace_quote: bool) -> str:
text = text.replace("\(", "$") text = text.replace(r"\(", "$")
text = text.replace("\)", "$") text = text.replace(r"\)", "$")
text = text.replace("\[", "$$") text = text.replace(r"\[", "$$")
text = text.replace("\]", "$$") text = text.replace(r"\]", "$$")
text = text.replace("<|assistant|>", "") text = text.replace("<|assistant|>", "")
text = text.replace("<|observation|>", "") text = text.replace("<|observation|>", "")
text = text.replace("<|system|>", "") text = text.replace("<|system|>", "")
...@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str: ...@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str:
for match in QUOTE_REGEX.finditer(text): for match in QUOTE_REGEX.finditer(text):
quote_id = match.group(1) quote_id = match.group(1)
quote = quotes.get(quote_id, Quote("未找到引用内容", "")) quote = quotes.get(quote_id, Quote("未找到引用内容", ""))
text = text.replace( text = text.replace(match.group(0), f" (来源:[{quote.title}]({quote.url})) ")
match.group(0), f" (来源:[{quote.title}]({quote.url})) "
)
return text.strip() return text.strip()
...@@ -12,10 +12,6 @@ from io import BytesIO ...@@ -12,10 +12,6 @@ from io import BytesIO
from uuid import uuid4 from uuid import uuid4
import streamlit as st import streamlit as st
from streamlit.delta_generator import DeltaGenerator
from PIL import Image
from client import Client, ClientType, get_client from client import Client, ClientType, get_client
from conversation import ( from conversation import (
FILE_TEMPLATE, FILE_TEMPLATE,
...@@ -24,14 +20,17 @@ from conversation import ( ...@@ -24,14 +20,17 @@ from conversation import (
postprocess_text, postprocess_text,
response_to_str, response_to_str,
) )
from PIL import Image
from streamlit.delta_generator import DeltaGenerator
from tools.tool_registry import dispatch_tool, get_tools from tools.tool_registry import dispatch_tool, get_tools
from utils import extract_pdf, extract_docx, extract_pptx, extract_text from utils import extract_docx, extract_pdf, extract_pptx, extract_text
CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat") CHAT_MODEL_PATH = os.environ.get("CHAT_MODEL_PATH", "THUDM/glm-4-9b-chat")
VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b") VLM_MODEL_PATH = os.environ.get("VLM_MODEL_PATH", "THUDM/glm-4v-9b")
USE_VLLM = os.environ.get("USE_VLLM", "0") == "1" USE_VLLM = os.environ.get("USE_VLLM", "0") == "1"
USE_API = os.environ.get("USE_API", "0") == "1"
class Mode(str, Enum): class Mode(str, Enum):
...@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client: ...@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client:
case Mode.ALL_TOOLS: case Mode.ALL_TOOLS:
st.session_state.top_k = 10 st.session_state.top_k = 10
typ = ClientType.VLLM if USE_VLLM else ClientType.HF typ = ClientType.VLLM if USE_VLLM else ClientType.HF
typ = ClientType.API if USE_API else typ
return get_client(CHAT_MODEL_PATH, typ) return get_client(CHAT_MODEL_PATH, typ)
case Mode.LONG_CTX: case Mode.LONG_CTX:
st.session_state.top_k = 10 st.session_state.top_k = 10
...@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX: ...@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX:
content = extract_pptx(file_path) content = extract_pptx(file_path)
else: else:
content = extract_text(file_path) content = extract_text(file_path)
uploaded_texts.append( uploaded_texts.append(FILE_TEMPLATE.format(file_name=file_name, file_content=content))
FILE_TEMPLATE.format(file_name=file_name, file_content=content)
)
os.remove(file_path) os.remove(file_path)
st.session_state.uploaded_texts = "\n\n".join(uploaded_texts) st.session_state.uploaded_texts = "\n\n".join(uploaded_texts)
st.session_state.uploaded_file_nums = len(uploaded_files) st.session_state.uploaded_file_nums = len(uploaded_files)
...@@ -230,9 +228,7 @@ def main(prompt_text: str): ...@@ -230,9 +228,7 @@ def main(prompt_text: str):
# Append uploaded files # Append uploaded files
uploaded_texts = st.session_state.get("uploaded_texts") uploaded_texts = st.session_state.get("uploaded_texts")
if page == Mode.LONG_CTX and uploaded_texts and first_round: if page == Mode.LONG_CTX and uploaded_texts and first_round:
meta_msg = "{} files uploaded.\n".format( meta_msg = "{} files uploaded.\n".format(st.session_state.uploaded_file_nums)
st.session_state.uploaded_file_nums
)
prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text prompt_text = uploaded_texts + "\n\n\n" + meta_msg + prompt_text
# Clear after first use # Clear after first use
st.session_state.files_uploaded = True st.session_state.files_uploaded = True
...@@ -247,16 +243,12 @@ def main(prompt_text: str): ...@@ -247,16 +243,12 @@ def main(prompt_text: str):
append_conversation(Conversation(role, prompt_text, image=image), history) append_conversation(Conversation(role, prompt_text, image=image), history)
placeholder = st.container() placeholder = st.container()
message_placeholder = placeholder.chat_message( message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty() markdown_placeholder = message_placeholder.empty()
def add_new_block(): def add_new_block():
nonlocal message_placeholder, markdown_placeholder nonlocal message_placeholder, markdown_placeholder
message_placeholder = placeholder.chat_message( message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
name="assistant", avatar="assistant"
)
markdown_placeholder = message_placeholder.empty() markdown_placeholder = message_placeholder.empty()
def commit_conversation( def commit_conversation(
...@@ -301,33 +293,18 @@ def main(prompt_text: str): ...@@ -301,33 +293,18 @@ def main(prompt_text: str):
history_len = len(chat_history) history_len = len(chat_history)
last_response = response last_response = response
replace_quote = chat_history[-1]["role"] == "assistant" replace_quote = chat_history[-1]["role"] == "assistant"
markdown_placeholder.markdown( markdown_placeholder.markdown(postprocess_text(str(response) + "●", replace_quote=replace_quote))
postprocess_text(
str(response) + "●", replace_quote=replace_quote
)
)
else: else:
metadata = ( metadata = page == Mode.ALL_TOOLS and isinstance(response, dict) and response.get("name") or None
page == Mode.ALL_TOOLS
and isinstance(response, dict)
and response.get("name")
or None
)
role = Role.TOOL if metadata else Role.ASSISTANT role = Role.TOOL if metadata else Role.ASSISTANT
text = ( text = response.get("content") if metadata else response_to_str(response)
response.get("content")
if metadata
else response_to_str(response)
)
commit_conversation(role, text, metadata) commit_conversation(role, text, metadata)
if metadata: if metadata:
add_new_block() add_new_block()
try: try:
with markdown_placeholder: with markdown_placeholder:
with st.spinner(f"Calling tool {metadata}..."): with st.spinner(f"Calling tool {metadata}..."):
observations = dispatch_tool( observations = dispatch_tool(metadata, text, str(st.session_state.session_id))
metadata, text, str(st.session_state.session_id)
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
st.error(f'Uncaught exception in `"{metadata}"`: {e}') st.error(f'Uncaught exception in `"{metadata}"`: {e}')
...@@ -346,7 +323,7 @@ def main(prompt_text: str): ...@@ -346,7 +323,7 @@ def main(prompt_text: str):
continue continue
else: else:
break break
except Exception as e: except Exception:
traceback.print_exc() traceback.print_exc()
st.error(f"Uncaught exception: {traceback.format_exc()}") st.error(f"Uncaught exception: {traceback.format_exc()}")
else: else:
......
...@@ -6,22 +6,26 @@ Simple browser tool. ...@@ -6,22 +6,26 @@ Simple browser tool.
Please start the backend browser server according to the instructions in the README. Please start the backend browser server according to the instructions in the README.
""" """
from pprint import pprint
import re import re
from dataclasses import dataclass
from pprint import pprint
import requests import requests
import streamlit as st import streamlit as st
from dataclasses import dataclass
from .config import BROWSER_SERVER_URL from .config import BROWSER_SERVER_URL
from .interface import ToolObservation from .interface import ToolObservation
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]") QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")
@dataclass @dataclass
class Quote: class Quote:
title: str title: str
url: str url: str
# Quotes for displaying reference # Quotes for displaying reference
if "quotes" not in st.session_state: if "quotes" not in st.session_state:
st.session_state.quotes = {} st.session_state.quotes = {}
...@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes ...@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes
def map_response(response: dict) -> ToolObservation: def map_response(response: dict) -> ToolObservation:
# Save quotes for reference # Save quotes for reference
print('===BROWSER_RESPONSE===') print("===BROWSER_RESPONSE===")
pprint(response) pprint(response)
role_metadata = response.get("roleMetadata") role_metadata = response.get("roleMetadata")
metadata = response.get("metadata") metadata = response.get("metadata")
if role_metadata.split()[0] == 'quote_result' and metadata: if role_metadata.split()[0] == "quote_result" and metadata:
quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1) quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
quote: dict[str, str] = metadata['metadata_list'][0] quote: dict[str, str] = metadata["metadata_list"][0]
quotes[quote_id] = Quote(quote['title'], quote['url']) quotes[quote_id] = Quote(quote["title"], quote["url"])
elif role_metadata == 'browser_result' and metadata: elif role_metadata == "browser_result" and metadata:
for i, quote in enumerate(metadata['metadata_list']): for i, quote in enumerate(metadata["metadata_list"]):
quotes[str(i)] = Quote(quote['title'], quote['url']) quotes[str(i)] = Quote(quote["title"], quote["url"])
return ToolObservation( return ToolObservation(
content_type=response.get("contentType"), content_type=response.get("contentType"),
......
...@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage ...@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage
from .config import COGVIEW_MODEL, ZHIPU_AI_KEY from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
from .interface import ToolObservation from .interface import ToolObservation
@st.cache_resource @st.cache_resource
def get_zhipu_client(): def get_zhipu_client():
return ZhipuAI(api_key=ZHIPU_AI_KEY) return ZhipuAI(api_key=ZHIPU_AI_KEY)
def map_response(img: GeneratedImage): def map_response(img: GeneratedImage):
return ToolObservation( return ToolObservation(
content_type='image', content_type="image",
text='CogView 已经生成并向用户展示了生成的图片。', text="CogView 已经生成并向用户展示了生成的图片。",
image_url=img.url, image_url=img.url,
role_metadata='cogview_result' role_metadata="cogview_result",
) )
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]: def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
client = get_zhipu_client() client = get_zhipu_client()
response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
......
BROWSER_SERVER_URL = "http://localhost:3000"
IPYKERNEL = "glm-4-demo"
ZHIPU_AI_KEY = ""
COGVIEW_MODEL = "cogview-3"
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
@dataclass @dataclass
class ToolObservation: class ToolObservation:
content_type: str content_type: str
......
from pprint import pprint
import queue import queue
import re import re
from pprint import pprint
from subprocess import PIPE from subprocess import PIPE
from typing import Literal from typing import Literal
...@@ -10,19 +10,22 @@ import streamlit as st ...@@ -10,19 +10,22 @@ import streamlit as st
from .config import IPYKERNEL from .config import IPYKERNEL
from .interface import ToolObservation from .interface import ToolObservation
ANSI_ESCAPE = re.compile(r'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]')
CODE = re.compile(r'```([^\n]*)\n(.*?)```')
class CodeKernel: ANSI_ESCAPE = re.compile(r"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]")
def __init__(self, CODE = re.compile(r"```([^\n]*)\n(.*?)```")
kernel_name='kernel',
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1):
class CodeKernel:
def __init__(
self,
kernel_name="kernel",
kernel_id=None,
kernel_config_path="",
python_path=None,
ipython_path=None,
init_file_path="./startup.py",
verbose=1,
):
self.kernel_name = kernel_name self.kernel_name = kernel_name
self.kernel_id = kernel_id self.kernel_id = kernel_id
self.kernel_config_path = kernel_config_path self.kernel_config_path = kernel_config_path
...@@ -37,19 +40,16 @@ class CodeKernel: ...@@ -37,19 +40,16 @@ class CodeKernel:
env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path} env = {"PATH": self.python_path + ":$PATH", "PYTHONPATH": self.python_path}
# Initialize the backend kernel # Initialize the backend kernel
self.kernel_manager = jupyter_client.KernelManager(kernel_name=IPYKERNEL, self.kernel_manager = jupyter_client.KernelManager(
connection_file=self.kernel_config_path, kernel_name=IPYKERNEL, connection_file=self.kernel_config_path, exec_files=[self.init_file_path], env=env
exec_files=[self.init_file_path], )
env=env)
if self.kernel_config_path: if self.kernel_config_path:
self.kernel_manager.load_connection_file() self.kernel_manager.load_connection_file()
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format( print("Backend kernel started with the configuration: {}".format(self.kernel_config_path))
self.kernel_config_path))
else: else:
self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE) self.kernel_manager.start_kernel(stdout=PIPE, stderr=PIPE)
print("Backend kernel started with the configuration: {}".format( print("Backend kernel started with the configuration: {}".format(self.kernel_manager.connection_file))
self.kernel_manager.connection_file))
if verbose: if verbose:
pprint(self.kernel_manager.get_connection_info()) pprint(self.kernel_manager.get_connection_info())
...@@ -64,13 +64,13 @@ class CodeKernel: ...@@ -64,13 +64,13 @@ class CodeKernel:
self.kernel.execute(code) self.kernel.execute(code)
try: try:
shell_msg = self.kernel.get_shell_msg(timeout=30) shell_msg = self.kernel.get_shell_msg(timeout=30)
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
while True: while True:
msg_out = io_msg_content msg_out = io_msg_content
### Poll the message ### Poll the message
try: try:
io_msg_content = self.kernel.get_iopub_msg(timeout=30)['content'] io_msg_content = self.kernel.get_iopub_msg(timeout=30)["content"]
if 'execution_state' in io_msg_content and io_msg_content['execution_state'] == 'idle': if "execution_state" in io_msg_content and io_msg_content["execution_state"] == "idle":
break break
except queue.Empty: except queue.Empty:
break break
...@@ -100,12 +100,12 @@ class CodeKernel: ...@@ -100,12 +100,12 @@ class CodeKernel:
return shell_msg return shell_msg
def get_error_msg(self, msg, verbose=False) -> str | None: def get_error_msg(self, msg, verbose=False) -> str | None:
if msg['content']['status'] == 'error': if msg["content"]["status"] == "error":
try: try:
error_msg = msg['content']['traceback'] error_msg = msg["content"]["traceback"]
except: except:
try: try:
error_msg = msg['content']['traceback'][-1].strip() error_msg = msg["content"]["traceback"][-1].strip()
except: except:
error_msg = "Traceback Error" error_msg = "Traceback Error"
if verbose: if verbose:
...@@ -114,12 +114,12 @@ class CodeKernel: ...@@ -114,12 +114,12 @@ class CodeKernel:
return None return None
def check_msg(self, msg, verbose=False): def check_msg(self, msg, verbose=False):
status = msg['content']['status'] status = msg["content"]["status"]
if status == 'ok': if status == "ok":
if verbose: if verbose:
print("Execution succeeded.") print("Execution succeeded.")
elif status == 'error': elif status == "error":
for line in msg['content']['traceback']: for line in msg["content"]["traceback"]:
if verbose: if verbose:
print(line) print(line)
...@@ -144,17 +144,17 @@ class CodeKernel: ...@@ -144,17 +144,17 @@ class CodeKernel:
def is_alive(self): def is_alive(self):
return self.kernel.is_alive() return self.kernel.is_alive()
def clean_ansi_codes(input_string): def clean_ansi_codes(input_string):
return ANSI_ESCAPE.sub('', input_string) return ANSI_ESCAPE.sub("", input_string)
def extract_code(text: str) -> str: def extract_code(text: str) -> str:
matches = CODE.findall(text, re.DOTALL) matches = CODE.findall(text, re.DOTALL)
return matches[-1][1] return matches[-1][1]
def execute(
code: str, def execute(code: str, kernel: CodeKernel) -> tuple[Literal["text", "image"] | None, str]:
kernel: CodeKernel
) -> tuple[Literal['text', 'image'] | None, str]:
res = "" res = ""
res_type = None res_type = None
code = code.replace("<|observation|>", "") code = code.replace("<|observation|>", "")
...@@ -164,37 +164,38 @@ def execute( ...@@ -164,37 +164,38 @@ def execute(
code = code.replace("<|system|>", "") code = code.replace("<|system|>", "")
msg, output = kernel.execute(code) msg, output = kernel.execute(code)
if msg['metadata']['status'] == "timeout": if msg["metadata"]["status"] == "timeout":
return res_type, 'Timed out' return res_type, "Timed out"
elif msg['metadata']['status'] == 'error': elif msg["metadata"]["status"] == "error":
return res_type, clean_ansi_codes('\n'.join(kernel.get_error_msg(msg, verbose=True))) return res_type, clean_ansi_codes("\n".join(kernel.get_error_msg(msg, verbose=True)))
if 'text' in output: if "text" in output:
res_type = "text" res_type = "text"
res = output['text'] res = output["text"]
elif 'data' in output: elif "data" in output:
for key in output['data']: for key in output["data"]:
if 'text/plain' in key: if "text/plain" in key:
res_type = "text" res_type = "text"
res = output['data'][key] res = output["data"][key]
elif 'image/png' in key: elif "image/png" in key:
res_type = "image" res_type = "image"
res = output['data'][key] res = output["data"][key]
break break
return res_type, res return res_type, res
@st.cache_resource @st.cache_resource
def get_kernel() -> CodeKernel: def get_kernel() -> CodeKernel:
return CodeKernel() return CodeKernel()
def tool_call(code: str, session_id: str) -> list[ToolObservation]: def tool_call(code: str, session_id: str) -> list[ToolObservation]:
kernel = get_kernel() kernel = get_kernel()
res_type, res = execute(code, kernel) res_type, res = execute(code, kernel)
# Convert base64 to data uri # Convert base64 to data uri
text = '[Image]' if res_type == 'image' else res text = "[Image]" if res_type == "image" else res
image = f'data:image/png;base64,{res}' if res_type == 'image' else None image = f"data:image/png;base64,{res}" if res_type == "image" else None
return [ToolObservation(res_type, text, image)] return [ToolObservation(res_type, text, image)]
...@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and ...@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and
through defined interfaces. through defined interfaces.
""" """
from collections.abc import Callable
import copy import copy
import inspect import inspect
import json import json
from pprint import pformat import subprocess
import traceback import traceback
from collections.abc import Callable
from types import GenericAlias from types import GenericAlias
from typing import get_origin, Annotated from typing import Annotated, get_origin
import subprocess
from .interface import ToolObservation
from .browser import tool_call as browser from .browser import tool_call as browser
from .cogview import tool_call as cogview from .cogview import tool_call as cogview
from .interface import ToolObservation
from .python import tool_call as python from .python import tool_call as python
ALL_TOOLS = { ALL_TOOLS = {
"simple_browser": browser, "simple_browser": browser,
"python": python, "python": python,
...@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv ...@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv
# Dispatch predefined tools # Dispatch predefined tools
if tool_name in ALL_TOOLS: if tool_name in ALL_TOOLS:
return ALL_TOOLS[tool_name](code, session_id) return ALL_TOOLS[tool_name](code, session_id)
code = code.strip().rstrip('<|observation|>').strip() code = code.strip().rstrip("<|observation|>").strip()
# Dispatch custom tools # Dispatch custom tools
try: try:
...@@ -105,8 +104,8 @@ def get_tools() -> list[dict]: ...@@ -105,8 +104,8 @@ def get_tools() -> list[dict]:
@register_tool @register_tool
def random_number_generator( def random_number_generator(
seed: Annotated[int, "The random seed used by the generator", True], seed: Annotated[int, "The random seed used by the generator", True],
range: Annotated[tuple[int, int], "The range of the generated numbers", True], range: Annotated[tuple[int, int], "The range of the generated numbers", True],
) -> int: ) -> int:
""" """
Generates a random number x, s.t. range[0] <= x < range[1] Generates a random number x, s.t. range[0] <= x < range[1]
...@@ -125,7 +124,7 @@ def random_number_generator( ...@@ -125,7 +124,7 @@ def random_number_generator(
@register_tool @register_tool
def get_weather( def get_weather(
city_name: Annotated[str, "The name of the city to be queried", True], city_name: Annotated[str, "The name of the city to be queried", True],
) -> str: ) -> str:
""" """
Get the current weather for `city_name` Get the current weather for `city_name`
...@@ -153,16 +152,14 @@ def get_weather( ...@@ -153,16 +152,14 @@ def get_weather(
except: except:
import traceback import traceback
ret = ( ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
"Error encountered while fetching weather data!\n" + traceback.format_exc()
)
return str(ret) return str(ret)
@register_tool @register_tool
def get_shell( def get_shell(
query: Annotated[str, "The command should run in Linux shell", True], query: Annotated[str, "The command should run in Linux shell", True],
) -> str: ) -> str:
""" """
Use shell to run command Use shell to run command
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment