Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GLM-4_pytorch
Commits
67ca83cf
Commit
67ca83cf
authored
Apr 17, 2025
by
Rayyyyy
Browse files
Support GLM-4-0414
parent
78ba9d16
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
239 additions
and
163 deletions
+239
-163
demo/composite_demo/browser/pnpm-lock.yaml
demo/composite_demo/browser/pnpm-lock.yaml
+0
-0
demo/composite_demo/browser/src/browser.ts
demo/composite_demo/browser/src/browser.ts
+11
-4
demo/composite_demo/browser/src/config.ts
demo/composite_demo/browser/src/config.ts
+9
-0
demo/composite_demo/browser/src/server.ts
demo/composite_demo/browser/src/server.ts
+1
-1
demo/composite_demo/browser/src/types.ts
demo/composite_demo/browser/src/types.ts
+0
-0
demo/composite_demo/browser/src/utils.ts
demo/composite_demo/browser/src/utils.ts
+1
-1
demo/composite_demo/browser/tsconfig.json
demo/composite_demo/browser/tsconfig.json
+0
-0
demo/composite_demo/requirements.txt
demo/composite_demo/requirements.txt
+15
-0
demo/composite_demo/src/client.py
demo/composite_demo/src/client.py
+10
-6
demo/composite_demo/src/clients/hf.py
demo/composite_demo/src/clients/hf.py
+3
-5
demo/composite_demo/src/clients/openai.py
demo/composite_demo/src/clients/openai.py
+69
-0
demo/composite_demo/src/clients/vllm.py
demo/composite_demo/src/clients/vllm.py
+5
-18
demo/composite_demo/src/conversation.py
demo/composite_demo/src/conversation.py
+11
-13
demo/composite_demo/src/main.py
demo/composite_demo/src/main.py
+14
-37
demo/composite_demo/src/tools/browser.py
demo/composite_demo/src/tools/browser.py
+14
-10
demo/composite_demo/src/tools/cogview.py
demo/composite_demo/src/tools/cogview.py
+6
-3
demo/composite_demo/src/tools/config.py
demo/composite_demo/src/tools/config.py
+6
-0
demo/composite_demo/src/tools/interface.py
demo/composite_demo/src/tools/interface.py
+1
-0
demo/composite_demo/src/tools/python.py
demo/composite_demo/src/tools/python.py
+51
-50
demo/composite_demo/src/tools/tool_registry.py
demo/composite_demo/src/tools/tool_registry.py
+12
-15
No files found.
composite_demo/browser/pnpm-lock.yaml
→
demo/
composite_demo/browser/pnpm-lock.yaml
View file @
67ca83cf
File moved
composite_demo/browser/src/browser.ts
→
demo/
composite_demo/browser/src/browser.ts
View file @
67ca83cf
...
@@ -172,13 +172,20 @@ abstract class BaseBrowser {
...
@@ -172,13 +172,20 @@ abstract class BaseBrowser {
logger
.
debug
(
`Searching for:
${
query
}
`
);
logger
.
debug
(
`Searching for:
${
query
}
`
);
const
search
=
new
URLSearchParams
({
q
:
query
});
const
search
=
new
URLSearchParams
({
q
:
query
});
recency_days
>
0
&&
search
.
append
(
'
recency_days
'
,
recency_days
.
toString
());
recency_days
>
0
&&
search
.
append
(
'
recency_days
'
,
recency_days
.
toString
());
if
(
config
.
CUSTOM_CONFIG_ID
)
{
search
.
append
(
'
customconfig
'
,
config
.
CUSTOM_CONFIG_ID
.
toString
());
}
const
url
=
`
${
config
.
BING_SEARCH_API_URL
}
/search?
${
search
.
toString
()}
`
;
console
.
log
(
'
Full URL:
'
,
url
);
// 输出完整的 URL查看是否正确
return
withTimeout
(
return
withTimeout
(
config
.
BROWSER_TIMEOUT
,
config
.
BROWSER_TIMEOUT
,
fetch
(
`
${
config
.
BING_SEARCH_API_URL
}
/search?
${
search
.
toString
()}
`
,
{
fetch
(
url
,
{
headers
:
{
headers
:
{
'
Ocp-Apim-Subscription-Key
'
:
config
.
BING_SEARCH_API_KEY
,
'
Ocp-Apim-Subscription-Key
'
:
config
.
BING_SEARCH_API_KEY
,
}
}
}).
then
(
})
.
then
(
res
=>
res
=>
res
.
json
()
as
Promise
<
{
res
.
json
()
as
Promise
<
{
queryContext
:
{
queryContext
:
{
...
@@ -255,11 +262,11 @@ abstract class BaseBrowser {
...
@@ -255,11 +262,11 @@ abstract class BaseBrowser {
}
}
})
})
.
catch
(
err
=>
{
.
catch
(
err
=>
{
logger
.
error
(
err
.
message
);
logger
.
error
(
`搜索请求失败:
${
query
}
,错误信息:
${
err
.
message
}
`
);
if
(
err
.
code
===
'
ECONNABORTED
'
)
{
if
(
err
.
code
===
'
ECONNABORTED
'
)
{
throw
new
Error
(
`Timeout while executing search for:
${
query
}
`
);
throw
new
Error
(
`Timeout while executing search for:
${
query
}
`
);
}
}
throw
new
Error
(
`
Network or server error occurred
`
);
throw
new
Error
(
`
网络或服务器发生错误,请检查URL:
${
url
}
`
);
});
});
},
},
open_url
:
(
url
:
string
)
=>
{
open_url
:
(
url
:
string
)
=>
{
...
...
composite_demo/browser/src/config.ts
→
demo/
composite_demo/browser/src/config.ts
View file @
67ca83cf
export
default
{
export
default
{
LOG_LEVEL
:
'
debug
'
,
LOG_LEVEL
:
'
debug
'
,
BROWSER_TIMEOUT
:
10000
,
BROWSER_TIMEOUT
:
10000
,
BING_SEARCH_API_URL
:
'
https://api.bing.microsoft.com/
'
,
BING_SEARCH_API_URL
:
'
https://api.bing.microsoft.com/
v7.0/custom/
'
,
BING_SEARCH_API_KEY
:
''
,
BING_SEARCH_API_KEY
:
'
YOUR_BING_SEARCH_API_KEY
'
,
CUSTOM_CONFIG_ID
:
'
YOUR_CUSTOM_CONFIG_ID
'
,
//将您的Custom Configuration ID放在此处
HOST
:
'
localhost
'
,
HOST
:
'
localhost
'
,
PORT
:
3000
,
PORT
:
3000
,
};
};
\ No newline at end of file
composite_demo/browser/src/server.ts
→
demo/
composite_demo/browser/src/server.ts
View file @
67ca83cf
...
@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => {
...
@@ -20,7 +20,7 @@ app.post('/', async (req: Request, res: Response) => {
}
=
req
.
body
;
}
=
req
.
body
;
logger
.
info
(
`session_id:
${
session_id
}
`
);
logger
.
info
(
`session_id:
${
session_id
}
`
);
logger
.
info
(
`action:
${
action
}
`
);
logger
.
info
(
`action:
${
action
}
`
);
if
(
!
session_history
[
session_id
])
{
if
(
!
session_history
[
session_id
])
{
session_history
[
session_id
]
=
new
SimpleBrowser
();
session_history
[
session_id
]
=
new
SimpleBrowser
();
}
}
...
...
composite_demo/browser/src/types.ts
→
demo/
composite_demo/browser/src/types.ts
View file @
67ca83cf
File moved
composite_demo/browser/src/utils.ts
→
demo/
composite_demo/browser/src/utils.ts
View file @
67ca83cf
...
@@ -53,4 +53,4 @@ export const withTimeout = <T>(
...
@@ -53,4 +53,4 @@ export const withTimeout = <T>(
setTimeout
(()
=>
reject
(
new
TimeoutError
()),
millis
)
setTimeout
(()
=>
reject
(
new
TimeoutError
()),
millis
)
);
);
return
Promise
.
race
([
promiseWithTime
(
promise
),
timeout
]);
return
Promise
.
race
([
promiseWithTime
(
promise
),
timeout
]);
};
};
\ No newline at end of file
composite_demo/browser/tsconfig.json
→
demo/
composite_demo/browser/tsconfig.json
View file @
67ca83cf
File moved
demo/composite_demo/requirements.txt
0 → 100644
View file @
67ca83cf
# Please install the requirments.txt in inference first!
ipykernel>=6.26.0
ipython>=8.18.1
jupyter_client>=8.6.0
langchain>=0.2.12
langchain-community>=0.2.11
matplotlib>=3.9.1
pymupdf>=1.24.9
python-docx>=1.1.2
python-pptx>=0.6.23
pyyaml>=6.0.1
requests>=2.31.0
streamlit>=1.37.1
zhipuai>=2.1.4
composite_demo/src/client.py
→
demo/
composite_demo/src/client.py
View file @
67ca83cf
...
@@ -13,7 +13,6 @@ from enum import Enum, auto
...
@@ -13,7 +13,6 @@ from enum import Enum, auto
from
typing
import
Protocol
from
typing
import
Protocol
import
streamlit
as
st
import
streamlit
as
st
from
conversation
import
Conversation
,
build_system_prompt
from
conversation
import
Conversation
,
build_system_prompt
from
tools.tool_registry
import
ALL_TOOLS
from
tools.tool_registry
import
ALL_TOOLS
...
@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS
...
@@ -21,6 +20,7 @@ from tools.tool_registry import ALL_TOOLS
class
ClientType
(
Enum
):
class
ClientType
(
Enum
):
HF
=
auto
()
HF
=
auto
()
VLLM
=
auto
()
VLLM
=
auto
()
API
=
auto
()
class
Client
(
Protocol
):
class
Client
(
Protocol
):
...
@@ -34,15 +34,15 @@ class Client(Protocol):
...
@@ -34,15 +34,15 @@ class Client(Protocol):
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
...
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
...
def
process_input
(
history
:
list
[
dict
],
tools
:
list
[
dict
])
->
list
[
dict
]:
def
process_input
(
history
:
list
[
dict
],
tools
:
list
[
dict
]
,
role_name_replace
:
dict
=
None
)
->
list
[
dict
]:
chat_history
=
[]
chat_history
=
[]
if
len
(
tools
)
>
0
:
# if len(tools) > 0:
chat_history
.
append
(
chat_history
.
append
({
"role"
:
"system"
,
"content"
:
build_system_prompt
(
list
(
ALL_TOOLS
),
tools
)})
{
"role"
:
"system"
,
"content"
:
build_system_prompt
(
list
(
ALL_TOOLS
),
tools
)}
)
for
conversation
in
history
:
for
conversation
in
history
:
role
=
str
(
conversation
.
role
).
removeprefix
(
"<|"
).
removesuffix
(
"|>"
)
role
=
str
(
conversation
.
role
).
removeprefix
(
"<|"
).
removesuffix
(
"|>"
)
if
role_name_replace
:
role
=
role_name_replace
.
get
(
role
,
role
)
item
=
{
item
=
{
"role"
:
role
,
"role"
:
role
,
"content"
:
conversation
.
content
,
"content"
:
conversation
.
content
,
...
@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client:
...
@@ -94,5 +94,9 @@ def get_client(model_path, typ: ClientType) -> Client:
e
.
msg
+=
"; did you forget to install vLLM?"
e
.
msg
+=
"; did you forget to install vLLM?"
raise
raise
return
VLLMClient
(
model_path
)
return
VLLMClient
(
model_path
)
case
ClientType
.
API
:
from
clients.openai
import
APIClient
return
APIClient
(
model_path
)
raise
NotImplementedError
(
f
"Client type
{
typ
}
is not supported."
)
raise
NotImplementedError
(
f
"Client type
{
typ
}
is not supported."
)
composite_demo/src/clients/hf.py
→
demo/
composite_demo/src/clients/hf.py
View file @
67ca83cf
...
@@ -2,25 +2,23 @@
...
@@ -2,25 +2,23 @@
HuggingFace client.
HuggingFace client.
"""
"""
import
threading
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
threading
import
Thread
from
threading
import
Thread
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
TextIteratorStreamer
from
client
import
Client
,
process_input
,
process_response
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
conversation
import
Conversation
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
TextIteratorStreamer
class
HFClient
(
Client
):
class
HFClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
def
__init__
(
self
,
model_path
:
str
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
,
model_path
,
trust_remote_code
=
True
,
)
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
model_path
,
trust_remote_code
=
True
,
torch_dtype
=
torch
.
bfloat16
,
torch_dtype
=
torch
.
bfloat16
,
device_map
=
"cuda"
,
device_map
=
"cuda"
,
).
eval
()
).
eval
()
...
...
demo/composite_demo/src/clients/openai.py
0 → 100644
View file @
67ca83cf
"""
OpenAI API client.
"""
from
collections.abc
import
Generator
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
openai
import
OpenAI
def
format_openai_tool
(
origin_tools
):
openai_tools
=
[]
for
tool
in
origin_tools
:
openai_param
=
{}
for
param
in
tool
[
"params"
]:
openai_param
[
param
[
"name"
]]
=
{}
openai_tool
=
{
"type"
:
"function"
,
"function"
:
{
"name"
:
tool
[
"name"
],
"description"
:
tool
[
"description"
],
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
param
[
"name"
]:
{
"type"
:
param
[
"type"
],
"description"
:
param
[
"description"
]}
for
param
in
tool
[
"params"
]
},
"required"
:
[
param
[
"name"
]
for
param
in
tool
[
"params"
]
if
param
[
"required"
]],
},
},
}
openai_tools
.
append
(
openai_tool
)
return
openai_tools
class
APIClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
base_url
=
"http://127.0.0.1:8000/v1/"
self
.
client
=
OpenAI
(
api_key
=
"EMPTY"
,
base_url
=
base_url
)
self
.
use_stream
=
False
self
.
role_name_replace
=
{
"observation"
:
"tool"
}
def
generate_stream
(
self
,
tools
:
list
[
dict
],
history
:
list
[
Conversation
],
**
parameters
,
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
chat_history
=
process_input
(
history
,
""
,
role_name_replace
=
self
.
role_name_replace
)
# messages = process_input(history, '', role_name_replace=self.role_name_replace)
openai_tools
=
format_openai_tool
(
tools
)
response
=
self
.
client
.
chat
.
completions
.
create
(
model
=
"glm-4"
,
messages
=
chat_history
,
tools
=
openai_tools
,
stream
=
self
.
use_stream
,
max_tokens
=
parameters
[
"max_new_tokens"
],
temperature
=
parameters
[
"temperature"
],
presence_penalty
=
1.2
,
top_p
=
parameters
[
"top_p"
],
tool_choice
=
"auto"
,
)
output
=
response
.
choices
[
0
].
message
if
output
.
tool_calls
:
glm4_output
=
output
.
tool_calls
[
0
].
function
.
name
+
"
\n
"
+
output
.
tool_calls
[
0
].
function
.
arguments
else
:
glm4_output
=
output
.
content
yield
process_response
(
glm4_output
,
chat_history
)
composite_demo/src/clients/vllm.py
→
demo/
composite_demo/src/clients/vllm.py
View file @
67ca83cf
...
@@ -8,23 +8,19 @@ installation guide before running this client.
...
@@ -8,23 +8,19 @@ installation guide before running this client.
import
time
import
time
from
collections.abc
import
Generator
from
collections.abc
import
Generator
from
transformers
import
AutoTokenizer
from
vllm
import
SamplingParams
,
LLMEngine
,
EngineArgs
from
client
import
Client
,
process_input
,
process_response
from
client
import
Client
,
process_input
,
process_response
from
conversation
import
Conversation
from
conversation
import
Conversation
from
transformers
import
AutoTokenizer
from
vllm
import
EngineArgs
,
LLMEngine
,
SamplingParams
class
VLLMClient
(
Client
):
class
VLLMClient
(
Client
):
def
__init__
(
self
,
model_path
:
str
):
def
__init__
(
self
,
model_path
:
str
):
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
model_path
,
trust_remote_code
=
True
)
self
.
engine_args
=
EngineArgs
(
self
.
engine_args
=
EngineArgs
(
model
=
model_path
,
model
=
model_path
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
dtype
=
"bfloat16"
,
# torch.bfloat16 is needed.
dtype
=
"bfloat16"
,
# torch.bfloat16 is needed.
trust_remote_code
=
True
,
gpu_memory_utilization
=
0.6
,
gpu_memory_utilization
=
0.6
,
enforce_eager
=
True
,
enforce_eager
=
True
,
worker_use_ray
=
False
,
worker_use_ray
=
False
,
...
@@ -35,29 +31,20 @@ class VLLMClient(Client):
...
@@ -35,29 +31,20 @@ class VLLMClient(Client):
self
,
tools
:
list
[
dict
],
history
:
list
[
Conversation
],
**
parameters
self
,
tools
:
list
[
dict
],
history
:
list
[
Conversation
],
**
parameters
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
)
->
Generator
[
tuple
[
str
|
dict
,
list
[
dict
]]]:
chat_history
=
process_input
(
history
,
tools
)
chat_history
=
process_input
(
history
,
tools
)
model_inputs
=
self
.
tokenizer
.
apply_chat_template
(
model_inputs
=
self
.
tokenizer
.
apply_chat_template
(
chat_history
,
add_generation_prompt
=
True
,
tokenize
=
False
)
chat_history
,
add_generation_prompt
=
True
,
tokenize
=
False
)
parameters
[
"max_tokens"
]
=
parameters
.
pop
(
"max_new_tokens"
)
parameters
[
"max_tokens"
]
=
parameters
.
pop
(
"max_new_tokens"
)
params_dict
=
{
params_dict
=
{
"n"
:
1
,
"n"
:
1
,
"best_of"
:
1
,
"best_of"
:
1
,
"top_p"
:
1
,
"top_p"
:
1
,
"top_k"
:
-
1
,
"top_k"
:
-
1
,
"use_beam_search"
:
False
,
"length_penalty"
:
1
,
"length_penalty"
:
1
,
"early_stopping"
:
False
,
"stop_token_ids"
:
[
151329
,
151336
,
151338
],
"stop_token_ids"
:
[
151329
,
151336
,
151338
],
"ignore_eos"
:
False
,
"logprobs"
:
None
,
"prompt_logprobs"
:
None
,
}
}
params_dict
.
update
(
parameters
)
params_dict
.
update
(
parameters
)
sampling_params
=
SamplingParams
(
**
params_dict
)
sampling_params
=
SamplingParams
(
**
params_dict
)
self
.
engine
.
add_request
(
self
.
engine
.
add_request
(
request_id
=
str
(
time
.
time
()),
inputs
=
model_inputs
,
params
=
sampling_params
)
request_id
=
str
(
time
.
time
()),
inputs
=
model_inputs
,
params
=
sampling_params
)
while
self
.
engine
.
has_unfinished_requests
():
while
self
.
engine
.
has_unfinished_requests
():
request_outputs
=
self
.
engine
.
step
()
request_outputs
=
self
.
engine
.
step
()
for
request_output
in
request_outputs
:
for
request_output
in
request_outputs
:
...
...
composite_demo/src/conversation.py
→
demo/
composite_demo/src/conversation.py
View file @
67ca83cf
...
@@ -5,12 +5,11 @@ from datetime import datetime
...
@@ -5,12 +5,11 @@ from datetime import datetime
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
import
streamlit
as
st
import
streamlit
as
st
from
streamlit.delta_generator
import
DeltaGenerator
from
PIL.Image
import
Image
from
PIL.Image
import
Image
from
streamlit.delta_generator
import
DeltaGenerator
from
tools.browser
import
Quote
,
quotes
from
tools.browser
import
Quote
,
quotes
QUOTE_REGEX
=
re
.
compile
(
r
"【(\d+)†(.+?)】"
)
QUOTE_REGEX
=
re
.
compile
(
r
"【(\d+)†(.+?)】"
)
SELFCOG_PROMPT
=
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
SELFCOG_PROMPT
=
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。"
...
@@ -30,7 +29,8 @@ def build_system_prompt(
...
@@ -30,7 +29,8 @@ def build_system_prompt(
):
):
value
=
SELFCOG_PROMPT
value
=
SELFCOG_PROMPT
value
+=
"
\n\n
"
+
datetime
.
now
().
strftime
(
DATE_PROMPT
)
value
+=
"
\n\n
"
+
datetime
.
now
().
strftime
(
DATE_PROMPT
)
value
+=
"
\n\n
# 可用工具"
if
enabled_tools
or
functions
:
value
+=
"
\n\n
# 可用工具"
contents
=
[]
contents
=
[]
for
tool
in
enabled_tools
:
for
tool
in
enabled_tools
:
contents
.
append
(
f
"
\n\n
##
{
tool
}
\n\n
{
TOOL_SYSTEM_PROMPTS
[
tool
]
}
"
)
contents
.
append
(
f
"
\n\n
##
{
tool
}
\n\n
{
TOOL_SYSTEM_PROMPTS
[
tool
]
}
"
)
...
@@ -130,23 +130,23 @@ class Conversation:
...
@@ -130,23 +130,23 @@ class Conversation:
if
self
.
role
!=
Role
.
USER
:
if
self
.
role
!=
Role
.
USER
:
show_text
=
text
show_text
=
text
else
:
else
:
splitted
=
text
.
split
(
'
files uploaded.
\n
'
)
splitted
=
text
.
split
(
"
files uploaded.
\n
"
)
if
len
(
splitted
)
==
1
:
if
len
(
splitted
)
==
1
:
show_text
=
text
show_text
=
text
else
:
else
:
# Show expander for document content
# Show expander for document content
doc
=
splitted
[
0
]
doc
=
splitted
[
0
]
show_text
=
splitted
[
-
1
]
show_text
=
splitted
[
-
1
]
expander
=
message
.
expander
(
f
'
File Content
'
)
expander
=
message
.
expander
(
"
File Content
"
)
expander
.
markdown
(
doc
)
expander
.
markdown
(
doc
)
message
.
markdown
(
show_text
)
message
.
markdown
(
show_text
)
def
postprocess_text
(
text
:
str
,
replace_quote
:
bool
)
->
str
:
def
postprocess_text
(
text
:
str
,
replace_quote
:
bool
)
->
str
:
text
=
text
.
replace
(
"\("
,
"$"
)
text
=
text
.
replace
(
r
"\("
,
"$"
)
text
=
text
.
replace
(
"\)"
,
"$"
)
text
=
text
.
replace
(
r
"\)"
,
"$"
)
text
=
text
.
replace
(
"\["
,
"$$"
)
text
=
text
.
replace
(
r
"\["
,
"$$"
)
text
=
text
.
replace
(
"\]"
,
"$$"
)
text
=
text
.
replace
(
r
"\]"
,
"$$"
)
text
=
text
.
replace
(
"<|assistant|>"
,
""
)
text
=
text
.
replace
(
"<|assistant|>"
,
""
)
text
=
text
.
replace
(
"<|observation|>"
,
""
)
text
=
text
.
replace
(
"<|observation|>"
,
""
)
text
=
text
.
replace
(
"<|system|>"
,
""
)
text
=
text
.
replace
(
"<|system|>"
,
""
)
...
@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str:
...
@@ -158,8 +158,6 @@ def postprocess_text(text: str, replace_quote: bool) -> str:
for
match
in
QUOTE_REGEX
.
finditer
(
text
):
for
match
in
QUOTE_REGEX
.
finditer
(
text
):
quote_id
=
match
.
group
(
1
)
quote_id
=
match
.
group
(
1
)
quote
=
quotes
.
get
(
quote_id
,
Quote
(
"未找到引用内容"
,
""
))
quote
=
quotes
.
get
(
quote_id
,
Quote
(
"未找到引用内容"
,
""
))
text
=
text
.
replace
(
text
=
text
.
replace
(
match
.
group
(
0
),
f
" (来源:[
{
quote
.
title
}
](
{
quote
.
url
}
)) "
)
match
.
group
(
0
),
f
" (来源:[
{
quote
.
title
}
](
{
quote
.
url
}
)) "
)
return
text
.
strip
()
return
text
.
strip
()
composite_demo/src/main.py
→
demo/
composite_demo/src/main.py
View file @
67ca83cf
...
@@ -12,10 +12,6 @@ from io import BytesIO
...
@@ -12,10 +12,6 @@ from io import BytesIO
from
uuid
import
uuid4
from
uuid
import
uuid4
import
streamlit
as
st
import
streamlit
as
st
from
streamlit.delta_generator
import
DeltaGenerator
from
PIL
import
Image
from
client
import
Client
,
ClientType
,
get_client
from
client
import
Client
,
ClientType
,
get_client
from
conversation
import
(
from
conversation
import
(
FILE_TEMPLATE
,
FILE_TEMPLATE
,
...
@@ -24,14 +20,17 @@ from conversation import (
...
@@ -24,14 +20,17 @@ from conversation import (
postprocess_text
,
postprocess_text
,
response_to_str
,
response_to_str
,
)
)
from
PIL
import
Image
from
streamlit.delta_generator
import
DeltaGenerator
from
tools.tool_registry
import
dispatch_tool
,
get_tools
from
tools.tool_registry
import
dispatch_tool
,
get_tools
from
utils
import
extract_
pdf
,
extract_
docx
,
extract_pptx
,
extract_text
from
utils
import
extract_
docx
,
extract_
pdf
,
extract_pptx
,
extract_text
CHAT_MODEL_PATH
=
os
.
environ
.
get
(
"CHAT_MODEL_PATH"
,
"THUDM/glm-4-9b-chat"
)
CHAT_MODEL_PATH
=
os
.
environ
.
get
(
"CHAT_MODEL_PATH"
,
"THUDM/glm-4-9b-chat"
)
VLM_MODEL_PATH
=
os
.
environ
.
get
(
"VLM_MODEL_PATH"
,
"THUDM/glm-4v-9b"
)
VLM_MODEL_PATH
=
os
.
environ
.
get
(
"VLM_MODEL_PATH"
,
"THUDM/glm-4v-9b"
)
USE_VLLM
=
os
.
environ
.
get
(
"USE_VLLM"
,
"0"
)
==
"1"
USE_VLLM
=
os
.
environ
.
get
(
"USE_VLLM"
,
"0"
)
==
"1"
USE_API
=
os
.
environ
.
get
(
"USE_API"
,
"0"
)
==
"1"
class
Mode
(
str
,
Enum
):
class
Mode
(
str
,
Enum
):
...
@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client:
...
@@ -104,6 +103,7 @@ def build_client(mode: Mode) -> Client:
case
Mode
.
ALL_TOOLS
:
case
Mode
.
ALL_TOOLS
:
st
.
session_state
.
top_k
=
10
st
.
session_state
.
top_k
=
10
typ
=
ClientType
.
VLLM
if
USE_VLLM
else
ClientType
.
HF
typ
=
ClientType
.
VLLM
if
USE_VLLM
else
ClientType
.
HF
typ
=
ClientType
.
API
if
USE_API
else
typ
return
get_client
(
CHAT_MODEL_PATH
,
typ
)
return
get_client
(
CHAT_MODEL_PATH
,
typ
)
case
Mode
.
LONG_CTX
:
case
Mode
.
LONG_CTX
:
st
.
session_state
.
top_k
=
10
st
.
session_state
.
top_k
=
10
...
@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX:
...
@@ -169,9 +169,7 @@ if page == Mode.LONG_CTX:
content
=
extract_pptx
(
file_path
)
content
=
extract_pptx
(
file_path
)
else
:
else
:
content
=
extract_text
(
file_path
)
content
=
extract_text
(
file_path
)
uploaded_texts
.
append
(
uploaded_texts
.
append
(
FILE_TEMPLATE
.
format
(
file_name
=
file_name
,
file_content
=
content
))
FILE_TEMPLATE
.
format
(
file_name
=
file_name
,
file_content
=
content
)
)
os
.
remove
(
file_path
)
os
.
remove
(
file_path
)
st
.
session_state
.
uploaded_texts
=
"
\n\n
"
.
join
(
uploaded_texts
)
st
.
session_state
.
uploaded_texts
=
"
\n\n
"
.
join
(
uploaded_texts
)
st
.
session_state
.
uploaded_file_nums
=
len
(
uploaded_files
)
st
.
session_state
.
uploaded_file_nums
=
len
(
uploaded_files
)
...
@@ -230,9 +228,7 @@ def main(prompt_text: str):
...
@@ -230,9 +228,7 @@ def main(prompt_text: str):
# Append uploaded files
# Append uploaded files
uploaded_texts
=
st
.
session_state
.
get
(
"uploaded_texts"
)
uploaded_texts
=
st
.
session_state
.
get
(
"uploaded_texts"
)
if
page
==
Mode
.
LONG_CTX
and
uploaded_texts
and
first_round
:
if
page
==
Mode
.
LONG_CTX
and
uploaded_texts
and
first_round
:
meta_msg
=
"{} files uploaded.
\n
"
.
format
(
meta_msg
=
"{} files uploaded.
\n
"
.
format
(
st
.
session_state
.
uploaded_file_nums
)
st
.
session_state
.
uploaded_file_nums
)
prompt_text
=
uploaded_texts
+
"
\n\n\n
"
+
meta_msg
+
prompt_text
prompt_text
=
uploaded_texts
+
"
\n\n\n
"
+
meta_msg
+
prompt_text
# Clear after first use
# Clear after first use
st
.
session_state
.
files_uploaded
=
True
st
.
session_state
.
files_uploaded
=
True
...
@@ -247,16 +243,12 @@ def main(prompt_text: str):
...
@@ -247,16 +243,12 @@ def main(prompt_text: str):
append_conversation
(
Conversation
(
role
,
prompt_text
,
image
=
image
),
history
)
append_conversation
(
Conversation
(
role
,
prompt_text
,
image
=
image
),
history
)
placeholder
=
st
.
container
()
placeholder
=
st
.
container
()
message_placeholder
=
placeholder
.
chat_message
(
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
name
=
"assistant"
,
avatar
=
"assistant"
)
markdown_placeholder
=
message_placeholder
.
empty
()
markdown_placeholder
=
message_placeholder
.
empty
()
def
add_new_block
():
def
add_new_block
():
nonlocal
message_placeholder
,
markdown_placeholder
nonlocal
message_placeholder
,
markdown_placeholder
message_placeholder
=
placeholder
.
chat_message
(
message_placeholder
=
placeholder
.
chat_message
(
name
=
"assistant"
,
avatar
=
"assistant"
)
name
=
"assistant"
,
avatar
=
"assistant"
)
markdown_placeholder
=
message_placeholder
.
empty
()
markdown_placeholder
=
message_placeholder
.
empty
()
def
commit_conversation
(
def
commit_conversation
(
...
@@ -301,33 +293,18 @@ def main(prompt_text: str):
...
@@ -301,33 +293,18 @@ def main(prompt_text: str):
history_len
=
len
(
chat_history
)
history_len
=
len
(
chat_history
)
last_response
=
response
last_response
=
response
replace_quote
=
chat_history
[
-
1
][
"role"
]
==
"assistant"
replace_quote
=
chat_history
[
-
1
][
"role"
]
==
"assistant"
markdown_placeholder
.
markdown
(
markdown_placeholder
.
markdown
(
postprocess_text
(
str
(
response
)
+
"●"
,
replace_quote
=
replace_quote
))
postprocess_text
(
str
(
response
)
+
"●"
,
replace_quote
=
replace_quote
)
)
else
:
else
:
metadata
=
(
metadata
=
page
==
Mode
.
ALL_TOOLS
and
isinstance
(
response
,
dict
)
and
response
.
get
(
"name"
)
or
None
page
==
Mode
.
ALL_TOOLS
and
isinstance
(
response
,
dict
)
and
response
.
get
(
"name"
)
or
None
)
role
=
Role
.
TOOL
if
metadata
else
Role
.
ASSISTANT
role
=
Role
.
TOOL
if
metadata
else
Role
.
ASSISTANT
text
=
(
text
=
response
.
get
(
"content"
)
if
metadata
else
response_to_str
(
response
)
response
.
get
(
"content"
)
if
metadata
else
response_to_str
(
response
)
)
commit_conversation
(
role
,
text
,
metadata
)
commit_conversation
(
role
,
text
,
metadata
)
if
metadata
:
if
metadata
:
add_new_block
()
add_new_block
()
try
:
try
:
with
markdown_placeholder
:
with
markdown_placeholder
:
with
st
.
spinner
(
f
"Calling tool
{
metadata
}
..."
):
with
st
.
spinner
(
f
"Calling tool
{
metadata
}
..."
):
observations
=
dispatch_tool
(
observations
=
dispatch_tool
(
metadata
,
text
,
str
(
st
.
session_state
.
session_id
))
metadata
,
text
,
str
(
st
.
session_state
.
session_id
)
)
except
Exception
as
e
:
except
Exception
as
e
:
traceback
.
print_exc
()
traceback
.
print_exc
()
st
.
error
(
f
'Uncaught exception in `"
{
metadata
}
"`:
{
e
}
'
)
st
.
error
(
f
'Uncaught exception in `"
{
metadata
}
"`:
{
e
}
'
)
...
@@ -346,7 +323,7 @@ def main(prompt_text: str):
...
@@ -346,7 +323,7 @@ def main(prompt_text: str):
continue
continue
else
:
else
:
break
break
except
Exception
as
e
:
except
Exception
:
traceback
.
print_exc
()
traceback
.
print_exc
()
st
.
error
(
f
"Uncaught exception:
{
traceback
.
format_exc
()
}
"
)
st
.
error
(
f
"Uncaught exception:
{
traceback
.
format_exc
()
}
"
)
else
:
else
:
...
...
composite_demo/src/tools/browser.py
→
demo/
composite_demo/src/tools/browser.py
View file @
67ca83cf
...
@@ -6,22 +6,26 @@ Simple browser tool.
...
@@ -6,22 +6,26 @@ Simple browser tool.
Please start the backend browser server according to the instructions in the README.
Please start the backend browser server according to the instructions in the README.
"""
"""
from
pprint
import
pprint
import
re
import
re
from
dataclasses
import
dataclass
from
pprint
import
pprint
import
requests
import
requests
import
streamlit
as
st
import
streamlit
as
st
from
dataclasses
import
dataclass
from
.config
import
BROWSER_SERVER_URL
from
.config
import
BROWSER_SERVER_URL
from
.interface
import
ToolObservation
from
.interface
import
ToolObservation
QUOTE_REGEX
=
re
.
compile
(
r
"\[(\d+)†(.+?)\]"
)
QUOTE_REGEX
=
re
.
compile
(
r
"\[(\d+)†(.+?)\]"
)
@
dataclass
@
dataclass
class
Quote
:
class
Quote
:
title
:
str
title
:
str
url
:
str
url
:
str
# Quotes for displaying reference
# Quotes for displaying reference
if
"quotes"
not
in
st
.
session_state
:
if
"quotes"
not
in
st
.
session_state
:
st
.
session_state
.
quotes
=
{}
st
.
session_state
.
quotes
=
{}
...
@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes
...
@@ -31,18 +35,18 @@ quotes: dict[str, Quote] = st.session_state.quotes
def
map_response
(
response
:
dict
)
->
ToolObservation
:
def
map_response
(
response
:
dict
)
->
ToolObservation
:
# Save quotes for reference
# Save quotes for reference
print
(
'
===BROWSER_RESPONSE===
'
)
print
(
"
===BROWSER_RESPONSE===
"
)
pprint
(
response
)
pprint
(
response
)
role_metadata
=
response
.
get
(
"roleMetadata"
)
role_metadata
=
response
.
get
(
"roleMetadata"
)
metadata
=
response
.
get
(
"metadata"
)
metadata
=
response
.
get
(
"metadata"
)
if
role_metadata
.
split
()[
0
]
==
'
quote_result
'
and
metadata
:
if
role_metadata
.
split
()[
0
]
==
"
quote_result
"
and
metadata
:
quote_id
=
QUOTE_REGEX
.
search
(
role_metadata
.
split
()[
1
]).
group
(
1
)
quote_id
=
QUOTE_REGEX
.
search
(
role_metadata
.
split
()[
1
]).
group
(
1
)
quote
:
dict
[
str
,
str
]
=
metadata
[
'
metadata_list
'
][
0
]
quote
:
dict
[
str
,
str
]
=
metadata
[
"
metadata_list
"
][
0
]
quotes
[
quote_id
]
=
Quote
(
quote
[
'
title
'
],
quote
[
'
url
'
])
quotes
[
quote_id
]
=
Quote
(
quote
[
"
title
"
],
quote
[
"
url
"
])
elif
role_metadata
==
'
browser_result
'
and
metadata
:
elif
role_metadata
==
"
browser_result
"
and
metadata
:
for
i
,
quote
in
enumerate
(
metadata
[
'
metadata_list
'
]):
for
i
,
quote
in
enumerate
(
metadata
[
"
metadata_list
"
]):
quotes
[
str
(
i
)]
=
Quote
(
quote
[
'
title
'
],
quote
[
'
url
'
])
quotes
[
str
(
i
)]
=
Quote
(
quote
[
"
title
"
],
quote
[
"
url
"
])
return
ToolObservation
(
return
ToolObservation
(
content_type
=
response
.
get
(
"contentType"
),
content_type
=
response
.
get
(
"contentType"
),
...
...
composite_demo/src/tools/cogview.py
→
demo/
composite_demo/src/tools/cogview.py
View file @
67ca83cf
...
@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage
...
@@ -5,18 +5,21 @@ from zhipuai.types.image import GeneratedImage
from
.config
import
COGVIEW_MODEL
,
ZHIPU_AI_KEY
from
.config
import
COGVIEW_MODEL
,
ZHIPU_AI_KEY
from
.interface
import
ToolObservation
from
.interface
import
ToolObservation
@
st
.
cache_resource
@
st
.
cache_resource
def
get_zhipu_client
():
def
get_zhipu_client
():
return
ZhipuAI
(
api_key
=
ZHIPU_AI_KEY
)
return
ZhipuAI
(
api_key
=
ZHIPU_AI_KEY
)
def
map_response
(
img
:
GeneratedImage
):
def
map_response
(
img
:
GeneratedImage
):
return
ToolObservation
(
return
ToolObservation
(
content_type
=
'
image
'
,
content_type
=
"
image
"
,
text
=
'
CogView 已经生成并向用户展示了生成的图片。
'
,
text
=
"
CogView 已经生成并向用户展示了生成的图片。
"
,
image_url
=
img
.
url
,
image_url
=
img
.
url
,
role_metadata
=
'
cogview_result
'
role_metadata
=
"
cogview_result
"
,
)
)
def
tool_call
(
prompt
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
def
tool_call
(
prompt
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
client
=
get_zhipu_client
()
client
=
get_zhipu_client
()
response
=
client
.
images
.
generations
(
model
=
COGVIEW_MODEL
,
prompt
=
prompt
).
data
response
=
client
.
images
.
generations
(
model
=
COGVIEW_MODEL
,
prompt
=
prompt
).
data
...
...
demo/composite_demo/src/tools/config.py
0 → 100644
View file @
67ca83cf
BROWSER_SERVER_URL
=
"http://localhost:3000"
IPYKERNEL
=
"glm-4-demo"
ZHIPU_AI_KEY
=
""
COGVIEW_MODEL
=
"cogview-3"
composite_demo/src/tools/interface.py
→
demo/
composite_demo/src/tools/interface.py
View file @
67ca83cf
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
from
typing
import
Any
@
dataclass
@
dataclass
class
ToolObservation
:
class
ToolObservation
:
content_type
:
str
content_type
:
str
...
...
composite_demo/src/tools/python.py
→
demo/
composite_demo/src/tools/python.py
View file @
67ca83cf
from
pprint
import
pprint
import
queue
import
queue
import
re
import
re
from
pprint
import
pprint
from
subprocess
import
PIPE
from
subprocess
import
PIPE
from
typing
import
Literal
from
typing
import
Literal
...
@@ -10,19 +10,22 @@ import streamlit as st
...
@@ -10,19 +10,22 @@ import streamlit as st
from
.config
import
IPYKERNEL
from
.config
import
IPYKERNEL
from
.interface
import
ToolObservation
from
.interface
import
ToolObservation
ANSI_ESCAPE
=
re
.
compile
(
r
'(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]'
)
CODE
=
re
.
compile
(
r
'```([^\n]*)\n(.*?)```'
)
class
CodeKernel
:
ANSI_ESCAPE
=
re
.
compile
(
r
"(\x9B|\x1B\[|\u001b\[)[0-?]*[ -/]*[@-~]"
)
def
__init__
(
self
,
CODE
=
re
.
compile
(
r
"```([^\n]*)\n(.*?)```"
)
kernel_name
=
'kernel'
,
kernel_id
=
None
,
kernel_config_path
=
""
,
python_path
=
None
,
ipython_path
=
None
,
init_file_path
=
"./startup.py"
,
verbose
=
1
):
class
CodeKernel
:
def
__init__
(
self
,
kernel_name
=
"kernel"
,
kernel_id
=
None
,
kernel_config_path
=
""
,
python_path
=
None
,
ipython_path
=
None
,
init_file_path
=
"./startup.py"
,
verbose
=
1
,
):
self
.
kernel_name
=
kernel_name
self
.
kernel_name
=
kernel_name
self
.
kernel_id
=
kernel_id
self
.
kernel_id
=
kernel_id
self
.
kernel_config_path
=
kernel_config_path
self
.
kernel_config_path
=
kernel_config_path
...
@@ -37,19 +40,16 @@ class CodeKernel:
...
@@ -37,19 +40,16 @@ class CodeKernel:
env
=
{
"PATH"
:
self
.
python_path
+
":$PATH"
,
"PYTHONPATH"
:
self
.
python_path
}
env
=
{
"PATH"
:
self
.
python_path
+
":$PATH"
,
"PYTHONPATH"
:
self
.
python_path
}
# Initialize the backend kernel
# Initialize the backend kernel
self
.
kernel_manager
=
jupyter_client
.
KernelManager
(
kernel_name
=
IPYKERNEL
,
self
.
kernel_manager
=
jupyter_client
.
KernelManager
(
connection_file
=
self
.
kernel_config_path
,
kernel_name
=
IPYKERNEL
,
connection_file
=
self
.
kernel_config_path
,
exec_files
=
[
self
.
init_file_path
],
env
=
env
exec_files
=
[
self
.
init_file_path
],
)
env
=
env
)
if
self
.
kernel_config_path
:
if
self
.
kernel_config_path
:
self
.
kernel_manager
.
load_connection_file
()
self
.
kernel_manager
.
load_connection_file
()
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
print
(
"Backend kernel started with the configuration: {}"
.
format
(
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_config_path
))
self
.
kernel_config_path
))
else
:
else
:
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
self
.
kernel_manager
.
start_kernel
(
stdout
=
PIPE
,
stderr
=
PIPE
)
print
(
"Backend kernel started with the configuration: {}"
.
format
(
print
(
"Backend kernel started with the configuration: {}"
.
format
(
self
.
kernel_manager
.
connection_file
))
self
.
kernel_manager
.
connection_file
))
if
verbose
:
if
verbose
:
pprint
(
self
.
kernel_manager
.
get_connection_info
())
pprint
(
self
.
kernel_manager
.
get_connection_info
())
...
@@ -64,13 +64,13 @@ class CodeKernel:
...
@@ -64,13 +64,13 @@ class CodeKernel:
self
.
kernel
.
execute
(
code
)
self
.
kernel
.
execute
(
code
)
try
:
try
:
shell_msg
=
self
.
kernel
.
get_shell_msg
(
timeout
=
30
)
shell_msg
=
self
.
kernel
.
get_shell_msg
(
timeout
=
30
)
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
'
content
'
]
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
"
content
"
]
while
True
:
while
True
:
msg_out
=
io_msg_content
msg_out
=
io_msg_content
### Poll the message
### Poll the message
try
:
try
:
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
'
content
'
]
io_msg_content
=
self
.
kernel
.
get_iopub_msg
(
timeout
=
30
)[
"
content
"
]
if
'
execution_state
'
in
io_msg_content
and
io_msg_content
[
'
execution_state
'
]
==
'
idle
'
:
if
"
execution_state
"
in
io_msg_content
and
io_msg_content
[
"
execution_state
"
]
==
"
idle
"
:
break
break
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
...
@@ -100,12 +100,12 @@ class CodeKernel:
...
@@ -100,12 +100,12 @@ class CodeKernel:
return
shell_msg
return
shell_msg
def
get_error_msg
(
self
,
msg
,
verbose
=
False
)
->
str
|
None
:
def
get_error_msg
(
self
,
msg
,
verbose
=
False
)
->
str
|
None
:
if
msg
[
'
content
'
][
'
status
'
]
==
'
error
'
:
if
msg
[
"
content
"
][
"
status
"
]
==
"
error
"
:
try
:
try
:
error_msg
=
msg
[
'
content
'
][
'
traceback
'
]
error_msg
=
msg
[
"
content
"
][
"
traceback
"
]
except
:
except
:
try
:
try
:
error_msg
=
msg
[
'
content
'
][
'
traceback
'
][
-
1
].
strip
()
error_msg
=
msg
[
"
content
"
][
"
traceback
"
][
-
1
].
strip
()
except
:
except
:
error_msg
=
"Traceback Error"
error_msg
=
"Traceback Error"
if
verbose
:
if
verbose
:
...
@@ -114,12 +114,12 @@ class CodeKernel:
...
@@ -114,12 +114,12 @@ class CodeKernel:
return
None
return
None
def
check_msg
(
self
,
msg
,
verbose
=
False
):
def
check_msg
(
self
,
msg
,
verbose
=
False
):
status
=
msg
[
'
content
'
][
'
status
'
]
status
=
msg
[
"
content
"
][
"
status
"
]
if
status
==
'
ok
'
:
if
status
==
"
ok
"
:
if
verbose
:
if
verbose
:
print
(
"Execution succeeded."
)
print
(
"Execution succeeded."
)
elif
status
==
'
error
'
:
elif
status
==
"
error
"
:
for
line
in
msg
[
'
content
'
][
'
traceback
'
]:
for
line
in
msg
[
"
content
"
][
"
traceback
"
]:
if
verbose
:
if
verbose
:
print
(
line
)
print
(
line
)
...
@@ -144,17 +144,17 @@ class CodeKernel:
...
@@ -144,17 +144,17 @@ class CodeKernel:
def
is_alive
(
self
):
def
is_alive
(
self
):
return
self
.
kernel
.
is_alive
()
return
self
.
kernel
.
is_alive
()
def
clean_ansi_codes
(
input_string
):
def
clean_ansi_codes
(
input_string
):
return
ANSI_ESCAPE
.
sub
(
''
,
input_string
)
return
ANSI_ESCAPE
.
sub
(
""
,
input_string
)
def
extract_code
(
text
:
str
)
->
str
:
def
extract_code
(
text
:
str
)
->
str
:
matches
=
CODE
.
findall
(
text
,
re
.
DOTALL
)
matches
=
CODE
.
findall
(
text
,
re
.
DOTALL
)
return
matches
[
-
1
][
1
]
return
matches
[
-
1
][
1
]
def
execute
(
code
:
str
,
def
execute
(
code
:
str
,
kernel
:
CodeKernel
)
->
tuple
[
Literal
[
"text"
,
"image"
]
|
None
,
str
]:
kernel
:
CodeKernel
)
->
tuple
[
Literal
[
'text'
,
'image'
]
|
None
,
str
]:
res
=
""
res
=
""
res_type
=
None
res_type
=
None
code
=
code
.
replace
(
"<|observation|>"
,
""
)
code
=
code
.
replace
(
"<|observation|>"
,
""
)
...
@@ -164,37 +164,38 @@ def execute(
...
@@ -164,37 +164,38 @@ def execute(
code
=
code
.
replace
(
"<|system|>"
,
""
)
code
=
code
.
replace
(
"<|system|>"
,
""
)
msg
,
output
=
kernel
.
execute
(
code
)
msg
,
output
=
kernel
.
execute
(
code
)
if
msg
[
'
metadata
'
][
'
status
'
]
==
"timeout"
:
if
msg
[
"
metadata
"
][
"
status
"
]
==
"timeout"
:
return
res_type
,
'
Timed out
'
return
res_type
,
"
Timed out
"
elif
msg
[
'
metadata
'
][
'
status
'
]
==
'
error
'
:
elif
msg
[
"
metadata
"
][
"
status
"
]
==
"
error
"
:
return
res_type
,
clean_ansi_codes
(
'
\n
'
.
join
(
kernel
.
get_error_msg
(
msg
,
verbose
=
True
)))
return
res_type
,
clean_ansi_codes
(
"
\n
"
.
join
(
kernel
.
get_error_msg
(
msg
,
verbose
=
True
)))
if
'
text
'
in
output
:
if
"
text
"
in
output
:
res_type
=
"text"
res_type
=
"text"
res
=
output
[
'
text
'
]
res
=
output
[
"
text
"
]
elif
'
data
'
in
output
:
elif
"
data
"
in
output
:
for
key
in
output
[
'
data
'
]:
for
key
in
output
[
"
data
"
]:
if
'
text/plain
'
in
key
:
if
"
text/plain
"
in
key
:
res_type
=
"text"
res_type
=
"text"
res
=
output
[
'
data
'
][
key
]
res
=
output
[
"
data
"
][
key
]
elif
'
image/png
'
in
key
:
elif
"
image/png
"
in
key
:
res_type
=
"image"
res_type
=
"image"
res
=
output
[
'
data
'
][
key
]
res
=
output
[
"
data
"
][
key
]
break
break
return
res_type
,
res
return
res_type
,
res
@
st
.
cache_resource
@
st
.
cache_resource
def
get_kernel
()
->
CodeKernel
:
def
get_kernel
()
->
CodeKernel
:
return
CodeKernel
()
return
CodeKernel
()
def
tool_call
(
code
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
def
tool_call
(
code
:
str
,
session_id
:
str
)
->
list
[
ToolObservation
]:
kernel
=
get_kernel
()
kernel
=
get_kernel
()
res_type
,
res
=
execute
(
code
,
kernel
)
res_type
,
res
=
execute
(
code
,
kernel
)
# Convert base64 to data uri
# Convert base64 to data uri
text
=
'
[Image]
'
if
res_type
==
'
image
'
else
res
text
=
"
[Image]
"
if
res_type
==
"
image
"
else
res
image
=
f
'
data:image/png;base64,
{
res
}
'
if
res_type
==
'
image
'
else
None
image
=
f
"
data:image/png;base64,
{
res
}
"
if
res_type
==
"
image
"
else
None
return
[
ToolObservation
(
res_type
,
text
,
image
)]
return
[
ToolObservation
(
res_type
,
text
,
image
)]
composite_demo/src/tools/tool_registry.py
→
demo/
composite_demo/src/tools/tool_registry.py
View file @
67ca83cf
...
@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and
...
@@ -4,22 +4,21 @@ This code provides extended functionality to the model, enabling it to call and
through defined interfaces.
through defined interfaces.
"""
"""
from
collections.abc
import
Callable
import
copy
import
copy
import
inspect
import
inspect
import
json
import
json
from
pprint
import
pformat
import
subprocess
import
traceback
import
traceback
from
collections.abc
import
Callable
from
types
import
GenericAlias
from
types
import
GenericAlias
from
typing
import
get_origin
,
Annotated
from
typing
import
Annotated
,
get_origin
import
subprocess
from
.interface
import
ToolObservation
from
.browser
import
tool_call
as
browser
from
.browser
import
tool_call
as
browser
from
.cogview
import
tool_call
as
cogview
from
.cogview
import
tool_call
as
cogview
from
.interface
import
ToolObservation
from
.python
import
tool_call
as
python
from
.python
import
tool_call
as
python
ALL_TOOLS
=
{
ALL_TOOLS
=
{
"simple_browser"
:
browser
,
"simple_browser"
:
browser
,
"python"
:
python
,
"python"
:
python
,
...
@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv
...
@@ -73,8 +72,8 @@ def dispatch_tool(tool_name: str, code: str, session_id: str) -> list[ToolObserv
# Dispatch predefined tools
# Dispatch predefined tools
if
tool_name
in
ALL_TOOLS
:
if
tool_name
in
ALL_TOOLS
:
return
ALL_TOOLS
[
tool_name
](
code
,
session_id
)
return
ALL_TOOLS
[
tool_name
](
code
,
session_id
)
code
=
code
.
strip
().
rstrip
(
'
<|observation|>
'
).
strip
()
code
=
code
.
strip
().
rstrip
(
"
<|observation|>
"
).
strip
()
# Dispatch custom tools
# Dispatch custom tools
try
:
try
:
...
@@ -105,8 +104,8 @@ def get_tools() -> list[dict]:
...
@@ -105,8 +104,8 @@ def get_tools() -> list[dict]:
@
register_tool
@
register_tool
def
random_number_generator
(
def
random_number_generator
(
seed
:
Annotated
[
int
,
"The random seed used by the generator"
,
True
],
seed
:
Annotated
[
int
,
"The random seed used by the generator"
,
True
],
range
:
Annotated
[
tuple
[
int
,
int
],
"The range of the generated numbers"
,
True
],
range
:
Annotated
[
tuple
[
int
,
int
],
"The range of the generated numbers"
,
True
],
)
->
int
:
)
->
int
:
"""
"""
Generates a random number x, s.t. range[0] <= x < range[1]
Generates a random number x, s.t. range[0] <= x < range[1]
...
@@ -125,7 +124,7 @@ def random_number_generator(
...
@@ -125,7 +124,7 @@ def random_number_generator(
@
register_tool
@
register_tool
def
get_weather
(
def
get_weather
(
city_name
:
Annotated
[
str
,
"The name of the city to be queried"
,
True
],
city_name
:
Annotated
[
str
,
"The name of the city to be queried"
,
True
],
)
->
str
:
)
->
str
:
"""
"""
Get the current weather for `city_name`
Get the current weather for `city_name`
...
@@ -153,16 +152,14 @@ def get_weather(
...
@@ -153,16 +152,14 @@ def get_weather(
except
:
except
:
import
traceback
import
traceback
ret
=
(
ret
=
"Error encountered while fetching weather data!
\n
"
+
traceback
.
format_exc
()
"Error encountered while fetching weather data!
\n
"
+
traceback
.
format_exc
()
)
return
str
(
ret
)
return
str
(
ret
)
@
register_tool
@
register_tool
def
get_shell
(
def
get_shell
(
query
:
Annotated
[
str
,
"The command should run in Linux shell"
,
True
],
query
:
Annotated
[
str
,
"The command should run in Linux shell"
,
True
],
)
->
str
:
)
->
str
:
"""
"""
Use shell to run command
Use shell to run command
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment