Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GLM-4_pytorch
Commits
67ca83cf
"docs/source/en/api/diffusion_pipeline.md" did not exist on "a6e2c1fe5c02cae8a9f077f5d4e11b73d5791723"
Commit
67ca83cf
authored
Apr 17, 2025
by
Rayyyyy
Browse files
Support GLM-4-0414
parent
78ba9d16
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
163 deletions
+239
-163
demo/composite_demo/browser/pnpm-lock.yaml
demo/composite_demo/browser/pnpm-lock.yaml
+0
-0
demo/composite_demo/browser/src/browser.ts
demo/composite_demo/browser/src/browser.ts
+11
-4
demo/composite_demo/browser/src/config.ts
demo/composite_demo/browser/src/config.ts
+9
-0
demo/composite_demo/browser/src/server.ts
demo/composite_demo/browser/src/server.ts
+1
-1
demo/composite_demo/browser/src/types.ts
demo/composite_demo/browser/src/types.ts
+0
-0
demo/composite_demo/browser/src/utils.ts
demo/composite_demo/browser/src/utils.ts
+1
-1
demo/composite_demo/browser/tsconfig.json
demo/composite_demo/browser/tsconfig.json
+0
-0
demo/composite_demo/requirements.txt
demo/composite_demo/requirements.txt
+15
-0
demo/composite_demo/src/client.py
demo/composite_demo/src/client.py
+10
-6
demo/composite_demo/src/clients/hf.py
demo/composite_demo/src/clients/hf.py
+3
-5
demo/composite_demo/src/clients/openai.py
demo/composite_demo/src/clients/openai.py
+69
-0
demo/composite_demo/src/clients/vllm.py
demo/composite_demo/src/clients/vllm.py
+5
-18
demo/composite_demo/src/conversation.py
demo/composite_demo/src/conversation.py
+11
-13
demo/composite_demo/src/main.py
demo/composite_demo/src/main.py
+14
-37
demo/composite_demo/src/tools/browser.py
demo/composite_demo/src/tools/browser.py
+14
-10
demo/composite_demo/src/tools/cogview.py
demo/composite_demo/src/tools/cogview.py
+6
-3
demo/composite_demo/src/tools/config.py
demo/composite_demo/src/tools/config.py
+6
-0
demo/composite_demo/src/tools/interface.py
demo/composite_demo/src/tools/interface.py
+1
-0
demo/composite_demo/src/tools/python.py
demo/composite_demo/src/tools/python.py
+51
-50
demo/composite_demo/src/tools/tool_registry.py
demo/composite_demo/src/tools/tool_registry.py
+12
-15
No files found.
composite_demo/browser/pnpm-lock.yaml
→
demo/
composite_demo/browser/pnpm-lock.yaml
View file @
67ca83cf
File moved
composite_demo/browser/src/browser.ts
→
demo/
composite_demo/browser/src/browser.ts
View file @
67ca83cf
...
...
@@ -172,13 +172,20 @@ abstract class BaseBrowser {
logger
.
debug
(
`Searching for:
${
query
}
`
);
const
search
=
new
URLSearchParams
({
q
:
query
});
recency_days
>
0
&&
search
.
append
(
'
recency_days
'
,
recency_days
.
toString
());
if
(
config
.
CUSTOM_CONFIG_ID
)
{
search
.
append
(
'
customconfig
'
,
config
.
CUSTOM_CONFIG_ID
.
toString
());
}
const
url
=
`
${
config
.
BING_SEARCH_API_URL
}
/search?
${
search
.
toString
()}
`
;
console
.
log
(
'
Full URL:
'
,
url
);
// 输出完整的 URL查看是否正确
return
withTimeout
(
config
.
BROWSER_TIMEOUT
,
fetch
(
`
${
config
.
BING_SEARCH_API_URL
}
/search?
${
search
.
toString
()}
`
,
{
fetch
(
url
,
{
headers
:
{
'
Ocp-Apim-Subscription-Key
'
:
config
.
BING_SEARCH_API_KEY
,
}
}).
then
(
})
.
then
(
res
=>
res
.
json
()
as
Promise
<
{
queryContext
:
{
...
...
@@ -255,11 +262,11 @@ abstract class BaseBrowser {
}
})
.
catch
(
err
=>
{
logger
.
error
(
err
.
message
);
logger
.
error
(
`搜索请求失败:
${
query
}
,错误信息:
${
err
.
message
}
`
);
if
(
err
.
code
===
'
ECONNABORTED
'
)
{
throw
new
Error
(
`Timeout while executing search for:
${
query
}
`
);
}
throw
new
Error
(
`
Network or server error occurred
`
);
throw
new
Error
(
`
网络或服务器发生错误,请检查URL:
${
url
}
`
);
});
},
open_url
:
(
url
:
string
)
=>
{
...
...
composite_demo/browser/src/config.ts
→
demo/
composite_demo/browser/src/config.ts
View file @
67ca83cf
export
default
{
LOG_LEVEL
:
'
debug
'
,
BROWSER_TIMEOUT
:
10000
,
BING_SEARCH_API_URL
:
'
https://api.bing.microsoft.com/
'
,
BING_SEARCH_API_KEY
:
''
,
BING_SEARCH_API_URL
:
'
https://api.bing.microsoft.com/
v7.0/custom/
'
,
BING_SEARCH_API_KEY
:
'
YOUR_BING_SEARCH_API_KEY
'
,
CUSTOM_CONFIG_ID
:
'
YOUR_CUSTOM_CONFIG_ID
'
,
//将您的Custom Configuration ID放在此处
HOST
:
'
localhost
'
,
PORT
:
3000
,
};
\ No newline at end of file
};
composite_demo/browser/src/server.ts
→
demo/
composite_demo/browser/src/server.ts
View file @
67ca83cf
...
...
@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => {
}
=
req
.
body
;
logger
.
info
(
`session_id:
${
session_id
}
`
);
logger
.
info
(
`action:
${
action
}
`
);
if
(
!
session_history
[
session_id
])
{
session_history
[
session_id
]
=
new
SimpleBrowser
();
}
...
...
composite_demo/browser/src/types.ts
→
demo/
composite_demo/browser/src/types.ts
View file @
67ca83cf
File moved
composite_demo/browser/src/utils.ts
→
demo/
composite_demo/browser/src/utils.ts
View file @
67ca83cf
...
...
@@ -53,4 +53,4 @@ export const withTimeout = <T>(
setTimeout
(()
=>
reject
(
new
TimeoutError
()),
millis
)
);
return
Promise
.
race
([
promiseWithTime
(
promise
),
timeout
]);
};
\ No newline at end of file
};
composite_demo/browser/tsconfig.json
→
demo/
composite_demo/browser/tsconfig.json
View file @
67ca83cf
File moved
demo/composite_demo/requirements.txt
0 → 100644
View file @
67ca83cf
# Please install the requirments.txt in inference first!
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
langchain>=0.2.12
langchain-community>=0.2.11
matplotlib>=3.9.1
pymupdf>=1.24.9
python-docx>=1.1.2
python-pptx>=0.6.23
pyyaml>=6.0.1
requests>=2.31.0
streamlit>=1.37.1
zhipuai>=2.1.4
composite_demo/src/client.py
→
demo/
composite_demo/src/client.py
View file @
67ca83cf
...
...
@@ -13,7 +13,6 @@ from enum import Enum, auto
from
typing
import
Protocol
import
streamlit
as
st
from
conversation
import
Conversation
,
build_system_prompt
from
tools.tool_registry
import
ALL_TOOLS
...
...
@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS
class
ClientType
(
Enum
):
HF
=
auto
()
VLLM
=
auto
()
API
=
auto
()
class
Client
(
Protocol
):
...
...
@@ -34,15 +34,15 @@ class Client(Protocol):
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
...
def
process_input
(
history
:
list
[
dict
],
tools
:
list
[
dict
])
->
list
[
dict
]:
def
process_input
(
history
:
list
[
dict
],
tools
:
list
[
dict
]
,
role_name_replace
:
dict
=
None
)
->
list
[
dict
]:
chat_history
=
[]
if
len
(
tools
)
>
0
:
chat_history
.
append
(
{
"role"
:
"system"
,
"content"
:
build_system_prompt
(
list
(
ALL_TOOLS
),
tools
)}
)
# if len(tools) > 0:
chat_history
.
append
({
"role"
:
"system"
,
"content"
:
build_system_prompt
(
list
(
ALL_TOOLS
),
tools
)})
for
conversation
in
history
:
role
=
str
(
conversation
.
role
).
removeprefix
(
"<|"
).
removesuffix
(
"|>"
)
if
role_name_replace
:
role
=
role_name_replace
.
get
(
role
,
role
)
item
=
{
"role"
:
role
,
"content"
:
conversation
.
content
,
...
...
@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client:
e
.
msg
+=
"; did you forget to install vLLM?"
raise
return
VLLMClient
(
model_path
)
case
ClientType
.
API
:
from
clients.openai
import
APIClient
return
APIClient
(
model_path
)
raise
NotImplementedError
(
f
"Client type
{
typ
}
is not supported."
)
composite_demo/src/clients/hf.py
→
demo/
composite_demo/src/clients/hf.py
View file @
67ca83cf
...
...
@@ -2,25 +2,23 @@
HuggingFace client.
"""
import
threading
from
collections.abc
import
Generator
from
threading
import
Thread
import
torch
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
TextIteratorStreamer
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
TextIteratorStreamer
class
HFClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
,
model_path
,
trust_remote_code
=
True
,
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
,
torch_dtype
=
torch
.
bfloat16
,
device_map
=
"cuda"
,
).
eval
()
...
...
demo/composite_demo/src/clients/openai.py
0 → 100644
View file @
67ca83cf
"""
OpenAI API client.
"""
from
collections.abc
import
Generator
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
openai
import
OpenAI
def
format_openai_tool
(
origin_tools
):
openai_tools
=
[]
for
tool
in
origin_tools
:
openai_param
=
{}
for
param
in
tool
[
"params"
]:
openai_param
[
param
[
"name"
]]
=
{}
openai_tool
=
{
"type"
:
"function"
,
"function"
:
{
"name"
:
tool
[
"name"
],
"description"
:
tool
[
"description"
],
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
param
[
"name"
]:
{
"type"
:
param
[
"type"
],
"description"
:
param
[
"description"
]}
for
param
in
tool
[
"params"
]
},
"required"
:
[
param
[
"name"
]
for
param
in
tool
[
"params"
]
if
param
[
"required"
]],
},
},
}
openai_tools
.
append
(
openai_tool
)
return
openai_tools
class
APIClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
base_url
=
"http://127.0.0.1:8000/v1/"
self
.
client
=
OpenAI
(
api_key
=
"EMPTY"
,
base_url
=
base_url
)
self
.
use_stream
=
False
self
.
role_name_replace
=
{
"observation"
:
"tool"
}
def
generate_stream
(
self
,
tools
:
list
[
dict
],
history
:
list
[
Conversation
],
**
parameters
,
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
chat_history
=
process_input
(
history
,
""
,
role_name_replace
=
self
.
role_name_replace
)
# messages = process_input(history, '', role_name_replace=self.role_name_replace)
openai_tools
=
format_openai_tool
(
tools
)
response
=
self
.
client
.
chat
.
completions
.
create
(
model
=
"glm-4"
,
messages
=
chat_history
,
tools
=
openai_tools
,
stream
=
self
.
use_stream
,
max_tokens
=
parameters
[
"max_new_tokens"
],
temperature
=
parameters
[
"temperature"
],
presence_penalty
=
1.2
,
top_p
=
parameters
[
"top_p"
],
tool_choice
=
"auto"
,
)
output
=
response
.
choices
[
0
].
message
if
output
.
tool_calls
:
glm4_output
=
output
.
tool_calls
[
0
].
function
.
name
+
"
\n
"
+
output
.
tool_calls
[
0
].
function
.
arguments
else
:
glm4_output
=
output
.
content
yield
process_response
(
glm4_output
,
chat_history
)
composite_demo/src/clients/vllm.py
→
demo/
composite_demo/src/clients/vllm.py
View file @
67ca83cf
...
...
@@ -8,23 +8,19 @@ installation guide before running this client.
import
time
from
collections.abc
import
Generator
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
,
LLMEngine
,
EngineArgs
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
transformers
import
AutoTokenizer
from
vllm
import
EngineArgs
,
LLMEngine
,
SamplingParams
class
VLLMClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
self
.
engine_args
=
EngineArgs
(
model
=
model_path
,
tensor_parallel_size
=
1
,
dtype
=
"bfloat16"
,
# torch.bfloat16 is needed.
trust_remote_code
=
True
,
gpu_memory_utilization
=
0.6
,
enforce_eager
=
True
,
worker_use_ray
=
False
,
...
...
@@ -35,29 +31,20 @@ class VLLMClient(Client):
self
,
tools
:
list
[
dict
],
history
:
list
[
Conversation
],
**
parameters
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
chat_history
=
process_input
(
history
,
tools
)
model_inputs
=
self
.
tokenizer
.
apply_chat_template
(
chat_history
,
add_generation_prompt
=
True
,
tokenize
=
False
)
model_inputs
=
self
.
tokenizer
.
apply_chat_template
(
chat_history
,
add_generation_prompt
=
True
,
tokenize
=
False
)
parameters
[
"max_tokens"
]
=
parameters
.
pop
(
"max_new_tokens"
)
params_dict
=
{
"n"
:
1
,
"best_of"
:
1
,
"top_p"
:
1
,
"top_k"
:
-
1
,
"use_beam_search"
:
False
,
"length_penalty"
:
1
,
"early_stopping"
:
False
,
"stop_token_ids"
:
[
151329
,
151336
,
151338
],
"ignore_eos"
:
False
,
"logprobs"
:
None
,
"prompt_logprobs"
:
None
,
}
params_dict
.
update
(
parameters
)
sampling_params
=
SamplingParams
(
**
params_dict
)
self
.
engine
.
add_request
(
request_id
=
str
(
time
.
time
()),
inputs
=
model_inputs
,
params
=
sampling_params
)
self
.
engine
.
add_request
(
request_id
=
str
(
time
.
time
()),
inputs
=
model_inputs
,
params
=
sampling_params
)
while
self
.
engine
.
has_unfinished_requests
():
request_outputs
=
self
.
engine
.
step
()
for
request_output
in
request_outputs
:
...
...
composite_demo/src/conversation.py
→
demo/
composite_demo/src/conversation.py
View file @
67ca83cf
...
...
@@ -5,12 +5,11 @@ from datetime import datetime
from
enum
import
Enum
,
auto
import
streamlit
as
st
from
streamlit.delta_generator
import
DeltaGenerator
from
PIL.Image
import
Image
from
streamlit.delta_generator
import
DeltaGenerator
from
tools.browser
import
Quote
,
quotes
QUOTE_REGEX
=
re
.
compile
(
r
"【(\d+)†(.+?)】"
)
SELFCOG_PROMPT
=
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
...
...
@@ -30,7 +29,8 @@ def build_system_prompt(
):
value
=
SELFCOG_PROMPT
value
+=
"
\n\n
"
+
datetime
.
now
().
strftime
(
DATE_PROMPT
)
value
+=
"
\n\n
# 可用工具"
if
enabled_tools
or
functions
:
value
+=
"
\n\n
# 可用工具"
contents
=
[]
for
tool
in
enabled_tools
:
contents
.
append
(
f
"
\n\n
##
{
tool
}
\n\n
{
TOOL_SYSTEM_PROMPTS
[
tool
]
}
"
)
...
...
@@ -130,23 +130,23 @@ class Conversation:
if
self
.
role
!=
Role
.
USER
:
show_text
=
text
else
:
splitted
=
text
.
split
(
'
files uploaded.
\n
'
)
splitted
=
text
.
split
(
"
files uploaded.
\n
"
)
if
len
(
splitted
)
==
1
:
show_text
=
text
else
:
# Show expander for document content
doc
=
splitted
[
0
]
show_text
=
splitted
[
-
1
]
expander
=
message
.
expander
(
f
'
File Content
'
)
expander
=
message
.
expander
(
"
File Content
"
)
expander
.
markdown
(
doc
)
message
.
markdown
(
show_text
)
def
postprocess_text
(
text
:
str
,
replace_quote
:
bool
)
->
str
:
text
=
text
.
replace
(
"\("
,
"$"
)
text
=
text
.
replace
(
"\)"
,
"$"
)
text
=
text
.
replace
(
"\["
,
"$$"
)
text
=
text
.
replace
(
"\]"
,
"$$"
)
text
=
text
.
replace
(
r
"\("
,
"$"
)
text
=
text
.
replace
(
r
"\)"
,
"$"
)
text
=
text
.
replace
(
r
"\["
,
"$$"
)
text
=
text
.
replace
(
r
"\]"
,
"$$"
)
text
=
text
.
replace
(
"<|assistant|>"
,
""
)
text
=
text
.
replace
(
"<|observation|>"
,
""
)
text
=
text
.
replace
(
"<|system|>"
,
""
)
...
...
@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str:
for
match
in
QUOTE_REGEX
.
finditer
(
text
):
quote_id
=
match
.
group
(
1
)
quote
=
quotes
.
get
(
quote_id
,
Quote
(
"未找到引用内容"
,
""
))
text
=
text
.
replace
(
match
.
group
(
0
),
f
" (来源:[
{
quote
.
title
}
](
{
quote
.
url
}
)) "
)
text
=
text
.
replace
(
match
.
group
(
0
),
f
" (来源:[
{
quote
.
title
}
](
{
quote
.
url
}
)) "
)
return
text
.
strip
()
composite_demo/src/main.py
→
demo/
composite_demo/src/main.py
View file @
67ca83cf
...
...
@@ -12,10 +12,6 @@ from io import BytesIO
from
uuid
import
uuid4
import
streamlit
as
st
from
streamlit.delta_generator
import
DeltaGenerator
from
PIL
import
Image
from
client
import
Client
,
ClientType
,
get_client
from
conversation
import
(
FILE_TEMPLATE
,
...
...
@@ -24,14 +20,17 @@ from conversation import (
postprocess_text
,
response_to_str
,
)
from
PIL
import
Image
from
streamlit.delta_generator
import
DeltaGenerator
from
tools.tool_registry
import
dispatch_tool
,
get_tools
from
utils
import
extract_
pdf
,
extract_
docx
,
extract_pptx
,
extract_text
from
utils
import
extract_
docx
,
extract_
pdf
,
extract_pptx
,
extract_text
CHAT_MODEL_PATH
=
os
.
environ
.
get
(
"CHAT_MODEL_PATH"
,
"THUDM/glm-4-9b-chat"
)
VLM_MODEL_PATH
=
os
.
environ
.
get
(
"VLM_MODEL_PATH"
,
"THUDM/glm-4v-9b"
)
USE_VLLM
=
os
.
environ
.
get
(
"USE_VLLM"
,
"0"
)
==
"1"
USE_API
=
os
.
environ
.
get
(
"USE_API"
,
"0"
)
==
"1"
class
Mode
(
str
,
Enum
):
...
...
@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client:
case
Mode
.
ALL_TOOLS
:
st
.
session_state
.
top_k
=
10
typ
=
ClientType
.
VLLM
if
USE_VLLM
else
ClientType
.
HF
typ
=
ClientType
.
API
if
USE_API
else
typ
return
get_client
(
CHAT_MODEL_PATH
,
typ
)
case
Mode
.
LONG_CTX
:
st
.
session_state
.
top_k
=
10
...
...
@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX:
content
=
extract_pptx
(
file_path
)
else
:
content
=
extract_text
(
file_path
)
uploaded_texts
.
append
(
FILE_TEMPLATE
.
format
(
file_name
=
file_name
,
file_content
=
content
)
)
uploaded_texts
.
append
(
FILE_TEMPLATE
.
format
(
file_name
=
file_name
,
file_content
=
content
))
os
.
remove
(
file_path
)
st
.
session_state
.
uploaded_texts
=
"
\n\n
"
.
join
(
uploaded_texts
)
st
.
session_state
.
uploaded_file_nums
=
len
(
uploaded_files
)
...
...
@@ -230,9 +228,7 @@ def main(prompt_text: str):
# Append uploaded files
uploaded_texts
=
st
.
session_state
.
get
(
"uploaded_texts"
)
if
page
==
Mode
.
LONG_CTX
and
uploaded_texts
and
first_round
:
meta_msg
=
"{} files uploaded.
\n
"
.
format
(
st
.
session_state
.
uploaded_file_nums
)
meta_msg
=
"{} files uploaded.
\n
"
.
format
(
st
.
session_state
.
uploaded_file_nums
)
prompt_text
=
uploaded_texts
+
"
\n\n\n
"
+
meta_msg
+
prompt_text
# Clear after first use
st
.
session_state
.
files_uploaded
=
True
...
...
@@ -247,16 +243,12 @@ def main(prompt_text: str):
append_conversation
(
Conversation
(
role
,
prompt_text
,
image
=
image
),
history
)
placeholder
=
st
.
container
()
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
markdown_placeholder
=
message_placeholder
.
empty
()
def
add_new_block
():
nonlocal
message_placeholder
,
markdown_placeholder
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
markdown_placeholder
=
message_placeholder
.
empty
()
def
commit_conversation
(
...
...
@@ -301,33 +293,18 @@ def main(prompt_text: str):
history_len
=
len
(
chat_history
)
last_response
=
response
replace_quote
=
chat_history
[
-
1
][
"role"
]
==
"assistant"
markdown_placeholder
.
markdown
(
postprocess_text
(
str
(
response
)
+
"●"
,
replace_quote
=
replace_quote
)
)
markdown_placeholder
.
markdown
(
postprocess_text
(
str
(
response
)
+
"●"
,
replace_quote
=
replace_quote
))
else
:
metadata
=
(
page
==
Mode
.
ALL_TOOLS
and
isinstance
(
response
,
dict
)
and
response
.
get
(
"name"
)
or
None
)
metadata
=
page
==
Mode
.
ALL_TOOLS
and
isinstance
(
response
,
dict
)
and
response
.
get
(
"name"
)
or
None
role
=
Role
.
TOOL
if
metadata
else
Role
.
ASSISTANT
text
=
(
response
.
get
(
"content"
)
if
metadata
else
response_to_str
(
response
)
)
text
=
response
.
get
(
"content"
)
if
metadata
else
response_to_str
(
response
)
commit_conversation
(
role
,
text
,
metadata
)
if
metadata
:
add_new_block
()
try
:
with
markdown_placeholder
:
with
st
.
spinner
(
f
"Calling tool
{
metadata
}
..."
):
observations
=
dispatch_tool
(
metadata
,
text
,
str
(
st
.
session_state
.
session_id
)
)
observations
=
dispatch_tool
(
metadata
,
text
,
str
(
st
.
session_state
.
session_id
))
except
Exception
as
e
:
traceback
.
print_exc
()
st
.
error
(
f
'Uncaught exception in `"
{
metadata
}
"`:
{
e
}
'
)
...
...
@@ -346,7 +323,7 @@ def main(prompt_text: str):
continue
else
:
break
except
Exception
as
e
:
except
Exception
:
traceback
.
print_exc
()
st
.
error
(
f
"Uncaught exception:
{
traceback
.
format_exc
()
}
"
)
else
:
...
...
composite_demo/src/tools/browser.py
→
demo/
composite_demo/src/tools/browser.py
View file @
67ca83cf
...
...
@@ -6,22 +6,26 @@ Simple browser tool.
Please start the backend browser server according to the instructions in the README.
"""
from
pprint
import
pprint
import
re
from
dataclasses
import
dataclass
from
pprint
import
pprint
import
requests
import
streamlit
as
st
from
dataclasses
import
dataclass
from
.config
import
BROWSER_SERVER_URL
from
.interface
import
ToolObservation
QUOTE_REGEX
=
re
.
compile
(
r
"\[(\d+)†(.+?)\]"
)
@
dataclass
class
Quote
:
title
:
str
url
:
str
# Quotes for displaying reference
if
"quotes"
not
in
st
.
session_state
:
st
.
session_state
.
quotes
=
{}
...
...
@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes
def
map_response
(
response
:
dict
)
->
ToolObservation
:
# Save quotes for reference
print
(
'
===BROWSER_RESPONSE===
'
)
print
(
"
===BROWSER_RESPONSE===
"
)
pprint
(
response
)
role_metadata
=
response
.
get
(
"roleMetadata"
)
metadata
=
response
.
get
(
"metadata"
)
if
role_metadata
.
split
()[
0
]
==
'
quote_result
'
and
metadata
:
if
role_metadata
.
split
()[
0
]
==
"
quote_result
"
and
metadata
:
quote_id
=
QUOTE_REGEX
.
search
(
role_metadata
.
split
()[
1
]).
group
(
1
)
quote
:
dict
[
str
,
str
]
=
metadata
[
'
metadata_list
'
][
0
]
quotes
[
quote_id
]
=
Quote
(
quote
[
'
title
'
],
quote
[
'
url
'
])
elif
role_metadata
==
'
browser_result
'
and
metadata
:
for
i
,
quote
in
enumerate
(
metadata
[
'
metadata_list
'
]):
quotes
[
str
(
i
)]
=
Quote
(
quote
[
'
title
'
],
quote
[
'
url
'
])
quote
:
dict
[
str
,
str
]
=
metadata
[
"
metadata_list
"
][
0
]
quotes
[
quote_id
]
=
Quote
(
quote
[
"
title
"
],
quote
[
"
url
"
])
elif
role_metadata
==
"
browser_result
"
and
metadata
:
for
i
,
quote
in
enumerate
(
metadata
[
"
metadata_list
"
]):
quotes
[
str
(
i
)]
=
Quote
(
quote
[
"
title
"
],
quote
[
"
url
"
])
return
ToolObservation
(
content_type
=
response
.
get
(
"contentType"
),
...
...
composite_demo/src/tools/cogview.py
→
demo/
composite_demo/src/tools/cogview.py
View file @
67ca83cf
...
...
@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage
from
.config
import
COGVIEW_MODEL
,
ZHIPU_AI_KEY
from
.interface
import
ToolObservation
@
st
.
cache_resource
def
get_zhipu_client
():
return
ZhipuAI
(
api_key
=
ZHIPU_AI_KEY
)
def
map_response
(
img
:
GeneratedImage
):
return
ToolObservation
(
content_type
=
'
image
'
,
text
=
'
CogView 已经生成并向用户展示了生成的图片。
'
,
content_type
=
"
image
"
,
text
=
"
CogView 已经生成并向用户展示了生成的图片。
"
,
image_url
=
img
.
url
,
role_metadata
=
'
cogview_result
'
role_metadata
=
"
cogview_result
"
,
)
def
tool_call
(
prompt
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
client
=
get_zhipu_client
()
response
=
client
.
images
.
generations
(
model
=
COGVIEW_MODEL
,
prompt
=
prompt
).
data
...
...
demo/composite_demo/src/tools/config.py
0 → 100644
View file @
67ca83cf
BROWSER_SERVER_URL
=
"http://localhost:3000"
IPYKERNEL
=
"glm-4-demo"
ZHIPU_AI_KEY
=
""
COGVIEW_MODEL
=
"cogview-3"
composite_demo/src/tools/interface.py
→
demo/
composite_demo/src/tools/interface.py
View file @
67ca83cf
from
dataclasses
import
dataclass
from
typing
import
Any
@
dataclass
class
ToolObservation
:
content_type
:
str
...
...
composite_demo/src/tools/python.py
→
demo/
composite_demo/src/tools/python.py
View file @
67ca83cf
from
pprint
import
pprint
import
queue
import
re
from
pprint
import
pprint
from
subprocess
import
PIPE
from
typing
import
Literal
...
...
@@ -10,19 +10,22 @@ import streamlit as st
from
.config
import
IPYKERNEL
from
.interface
import
ToolObservation
ANSI_ESCAPE
=
re
.
compile
(
r
'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]'
)
CODE
=
re
.
compile
(
r
'```([^\n]*)\n(.*?)```'
)
class
CodeKernel
:
def
__init__
(
self
,
kernel_name
=
'kernel'
,
kernel_id
=
None
,
kernel_config_path
=
""
,
python_path
=
None
,
ipython_path
=
None
,
init_file_path
=
"./startup.py"
,
verbose
=
1
):
ANSI_ESCAPE
=
re
.
compile
(
r
"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]"
)
CODE
=
re
.
compile
(
r
"```([^\n]*)\n(.*?)```"
)
class
CodeKernel
:
def
__init__
(
self
,
kernel_name
=
"kernel"
,
kernel_id
=
None
,
kernel_config_path
=
""
,
python_path
=
None
,
ipython_path
=
None
,
init_file_path
=
"./startup.py"
,
verbose
=
1
,
):
self
.
kernel_name
=
kernel_name
self
.
kernel_id
=
kernel_id
self
.
kernel_config_path
=
kernel_config_path
...
...
@@ -37,19 +40,16 @@ class CodeKernel:
env
=
{
"PATH"
:
self
.
python_path
+
":$PATH"
,
"PYTHONPATH"
:
self
.
python_path
}
# Initialize the backend kernel
self
.
kernel_manager
=
jupyter_client
.
KernelManager
(
kernel_name
=
IPYKERNEL
,
connection_file
=
self
.
kernel_config_path
,
exec_files
=
[
self
.
init_file_path
],
env
=
env
)
self
.
kernel_manager
=
jupyter_client
.
KernelManager
(
kernel_name
=
IPYKERNEL
,
connection_file
=
self
.
kernel_config_path
,
exec_files
=
[
self
.
init_file_path
],
env
=
env
)
if
self
.
kernel_config_path
:
self
.
kernel_manager
.
load_connection_file
()
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_config_path
))
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_config_path
))
else
:
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_manager
.
connection_file
))
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_manager
.
connection_file
))
if
verbose
:
pprint
(
self
.
kernel_manager
.
get_connection_info
())
...
...
@@ -64,13 +64,13 @@ class CodeKernel:
self
.
kernel
.
execute
(
code
)
try
:
shell_msg
=
self
.
kernel
.
get_shell_msg
(
timeout
=
30
)
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
'
content
'
]
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
"
content
"
]
while
True
:
msg_out
=
io_msg_content
### Poll the message
try
:
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
'
content
'
]
if
'
execution_state
'
in
io_msg_content
and
io_msg_content
[
'
execution_state
'
]
==
'
idle
'
:
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
"
content
"
]
if
"
execution_state
"
in
io_msg_content
and
io_msg_content
[
"
execution_state
"
]
==
"
idle
"
:
break
except
queue
.
Empty
:
break
...
...
@@ -100,12 +100,12 @@ class CodeKernel:
return
shell_msg
def
get_error_msg
(
self
,
msg
,
verbose
=
False
)
->
str
|
None
:
if
msg
[
'
content
'
][
'
status
'
]
==
'
error
'
:
if
msg
[
"
content
"
][
"
status
"
]
==
"
error
"
:
try
:
error_msg
=
msg
[
'
content
'
][
'
traceback
'
]
error_msg
=
msg
[
"
content
"
][
"
traceback
"
]
except
:
try
:
error_msg
=
msg
[
'
content
'
][
'
traceback
'
][
-
1
].
strip
()
error_msg
=
msg
[
"
content
"
][
"
traceback
"
][
-
1
].
strip
()
except
:
error_msg
=
"Traceback Error"
if
verbose
:
...
...
@@ -114,12 +114,12 @@ class CodeKernel:
return
None
def
check_msg
(
self
,
msg
,
verbose
=
False
):
status
=
msg
[
'
content
'
][
'
status
'
]
if
status
==
'
ok
'
:
status
=
msg
[
"
content
"
][
"
status
"
]
if
status
==
"
ok
"
:
if
verbose
:
print
(
"Execution succeeded."
)
elif
status
==
'
error
'
:
for
line
in
msg
[
'
content
'
][
'
traceback
'
]:
elif
status
==
"
error
"
:
for
line
in
msg
[
"
content
"
][
"
traceback
"
]:
if
verbose
:
print
(
line
)
...
...
@@ -144,17 +144,17 @@ class CodeKernel:
def
is_alive
(
self
):
return
self
.
kernel
.
is_alive
()
def
clean_ansi_codes
(
input_string
):
return
ANSI_ESCAPE
.
sub
(
''
,
input_string
)
return
ANSI_ESCAPE
.
sub
(
""
,
input_string
)
def
extract_code
(
text
:
str
)
->
str
:
matches
=
CODE
.
findall
(
text
,
re
.
DOTALL
)
return
matches
[
-
1
][
1
]
def
execute
(
code
:
str
,
kernel
:
CodeKernel
)
->
tuple
[
Literal
[
'text'
,
'image'
]
|
None
,
str
]:
def
execute
(
code
:
str
,
kernel
:
CodeKernel
)
->
tuple
[
Literal
[
"text"
,
"image"
]
|
None
,
str
]:
res
=
""
res_type
=
None
code
=
code
.
replace
(
"<|observation|>"
,
""
)
...
...
@@ -164,37 +164,38 @@ def execute(
code
=
code
.
replace
(
"<|system|>"
,
""
)
msg
,
output
=
kernel
.
execute
(
code
)
if
msg
[
'
metadata
'
][
'
status
'
]
==
"timeout"
:
return
res_type
,
'
Timed out
'
elif
msg
[
'
metadata
'
][
'
status
'
]
==
'
error
'
:
return
res_type
,
clean_ansi_codes
(
'
\n
'
.
join
(
kernel
.
get_error_msg
(
msg
,
verbose
=
True
)))
if
msg
[
"
metadata
"
][
"
status
"
]
==
"timeout"
:
return
res_type
,
"
Timed out
"
elif
msg
[
"
metadata
"
][
"
status
"
]
==
"
error
"
:
return
res_type
,
clean_ansi_codes
(
"
\n
"
.
join
(
kernel
.
get_error_msg
(
msg
,
verbose
=
True
)))
if
'
text
'
in
output
:
if
"
text
"
in
output
:
res_type
=
"text"
res
=
output
[
'
text
'
]
elif
'
data
'
in
output
:
for
key
in
output
[
'
data
'
]:
if
'
text/plain
'
in
key
:
res
=
output
[
"
text
"
]
elif
"
data
"
in
output
:
for
key
in
output
[
"
data
"
]:
if
"
text/plain
"
in
key
:
res_type
=
"text"
res
=
output
[
'
data
'
][
key
]
elif
'
image/png
'
in
key
:
res
=
output
[
"
data
"
][
key
]
elif
"
image/png
"
in
key
:
res_type
=
"image"
res
=
output
[
'
data
'
][
key
]
res
=
output
[
"
data
"
][
key
]
break
return
res_type
,
res
@
st
.
cache_resource
def
get_kernel
()
->
CodeKernel
:
return
CodeKernel
()
def
tool_call
(
code
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
kernel
=
get_kernel
()
res_type
,
res
=
execute
(
code
,
kernel
)
# Convert base64 to data uri
text
=
'
[Image]
'
if
res_type
==
'
image
'
else
res
image
=
f
'
data:image/png;base64,
{
res
}
'
if
res_type
==
'
image
'
else
None
text
=
"
[Image]
"
if
res_type
==
"
image
"
else
res
image
=
f
"
data:image/png;base64,
{
res
}
"
if
res_type
==
"
image
"
else
None
return
[
ToolObservation
(
res_type
,
text
,
image
)]
composite_demo/src/tools/tool_registry.py
→
demo/
composite_demo/src/tools/tool_registry.py
View file @
67ca83cf
...
...
@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and
through defined interfaces.
"""
from
collections.abc
import
Callable
import
copy
import
inspect
import
json
from
pprint
import
pformat
import
subprocess
import
traceback
from
collections.abc
import
Callable
from
types
import
GenericAlias
from
typing
import
get_origin
,
Annotated
import
subprocess
from
.interface
import
ToolObservation
from
typing
import
Annotated
,
get_origin
from
.browser
import
tool_call
as
browser
from
.cogview
import
tool_call
as
cogview
from
.interface
import
ToolObservation
from
.python
import
tool_call
as
python
ALL_TOOLS
=
{
"simple_browser"
:
browser
,
"python"
:
python
,
...
...
@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv
# Dispatch predefined tools
if
tool_name
in
ALL_TOOLS
:
return
ALL_TOOLS
[
tool_name
](
code
,
session_id
)
code
=
code
.
strip
().
rstrip
(
'
<|observation|>
'
).
strip
()
code
=
code
.
strip
().
rstrip
(
"
<|observation|>
"
).
strip
()
# Dispatch custom tools
try
:
...
...
@@ -105,8 +104,8 @@ def get_tools() -> list[dict]:
@
register_tool
def
random_number_generator
(
seed
:
Annotated
[
int
,
"The random seed used by the generator"
,
True
],
range
:
Annotated
[
tuple
[
int
,
int
],
"The range of the generated numbers"
,
True
],
seed
:
Annotated
[
int
,
"The random seed used by the generator"
,
True
],
range
:
Annotated
[
tuple
[
int
,
int
],
"The range of the generated numbers"
,
True
],
)
->
int
:
"""
Generates a random number x, s.t. range[0] <= x < range[1]
...
...
@@ -125,7 +124,7 @@ def random_number_generator(
@
register_tool
def
get_weather
(
city_name
:
Annotated
[
str
,
"The name of the city to be queried"
,
True
],
city_name
:
Annotated
[
str
,
"The name of the city to be queried"
,
True
],
)
->
str
:
"""
Get the current weather for `city_name`
...
...
@@ -153,16 +152,14 @@ def get_weather(
except
:
import
traceback
ret
=
(
"Error encountered while fetching weather data!
\n
"
+
traceback
.
format_exc
()
)
ret
=
"Error encountered while fetching weather data!
\n
"
+
traceback
.
format_exc
()
return
str
(
ret
)
@
register_tool
def
get_shell
(
query
:
Annotated
[
str
,
"The command should run in Linux shell"
,
True
],
query
:
Annotated
[
str
,
"The command should run in Linux shell"
,
True
],
)
->
str
:
"""
Use shell to run command
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment