"vscode:/vscode.git/clone" did not exist on "3047dc9b500266d8197139fad5ef3a8a4a459985"
Commit 1768a324 authored by dengjb's avatar dengjb
Browse files

update codes

parent 18493eef
Pipeline #1372 failed with stages
in 0 seconds
FROM python:3.11-slim-bookworm
WORKDIR /mnt/data
RUN apt-get update && apt-get install -y \
gcc \
libffi-dev \
zlib1g-dev \
fonts-arphic-ukai \
fonts-arphic-uming \
fonts-ipafont-mincho \
fonts-ipafont-gothic \
fonts-unfonts-core \
libgdal-dev \
g++ \
&& rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir \
pydantic \
tornado \
jupyter_client \
ipython \
ipykernel \
numpy \
pandas \
scipy \
matplotlib \
scikit-learn \
notebook \
beautifulsoup4 \
seaborn \
pytest \
ipywidgets \
sympy \
statsmodels \
joblib \
cython \
lxml \
xlrd \
qrcode \
nltk \
opencv-python \
Pillow \
geopandas
ENV HOME=/mnt/data
RUN find / -perm +6000 -type f -exec chmod a-s {} \; || true
RUN echo "set -o history -o vi" >> /etc/profile
RUN useradd -u 999 -ms /bin/bash appuser
RUN chown -R appuser:appuser /mnt/data
USER appuser
ENV JUPYTER_RUNTIME_DIR=/mnt/data/.local/share/jupyter/runtime
ENV JUPYTER_DATA_DIR=/mnt/data/.local/share/jupyter
ENV JUPYTER_CONFIG_DIR=/mnt/data/.jupyter
COPY sandbox.py /sandbox.py
VOLUME [ "/mnt/data" ]
CMD ["python", "/sandbox.py"]
# Codegeex4 Interpreter Gradio
Fully local gradio demo of CodeGeeX4 Interpreter.
## Usage
### Install Dependencies
```python
pip install gradio requests
```
### Build & Launch Sandbox
```bash
docker build -t sandbox -f Dockerfile.sandbox .
docker run --name sandbox --publish 8080:8080 sandbox
```
### Launch Demo
```bash
python app.py --tgi-addr <tgi-addr>
```
## Docs
Check out the [documentation](./SANDBOX.md) for the sandbox API.
# Codegeex4 代码解释器DEMO
完全本地可运行的 CodeGeeX4 代码解释器.
## 使用方法
### 安装依赖
```python
pip install gradio requests
```
### 构建并启动本地沙盒环境
```bash
docker build -t sandbox -f Dockerfile.sandbox .
docker run --name sandbox --publish 8080:8080 sandbox
```
### 启动DEMO
```bash
python app.py --tgi-addr <tgi-addr>
```
## 文档
参考 [沙盒API文档](./SANDBOX.md)
# Sandbox API
### Ping
**Path:** GET `/`
Check whether a sandbox is alive and return information about it.
#### Request
-
#### Response
**Status:**
- `200` if alive
**Example:**
```json
{
"last_activity": "2006-01-02T15:04:05Z07:00", // RFC 3339
}
```
### Execute
**Path:** POST `/execute`
#### Request
**Content-Type:** `application/json`
**JSON Schema:**
| Name | Type | Description |
| -------------- | ----------------- | ------------------------------------------------------------------------------------------------------ |
| `code` | string | The code to be executed. |
| `timeout_secs` | number (Optional) | Abort execution after timeout. Does not include environment and runtime creation time. Defaults to 60. |
#### Response
**Status:**
- `200` if successful
**Content-Type:** `application/json`
**Example:**
```json
{
"status": "ok", // Possible values: "ok", "timeout"
"events": [
{
"type": "stream",
"timestamp": "2006-01-02T15:04:05Z07:00", // RFC 3339
"data": {
"name": "stdout", // Possible values: "stdout", "stderr"
"text": "Hello World!",
}
},
{
"type": "display_data",
"timestamp": "2006-01-02T15:04:05Z07:00", // RFC 3339
"data": {
"variants": {
"text/plain": "<IPython.core.display.Image object>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGP4//8/AAX+Av4N70a4AAAAAElFTkSuQmCC" // Base64 encoded PNG image
}
}
},
{
"type": "file", // The program has written a file to disk.
"timestamp": "2006-01-02T15:04:05Z07:00", // RFC 3339
"data": {
"path": "README.md",
"size": 128, // Size is expressed in bytes
}
},
{
"type": "error",
"timestamp": "2006-01-02T15:04:05Z07:00", // RFC 3339
"data": {
"ename": "ZeroDivisionError",
"evalue": "division by zero",
"traceback": [
"\\u001b[0;31m---------------------------------------------------------------------------\\u001b[0m",
"\\u001b[0;31mZeroDivisionError\\u001b[0m Traceback (most recent call last)",
"Cell \\u001b[0;32mIn[1], line 2\\u001b[0m\\n\\u001b[1;32m 1\\u001b[0m \\u001b[38;5;66;03m# \\u8ba1\\u7b97\\u8868\\u8fbe\\u5f0f\\u7684\\u7ed3\\u679c\\u001b[39;00m\\n\\u001b[0;32m----> 2\\u001b[0m result \\u001b[38;5;241m=\\u001b[39m \\u001b[38;5;241;43m361234\\u001b[39;49m\\u001b[43m \\u001b[49m\\u001b[38;5;241;43m/\\u001b[39;49m\\u001b[43m \\u001b[49m\\u001b[38;5;241;43m0\\u001b[39;49m \\u001b[38;5;241m+\\u001b[39m \\u001b[38;5;241m4514\\u001b[39m \\u001b[38;5;241m*\\u001b[39m \\u001b[38;5;241m1234\\u001b[39m \\u001b[38;5;241m-\\u001b[39m \\u001b[38;5;241m27152346\\u001b[39m \\u001b[38;5;241m/\\u001b[39m \\u001b[38;5;241m2023\\u001b[39m\\n\\u001b[1;32m 3\\u001b[0m result\\n",
"\\u001b[0;31mZeroDivisionError\\u001b[0m: division by zero"
]
}
}
]
}
```
### File upload
**Path:** POST `/files/upload/-/*path`
Upload a file to the sandbox under `*path`.
#### Request
**Content-Length:** The length of the file in bytes.
**Body:** The raw contents of the file as bytes.
#### Response
**Status:**
- `201` if upload was successful
- `409` if file already exists
### File download
**Path:** GET `/files/download/-/*path`
Download a file from the sandbox from `*path`.
#### Request
\-
#### Response
**Content-Type:** Automatically detected, depending on the file.
**Content-Disposition:** `attachment; filename*=UTF-8''<filename>`
**Body:** The raw contents of the file.
**Status:**
- `200` if file exists
- `404` if file is not found
import argparse
import json
import os
import re
from typing import Any, Dict, List, Tuple
import gradio as gr
import requests
SYSTEM_PROMPT = {
"zh": "你是一位智能编程助手,你叫CodeGeeX,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。",
"en": "You are an intelligent programming assistant named CodeGeeX, connected to a computer, but please note that you cannot access the internet. When solving tasks using Python, you can run code and obtain results. If there are any errors in the results, you need to improve the code as much as possible. You can also handle files uploaded to the computer, with the default storage path being /mnt/data/.",
}
CODEGEEX_SPECIAL_TOKENS = {
"user": "<|user|>",
"assistant": "<|assistant|>",
"system": "<|system|>",
"observation": "<|observation|>",
"eos": "<|endoftext|>",
}
parser = argparse.ArgumentParser(description="CodeGeeX4 Interpreter")
parser.add_argument("--tgi-addr", type=str, required=True)
parser.add_argument("--sandbox-addr", type=str, default="http://127.0.0.1:8080")
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-p", type=float, default=0.95)
args = parser.parse_args()
code_block_regex = re.compile(r"```(.*?)\n(.*?)```", re.DOTALL)
def execute_code_block(lang, code) -> Tuple[List[Dict[str, Any]], str]:
assert lang in ["python"]
response = requests.post(
f"{args.sandbox_addr}/execute",
json={"code": code, "timeout_secs": 60},
)
response = response.json()
print(f"[RESPONSE] {response}")
return response["events"], response["status"]
def upload_file(filepath: str, contents: str):
print(f"[REQUEST] Upload {filepath} ({len(contents)} bytes)")
response = requests.post(
f"{args.sandbox_addr}/files/upload/-/{filepath.lstrip('/')}",
data=bytes(contents, encoding="utf-8"),
)
print(f"[RESPONSE] {response.text}")
assert response.status_code == 201
def stream_chat_completion(message, history):
should_stop = False
round = 0
max_rounds = 5
file_info = ""
for filepath in message.get("files", []):
with open(filepath, "r") as f:
contents = f.read()
filename = os.path.basename(filepath)
upload_file(f"/mnt/data/{filename}", contents)
file_info += f"# File: /mnt/data/{filename}\n"
file_info += f"# Size: {len(contents)}\n"
file_info += "# File uploaded\n"
prompt = f"{CODEGEEX_SPECIAL_TOKENS['system']}\n{SYSTEM_PROMPT['en']}\n"
for [user_message, bot_message] in history:
if isinstance(user_message, tuple):
# It's a file
pass
else:
# Remove any '![image](data:image/png;base64,...)' from the bot message.
bot_message = re.sub(
r"!\[image\]\(data:image/png;base64,[^\)]+\)", "", bot_message
)
prompt += f"{CODEGEEX_SPECIAL_TOKENS['user']}\n{user_message}\n"
prompt += f"{CODEGEEX_SPECIAL_TOKENS['assistant']}\n{bot_message}\n"
prompt += f"{CODEGEEX_SPECIAL_TOKENS['user']}\n{file_info}{message['text']}\n"
prompt += f"{CODEGEEX_SPECIAL_TOKENS['assistant']}\n"
stop_sequences = [
CODEGEEX_SPECIAL_TOKENS["eos"],
CODEGEEX_SPECIAL_TOKENS["user"],
CODEGEEX_SPECIAL_TOKENS["observation"],
]
while not should_stop and round < max_rounds:
round += 1
request_json_body = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 2048,
"do_sample": True,
"top_p": args.top_p,
"temperature": args.temperature,
"stop": stop_sequences,
"details": True,
"stream": False,
},
}
print(f"[REQUEST] {request_json_body}")
response = requests.post(
f"{args.tgi_addr}/generate_stream",
json=request_json_body,
stream=True,
)
completion = ""
for line in response.iter_lines():
if line:
event = line.decode("utf-8")
if event.startswith("data:"):
event = event[5:].strip()
event = json.loads(event)
token = event["token"]["text"]
completion += token
prompt += token
# Only display the token if it's not "special".
if event["token"]["text"] not in CODEGEEX_SPECIAL_TOKENS.values():
yield token
# If the model asks for the code to be executed, do it.
if event["token"]["text"] == CODEGEEX_SPECIAL_TOKENS["observation"]:
match = code_block_regex.search(completion)
if match is None:
# Hm, it seems the model didn't write any code.
# Let's gently warn it.
prompt += f"\n```result\nError: no code to execute.\n```\n{CODEGEEX_SPECIAL_TOKENS['assistant']}\n"
yield "```\nError: no code to execute.\n```\n"
break
lang, code = match.groups()
events, status = execute_code_block(lang, code)
buffer = []
for exec_event in events:
if exec_event["type"] == "stream":
buffer.append(exec_event["text"])
if exec_event["type"] == "display_data":
if "text/plain" in exec_event["data"]["variants"]:
buffer.append(
exec_event["data"]["variants"]["text/plain"]
)
if status == "timeout":
buffer.append("Execution timed out.")
if status == "error":
buffer.append("Execution failed.")
prompt += f"\n```result\n{''.join(buffer)}\n```\n{CODEGEEX_SPECIAL_TOKENS['assistant']}\n"
yield f"```\n{''.join(buffer)}\n```\n"
for exec_event in events:
if exec_event["type"] == "display_data":
if "image/png" in exec_event["data"]["variants"]:
yield f"![image](data:image/png;base64,{exec_event['data']['variants']['image/png']})"
elif "text/html" in exec_event["data"]["variants"]:
yield exec_event["data"]["variants"]["text/html"]
break
# If the model otherwise ends the generation, stop here.
if event["details"] is not None:
should_stop = True
break
print(f"[RESPONSE] {completion}")
def predict(message: Dict[str, Any], history: List[List[str | None | tuple]]):
completion = ""
for delta in stream_chat_completion(message, history):
completion += delta
# Replace (sandbox:/ by (<sandbox-address>/
completion = completion.replace(
"sandbox:/", f"{args.sandbox_addr}/files/download/-/"
)
yield completion
demo = gr.ChatInterface(
fn=predict,
title="CodeGeeX4 Interpreter",
description="",
examples=[
{"text": "Compute factorial of 21 using code", "files": []},
{
"text": "Plot the class distribution of this dataset",
"files": ["./data.csv"],
},
{
"text": 'Reverse the following string and save it to a file: "9738426487936"',
"files": [],
},
],
multimodal=True,
)
demo.launch()
"sepal.length","sepal.width","petal.length","petal.width","variety"
5.1,3.5,1.4,.2,"Setosa"
4.9,3,1.4,.2,"Setosa"
4.7,3.2,1.3,.2,"Setosa"
4.6,3.1,1.5,.2,"Setosa"
5,3.6,1.4,.2,"Setosa"
5.4,3.9,1.7,.4,"Setosa"
4.6,3.4,1.4,.3,"Setosa"
5,3.4,1.5,.2,"Setosa"
4.4,2.9,1.4,.2,"Setosa"
4.9,3.1,1.5,.1,"Setosa"
5.4,3.7,1.5,.2,"Setosa"
4.8,3.4,1.6,.2,"Setosa"
4.8,3,1.4,.1,"Setosa"
4.3,3,1.1,.1,"Setosa"
5.8,4,1.2,.2,"Setosa"
5.7,4.4,1.5,.4,"Setosa"
5.4,3.9,1.3,.4,"Setosa"
5.1,3.5,1.4,.3,"Setosa"
5.7,3.8,1.7,.3,"Setosa"
5.1,3.8,1.5,.3,"Setosa"
5.4,3.4,1.7,.2,"Setosa"
5.1,3.7,1.5,.4,"Setosa"
4.6,3.6,1,.2,"Setosa"
5.1,3.3,1.7,.5,"Setosa"
4.8,3.4,1.9,.2,"Setosa"
5,3,1.6,.2,"Setosa"
5,3.4,1.6,.4,"Setosa"
5.2,3.5,1.5,.2,"Setosa"
5.2,3.4,1.4,.2,"Setosa"
4.7,3.2,1.6,.2,"Setosa"
4.8,3.1,1.6,.2,"Setosa"
5.4,3.4,1.5,.4,"Setosa"
5.2,4.1,1.5,.1,"Setosa"
5.5,4.2,1.4,.2,"Setosa"
4.9,3.1,1.5,.2,"Setosa"
5,3.2,1.2,.2,"Setosa"
5.5,3.5,1.3,.2,"Setosa"
4.9,3.6,1.4,.1,"Setosa"
4.4,3,1.3,.2,"Setosa"
5.1,3.4,1.5,.2,"Setosa"
5,3.5,1.3,.3,"Setosa"
4.5,2.3,1.3,.3,"Setosa"
4.4,3.2,1.3,.2,"Setosa"
5,3.5,1.6,.6,"Setosa"
5.1,3.8,1.9,.4,"Setosa"
4.8,3,1.4,.3,"Setosa"
5.1,3.8,1.6,.2,"Setosa"
4.6,3.2,1.4,.2,"Setosa"
5.3,3.7,1.5,.2,"Setosa"
5,3.3,1.4,.2,"Setosa"
7,3.2,4.7,1.4,"Versicolor"
6.4,3.2,4.5,1.5,"Versicolor"
6.9,3.1,4.9,1.5,"Versicolor"
5.5,2.3,4,1.3,"Versicolor"
6.5,2.8,4.6,1.5,"Versicolor"
5.7,2.8,4.5,1.3,"Versicolor"
6.3,3.3,4.7,1.6,"Versicolor"
4.9,2.4,3.3,1,"Versicolor"
6.6,2.9,4.6,1.3,"Versicolor"
5.2,2.7,3.9,1.4,"Versicolor"
5,2,3.5,1,"Versicolor"
5.9,3,4.2,1.5,"Versicolor"
6,2.2,4,1,"Versicolor"
6.1,2.9,4.7,1.4,"Versicolor"
5.6,2.9,3.6,1.3,"Versicolor"
6.7,3.1,4.4,1.4,"Versicolor"
5.6,3,4.5,1.5,"Versicolor"
5.8,2.7,4.1,1,"Versicolor"
6.2,2.2,4.5,1.5,"Versicolor"
5.6,2.5,3.9,1.1,"Versicolor"
5.9,3.2,4.8,1.8,"Versicolor"
6.1,2.8,4,1.3,"Versicolor"
6.3,2.5,4.9,1.5,"Versicolor"
6.1,2.8,4.7,1.2,"Versicolor"
6.4,2.9,4.3,1.3,"Versicolor"
6.6,3,4.4,1.4,"Versicolor"
6.8,2.8,4.8,1.4,"Versicolor"
6.7,3,5,1.7,"Versicolor"
6,2.9,4.5,1.5,"Versicolor"
5.7,2.6,3.5,1,"Versicolor"
5.5,2.4,3.8,1.1,"Versicolor"
5.5,2.4,3.7,1,"Versicolor"
5.8,2.7,3.9,1.2,"Versicolor"
6,2.7,5.1,1.6,"Versicolor"
5.4,3,4.5,1.5,"Versicolor"
6,3.4,4.5,1.6,"Versicolor"
6.7,3.1,4.7,1.5,"Versicolor"
6.3,2.3,4.4,1.3,"Versicolor"
5.6,3,4.1,1.3,"Versicolor"
5.5,2.5,4,1.3,"Versicolor"
5.5,2.6,4.4,1.2,"Versicolor"
6.1,3,4.6,1.4,"Versicolor"
5.8,2.6,4,1.2,"Versicolor"
5,2.3,3.3,1,"Versicolor"
5.6,2.7,4.2,1.3,"Versicolor"
5.7,3,4.2,1.2,"Versicolor"
5.7,2.9,4.2,1.3,"Versicolor"
6.2,2.9,4.3,1.3,"Versicolor"
5.1,2.5,3,1.1,"Versicolor"
5.7,2.8,4.1,1.3,"Versicolor"
6.3,3.3,6,2.5,"Virginica"
5.8,2.7,5.1,1.9,"Virginica"
7.1,3,5.9,2.1,"Virginica"
6.3,2.9,5.6,1.8,"Virginica"
6.5,3,5.8,2.2,"Virginica"
7.6,3,6.6,2.1,"Virginica"
4.9,2.5,4.5,1.7,"Virginica"
7.3,2.9,6.3,1.8,"Virginica"
6.7,2.5,5.8,1.8,"Virginica"
7.2,3.6,6.1,2.5,"Virginica"
6.5,3.2,5.1,2,"Virginica"
6.4,2.7,5.3,1.9,"Virginica"
6.8,3,5.5,2.1,"Virginica"
5.7,2.5,5,2,"Virginica"
5.8,2.8,5.1,2.4,"Virginica"
6.4,3.2,5.3,2.3,"Virginica"
6.5,3,5.5,1.8,"Virginica"
7.7,3.8,6.7,2.2,"Virginica"
7.7,2.6,6.9,2.3,"Virginica"
6,2.2,5,1.5,"Virginica"
6.9,3.2,5.7,2.3,"Virginica"
5.6,2.8,4.9,2,"Virginica"
7.7,2.8,6.7,2,"Virginica"
6.3,2.7,4.9,1.8,"Virginica"
6.7,3.3,5.7,2.1,"Virginica"
7.2,3.2,6,1.8,"Virginica"
6.2,2.8,4.8,1.8,"Virginica"
6.1,3,4.9,1.8,"Virginica"
6.4,2.8,5.6,2.1,"Virginica"
7.2,3,5.8,1.6,"Virginica"
7.4,2.8,6.1,1.9,"Virginica"
7.9,3.8,6.4,2,"Virginica"
6.4,2.8,5.6,2.2,"Virginica"
6.3,2.8,5.1,1.5,"Virginica"
6.1,2.6,5.6,1.4,"Virginica"
7.7,3,6.1,2.3,"Virginica"
6.3,3.4,5.6,2.4,"Virginica"
6.4,3.1,5.5,1.8,"Virginica"
6,3,4.8,1.8,"Virginica"
6.9,3.1,5.4,2.1,"Virginica"
6.7,3.1,5.6,2.4,"Virginica"
6.9,3.1,5.1,2.3,"Virginica"
5.8,2.7,5.1,1.9,"Virginica"
6.8,3.2,5.9,2.3,"Virginica"
6.7,3.3,5.7,2.5,"Virginica"
6.7,3,5.2,2.3,"Virginica"
6.3,2.5,5,1.9,"Virginica"
6.5,3,5.2,2,"Virginica"
6.2,3.4,5.4,2.3,"Virginica"
5.9,3,5.1,1.8,"Virginica"
\ No newline at end of file
import argparse
import asyncio
import json
import logging
import os
import signal
import sys
from asyncio import Queue
from datetime import datetime, timezone
from typing import Annotated, List, Union
import tornado.escape
import tornado.ioloop
import tornado.web
from annotated_types import Gt
from jupyter_client.asynchronous.client import AsyncKernelClient
from jupyter_client.manager import AsyncKernelManager
from pydantic import BaseModel
# Shell Jupyter message types
JupyterMessageTypeExecuteRequest = "execute_request"
JupyterMessageTypeExecuteReply = "execute_reply"
# IOPub Jupyter message types
JupyterMessageTypeStream = "stream"
JupyterMessageTypeDisplayData = "display_data"
JupyterMessageTypeExecuteResult = "execute_result"
JupyterMessageTypeError = "error"
JupyterMessageTypeStatus = "status"
# Supported Jupyter message types (IOPub only)
JupyterSupportedMessageTypes = [
JupyterMessageTypeStream,
JupyterMessageTypeDisplayData,
JupyterMessageTypeExecuteResult,
JupyterMessageTypeError,
JupyterMessageTypeStatus,
]
# Kernel execution states
JupyterExecutionStateBusy = "busy"
JupyterExecutionStateIdle = "idle"
JupyterExecutionStateStarting = "starting"
# Saturn execution event types
ExecutionEventTypeStream = "stream"
ExecutionEventTypeDisplayData = "display_data"
ExecutionEventTypeError = "error"
# Saturn execution statuses
ExecutionStatusOK = "ok"
ExecutionStatusTimeout = "timeout"
class ExecutionEventStream(BaseModel):
stream: str
text: str
class ExecutionEventDisplayData(BaseModel):
variants: dict
class ExecutionEventError(BaseModel):
ename: str
evalue: str
traceback: list[str]
class ExecutionEvent(BaseModel):
type: str
timestamp: str # RFC3339
data: Union[
ExecutionEventStream,
ExecutionEventDisplayData,
ExecutionEventError,
]
class ExecuteRequest(BaseModel):
code: str
timeout_secs: Annotated[int, Gt(0)]
class ExecuteResponse(BaseModel):
status: str
events: List[ExecutionEvent]
class PingResponse(BaseModel):
last_activity: str # RFC3339
class Error(BaseModel):
error: str
def datetime_to_rfc3339(dt: datetime) -> str:
"""Convert a datetime to an RFC3339 formatted string."""
return dt.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
def rfc3339_to_datetime(date_string: str) -> datetime:
"""Convert an RFC3339 formatted string to a datetime."""
return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ").replace(
tzinfo=timezone.utc
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
async def async_create_kernel(kernel_name: str):
logging.info(f"Starting kernel for spec '{kernel_name}'")
km = AsyncKernelManager(kernel_name=kernel_name)
await km.start_kernel()
client: AsyncKernelClient = km.client()
client.start_channels()
await client.wait_for_ready()
logging.info("Kernel started")
return km, client
msg_id_to_queue: dict[str, Queue] = {}
async def async_msg_producer(km: AsyncKernelManager, kc: AsyncKernelClient):
try:
while True:
logging.info("Waiting for message...")
msg = await kc.get_iopub_msg()
log_jupyter_kernel_message(msg)
parent_msg_id = msg["parent_header"].get("msg_id")
if parent_msg_id in msg_id_to_queue:
await msg_id_to_queue[parent_msg_id].put(msg)
except Exception as e:
logging.error(f"Error in message producer: {e}")
await async_shutdown(km)
async def async_shutdown(km: AsyncKernelManager):
logging.info("Shutting down kernel...")
await km.shutdown_kernel()
logging.info("Kernel shut down")
sys.exit(0)
class State:
def __init__(self, kernel_client: AsyncKernelClient):
self.last_activity = datetime.now()
self.kernel_client = kernel_client
def reset_last_activity(self):
self.last_activity = datetime.now()
class MainHandler(tornado.web.RequestHandler):
def initialize(self, state: State):
self.state = state
async def get(self):
try:
is_alive = await client.is_alive()
if not is_alive:
raise Exception("kernel is not alive")
self.write(
PingResponse(
last_activity=datetime_to_rfc3339(self.state.last_activity)
).model_dump_json()
)
except Exception as e:
self.set_status(500)
self.write(Error(error=str(e)).model_dump_json())
return
def serializer(o):
if isinstance(o, datetime):
return o.isoformat()
raise TypeError("Type not serializable")
def log_jupyter_kernel_message(msg):
m = json.dumps(msg, default=serializer)
logging.info(f"Jupyter: {m}")
class ExecuteHandler(tornado.web.RequestHandler):
def initialize(self, state: State):
self.state = state
async def post(self):
parent_msg_id = None
res: ExecuteResponse = ExecuteResponse(status=ExecutionStatusOK, events=[])
try:
logging.info(f"Execute request: {self.request.body}")
self.state.reset_last_activity()
req = ExecuteRequest.model_validate_json(self.request.body)
local_queue = Queue()
parent_msg_id = self.state.kernel_client.execute(req.code)
msg_id_to_queue[parent_msg_id] = local_queue
# Use the timeout logic on message processing
try:
await asyncio.wait_for(
self.process_messages(parent_msg_id, local_queue, res),
timeout=req.timeout_secs,
)
except asyncio.TimeoutError:
logging.info(f"Timeout after {req.timeout_secs}s")
res.status = ExecutionStatusTimeout
return self.write(res.model_dump_json())
self.state.reset_last_activity()
self.write(res.model_dump_json())
except Exception as e:
self.set_status(500)
self.write(Error(error=str(e)).model_dump_json())
finally:
# Cleanup after processing all messages
if parent_msg_id is not None and parent_msg_id in msg_id_to_queue:
del msg_id_to_queue[parent_msg_id]
logging.info(f"Execute response: {res.model_dump_json()}")
async def process_messages(self, parent_msg_id, queue, res):
while True:
msg = await queue.get()
if msg["msg_type"] not in JupyterSupportedMessageTypes:
continue
elif msg["msg_type"] == JupyterMessageTypeStatus:
if msg["content"]["execution_state"] == JupyterExecutionStateIdle:
break
elif msg["msg_type"] == JupyterMessageTypeStream:
res.events.append(
ExecutionEvent(
type=ExecutionEventTypeStream,
timestamp=datetime_to_rfc3339(datetime.now()),
data=ExecutionEventStream(
stream=msg["content"]["name"],
text=msg["content"]["text"],
),
)
)
elif msg["msg_type"] == JupyterMessageTypeDisplayData:
res.events.append(
ExecutionEvent(
type=ExecutionEventTypeDisplayData,
timestamp=datetime_to_rfc3339(datetime.now()),
data=ExecutionEventDisplayData(variants=msg["content"]["data"]),
)
)
elif msg["msg_type"] == JupyterMessageTypeError:
res.events.append(
ExecutionEvent(
type=ExecutionEventTypeError,
timestamp=datetime_to_rfc3339(datetime.now()),
data=ExecutionEventError(
ename=msg["content"]["ename"],
evalue=msg["content"]["evalue"],
traceback=msg["content"]["traceback"],
),
)
)
elif msg["msg_type"] == JupyterMessageTypeExecuteResult:
res.events.append(
ExecutionEvent(
type=ExecutionEventTypeDisplayData,
timestamp=datetime_to_rfc3339(datetime.now()),
data=ExecutionEventDisplayData(variants=msg["content"]["data"]),
)
)
@tornado.web.stream_request_body
class FileUploadHandler(tornado.web.RequestHandler):
def initialize(self, state: State):
self.state = state
self.file_obj = None
async def prepare(self):
if self.request.method != "POST":
self.set_status(404)
self.finish()
return
path = self.path_args[0]
full_path = os.path.join("/", path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
self.file_obj = open(full_path, "wb")
content_length = int(self.request.headers.get("Content-Length", 0))
logging.info(f"File upload: '{path}' (Content-Length: {content_length})")
def data_received(self, chunk):
if self.file_obj:
self.file_obj.write(chunk)
async def post(self, path):
self.state.reset_last_activity()
if self.file_obj:
self.file_obj.close()
self.set_status(201)
class FileDownloadHandler(tornado.web.RequestHandler):
def initialize(self, state: State):
self.state = state
async def get(self, path):
self.state.reset_last_activity()
full_path = os.path.join("/", path)
if not os.path.exists(full_path):
self.set_status(404)
self.write(Error(error="file not found").model_dump_json())
return
content_length = os.path.getsize(full_path)
logging.info(f"File download: '{path}' (Content-Length: {content_length})")
# Set appropriate headers for file download
self.set_header("Content-Length", content_length)
self.set_header("Content-Type", "application/octet-stream")
self.set_header(
"Content-Disposition",
f"attachment; filename*=UTF-8''{tornado.escape.url_escape(os.path.basename(full_path))}",
)
# Stream the file to the client
with open(full_path, "rb") as f:
while True:
chunk = f.read(64 * 1024)
if not chunk:
break
try:
self.write(chunk)
await self.flush()
except tornado.iostream.StreamClosedError:
return
def shutdown(ioloop: tornado.ioloop.IOLoop, km):
logging.info("Shutting down server...")
ioloop.add_callback_from_signal(lambda: async_shutdown(km))
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--port", type=int, default=80)
p.add_argument("--kernel-name", type=str, default="python3")
args = p.parse_args()
km, client = asyncio.run(async_create_kernel(args.kernel_name))
state = State(client)
application = tornado.web.Application(
[
(r"/", MainHandler, {"state": state}),
(r"/execute", ExecuteHandler, {"state": state}),
(r"/files/upload/-/(.*)", FileUploadHandler, {"state": state}),
(r"/files/download/-/(.*)", FileDownloadHandler, {"state": state}),
]
)
application.listen(args.port)
logging.info(f"Server started at http://localhost:{args.port}")
ioloop = tornado.ioloop.IOLoop.current()
signal.signal(signal.SIGINT, lambda sig, frame: shutdown(ioloop, km))
signal.signal(signal.SIGTERM, lambda sig, frame: shutdown(ioloop, km))
ioloop.add_callback(async_msg_producer, km, client)
tornado.ioloop.IOLoop.current().start()
import os
import shutil
import tempfile
import unittest
import requests
from sandbox import (
Error,
ExecuteResponse,
ExecutionEventTypeDisplayData,
ExecutionEventTypeError,
ExecutionEventTypeStream,
ExecutionStatusOK,
ExecutionStatusTimeout,
)
# We'll create a temporary directory for the tests to avoid any side effects.
temp_dir = tempfile.mkdtemp()
BASE_URL = "http://localhost:8888/"
def url(path: str) -> str:
return BASE_URL + path
class TestExecuteHandler(unittest.TestCase):
def must_bind_with_execute_response(self, r: requests.Response) -> ExecuteResponse:
self.assertEqual(r.status_code, 200)
return ExecuteResponse.model_validate_json(r.content)
def must_bind_with_error(self, r: requests.Response) -> Error:
return Error.model_validate_json(r.content)
def test_execute_hello(self):
r = requests.post(
url("execute"), json={"code": "print('hello')", "timeout_secs": 10}
)
res = self.must_bind_with_execute_response(r)
self.assertEqual(len(res.events), 1)
self.assertEqual(res.events[0].type, ExecutionEventTypeStream)
self.assertEqual(res.events[0].data.stream, "stdout") # type: ignore
self.assertEqual(res.events[0].data.text, "hello\n") # type: ignore
def test_execute_timeout(self):
r = requests.post(
url("execute"),
json={"code": "import time\ntime.sleep(5)", "timeout_secs": 1},
)
res = self.must_bind_with_execute_response(r)
self.assertEqual(len(res.events), 0)
self.assertEqual(res.status, ExecutionStatusTimeout)
def test_execute_syntax_error(self):
r = requests.post(
url("execute"), json={"code": "print('hello'", "timeout_secs": 10}
)
err = self.must_bind_with_execute_response(r)
self.assertEqual(err.status, ExecutionStatusOK)
self.assertEqual(len(err.events), 1)
self.assertEqual(err.events[0].type, ExecutionEventTypeError)
self.assertEqual(err.events[0].data.ename, "SyntaxError") # type: ignore
self.assertIsNotNone(err.events[0].data.evalue) # type: ignore
self.assertGreater(len(err.events[0].data.traceback), 0) # type: ignore
def test_execute_invalid_timeout(self):
r = requests.post(
url("execute"),
json={"code": "print('hello')", "timeout_secs": -1},
)
self.must_bind_with_error(r)
def test_execute_display_data(self):
code = """import matplotlib.pyplot as plt
plt.plot([1, 2, 3, 4])
plt.ylabel('some numbers')
plt.show()"""
r = requests.post(url("execute"), json={"code": code, "timeout_secs": 10})
res = self.must_bind_with_execute_response(r)
self.assertEqual(res.status, ExecutionStatusOK)
self.assertEqual(len(res.events), 1)
self.assertEqual(res.events[0].type, ExecutionEventTypeDisplayData)
self.assertIsNotNone(res.events[0].data.variants["image/png"]) # type: ignore
self.assertIsNotNone(res.events[0].data.variants["text/plain"]) # type: ignore
def test_execute_pil_image(self):
code = """from PIL import Image
img = Image.new('RGB', (60, 30), color = 'red')
# Override the show method of the Image class
def new_show(self, *args, **kwargs):
display(self)
Image.Image.show = new_show
img.show()"""
r = requests.post(url("execute"), json={"code": code, "timeout_secs": 10})
res = self.must_bind_with_execute_response(r)
self.assertEqual(res.status, ExecutionStatusOK)
self.assertEqual(len(res.events), 1)
self.assertEqual(res.events[0].type, ExecutionEventTypeDisplayData)
self.assertIsNotNone(res.events[0].data.variants["image/png"]) # type: ignore
self.assertIsNotNone(res.events[0].data.variants["text/plain"]) # type: ignore
class FileUploadHandlerTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.temp_dir = tempfile.mkdtemp()
cls.BASE_URL = f"http://localhost:8888/files/upload/-{cls.temp_dir}/"
def test_upload_file(self):
file_path = os.path.join(self.temp_dir, "test.txt")
large_binary_file = os.urandom(1024 * 1024 * 10) # 10 MB
r = requests.post(self.BASE_URL + "test.txt", data=large_binary_file)
self.assertEqual(r.status_code, 201)
self.assertTrue(os.path.exists(file_path))
with open(file_path, "rb") as f:
self.assertEqual(f.read(), large_binary_file)
def test_upload_existing_file(self):
file_path = os.path.join(self.temp_dir, "existing.txt")
with open(file_path, "wb") as f:
f.write(b"exists")
with open(file_path, "rb") as f:
r = requests.post(self.BASE_URL + "existing.txt", data=f.read())
self.assertEqual(r.status_code, 409)
error = Error.model_validate_json(r.content)
self.assertEqual(error.error, "file already exists")
def test_directory_creation(self):
file_path = os.path.join(self.temp_dir, "newdir", "test.txt")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
r = requests.post(self.BASE_URL + "newdir/test.txt", data=b"test content")
self.assertEqual(r.status_code, 201)
self.assertTrue(os.path.exists(file_path))
with open(file_path, "rb") as f:
self.assertEqual(f.read(), b"test content")
@classmethod
def tearDownClass(cls):
# Clean up the temp_dir after all tests
if os.path.exists(cls.temp_dir):
shutil.rmtree(cls.temp_dir)
if __name__ == "__main__":
unittest.main()
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## RAG Functionality
CodeGeeX4 supports RAG functionality and is compatible with the Langchain framework to achieve project-level retrieval Q&A.
## Tutorial
### 1. Install Dependencies
Navigate to the `langchain_demo` directory and install the required packages.
```bash
cd langchain_demo
pip install -r requirements.txt
```
### 2. Configure Embedding API Key
This project uses the Embedding API from the Zhipu Open Platform for vectorization. Please register and obtain an API Key first.
Then, configure the API Key in `models/embedding.py`.
For more details, refer to https://open.bigmodel.cn/dev/api#text_embedding.
### 3. Generate Vector Data
```bash
python vectorize.py --workspace . --output_path vectors
>>> File vectorization completed, saved to vectors
```
### 4. Run the Q&A Script
```bash
python chat.py --vector_path vectors
>>> Running on local URL: http://127.0.0.1:8080
```
## Demo
![](resources/demo.png)
\ No newline at end of file
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## RAG功能
CodeGeeX4支持RAG检索增强,并兼容Langchain框架,实现项目级检索问答。
## 使用教程
### 1. 安装依赖项
```bash
cd langchain_demo
pip install -r requirements.txt
```
### 2. 配置Embedding API Key
本项目使用智谱开放平台的Embedding API实现向量化功能,请先注册并获取API Key。
并在`models/embedding.py`中配置API Key。
详情可参考 https://open.bigmodel.cn/dev/api#text_embedding
### 3. 生成向量数据
```bash
python vectorize.py --workspace . --output_path vectors
>>> 文件向量化完成,已保存至vectors
```
### 4. 运行问答脚本
```bash
python chat.py --vector_path vectors
>>> Running on local URL: http://127.0.0.1:8080
```
## Demo
![](resources/demo_zh.png)
\ No newline at end of file
"""
References: https://python.langchain.com/v0.2/docs/tutorials/rag/
"""
import argparse
import gradio as gr
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from models.codegeex import CodegeexChatModel
from utils.prompts import CUSTOM_RAG_PROMPT
from utils.vector import load_vector_store
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--vector_path', type=str, help="path to load the vectors", default='vectors')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
return parser.parse_args()
def format_docs(docs):
return "\n\n".join(
[f"[[citation:{i + 1}]]\n```markdown\n{doc.page_content}\n```" for i, doc in enumerate(docs)]
)
def chat(query, history):
retrieve_chain = ({"context": retriever | format_docs, "question": RunnablePassthrough()} | CUSTOM_RAG_PROMPT)
retrieve_output = retrieve_chain.invoke(query)
ans = retrieve_output.text
yield ans
ans += "模型回复".center(150, '-') + '\n'
yield ans
parse_chain = (llm | StrOutputParser())
ans += parse_chain.invoke(retrieve_output)
yield ans
if __name__ == '__main__':
args = parse_arguments()
llm = CodegeexChatModel(args)
try:
retriever = load_vector_store(args.vector_path).as_retriever()
except Exception as e:
print(f"Fail to load vectors,caused by {e}")
exit()
demo = gr.ChatInterface(chat).queue()
demo.launch(server_name="127.0.0.1", server_port=8080)
from typing import Iterator
import torch
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessageChunk
from langchain_core.outputs import ChatGenerationChunk, ChatResult, ChatGeneration
from pydantic import Field
from transformers import AutoModel, AutoTokenizer
from utils.prompts import SYS_PROMPT
class CodegeexChatModel(BaseChatModel):
device: str = Field(description="device to load the model")
tokenizer = Field(description="model's tokenizer")
model = Field(description="Codegeex model")
temperature: float = Field(description="temperature to use for the model.")
def __init__(self, args):
super().__init__()
self.device = args.device
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True
).to(args.device).eval()
self.temperature = args.temperature
print("Model has been initialized.")
def _llm_type(self) -> str:
return "codegeex"
@torch.inference_mode()
def _generate(self, messages, **kwargs):
try:
response, _ = self.model.chat(
self.tokenizer,
query=messages[0].content,
history=[{"role": "system", "content": SYS_PROMPT}],
max_new_tokens=1024,
temperature=self.temperature
)
return ChatResult(generations=[ChatGeneration(message=BaseMessage(content=response, type='ai'))])
except Exception as e:
return ChatResult(generations=[ChatGeneration(message=BaseMessage(content=repr(e), type='ai'))])
def _stream(self, messages: list[BaseMessage], **kwargs) -> Iterator[ChatGenerationChunk]:
try:
for response, _ in self.model.stream_chat(
self.tokenizer,
query=messages[0].content,
history=[{"role": "system", "content": SYS_PROMPT}],
max_new_tokens=1024,
temperature=self.temperature
):
yield ChatGenerationChunk(message=AIMessageChunk(content=response))
except Exception as e:
yield ChatGenerationChunk(message=AIMessageChunk(content=f"Fail to generate, cause by {e}"))
import os
from langchain.schema.embeddings import Embeddings
from zhipuai import ZhipuAI
class GLMEmbeddings(Embeddings):
def __init__(self):
self.client = ZhipuAI(api_key=os.getenv("Zhipu_API_KEY"))
self.embedding_size = 1024
def embed_query(self, text: str) -> list[float]:
return self.embed_documents([text])[0]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
return self._get_len_safe_embeddings(texts)
def _get_len_safe_embeddings(self, texts: list[str]) -> list[list[float]]:
try:
# 获取embedding响应
response = self.client.embeddings.create(model="embedding-2", input=texts)
data = [item.embedding for item in response.data]
return data
except Exception as e:
print(f"Fail to get embeddings, caused by {e}")
return []
accelerate==0.31.0
faiss-cpu==1.8
gradio==4.26.0
langchain==0.2.3
langchain-community==0.2.4
regex==2024.5.15
requests==2.31.0
tiktoken==0.7.0
torch==2.3.1
tqdm==4.66.4
transformers==4.39.0
zhipuai~=2.0
\ No newline at end of file
import os
from langchain.text_splitter import (
Language,
RecursiveCharacterTextSplitter as TextSplitter,
)
from langchain_community.document_loaders import TextLoader
Languages = {
'c': Language.CPP,
'cpp': Language.CPP,
'go': Language.GO,
'java': Language.JAVA,
'js': Language.JS,
'md': Language.MARKDOWN,
'py': Language.PYTHON,
'ts': Language.TS,
}
def traverse(repo_path: str) -> list[str]:
"""
Traverse the directory, fetch all files
- skip hidden directories
- only keep the supported files
:param repo_path: path to this repo
"""
def helper(root):
for entry in os.scandir(root):
if entry.name.startswith('.'):
continue
if entry.is_file():
ext = entry.name.split('.')[-1].lower()
if ext not in Languages.keys():
continue
file_paths.append(entry.path)
elif entry.is_dir():
helper(entry.path)
file_paths = []
helper(repo_path)
return sorted(file_paths)
def split_into_chunks(file_path, chunk_size, overlap_size) -> list[str]:
"""
Split file into chunks
:param file_path: path to the file
:param chunk_size: size for each chunk
:param overlap_size: overlap size betweeen 2 chunks
"""
ext = file_path.split('.')[-1].lower()
lang = Languages.get(ext, None)
if not lang:
return []
try:
loader = TextLoader(file_path, encoding='utf-8', autodetect_encoding=True)
splitter = TextSplitter.from_language(lang, chunk_size=chunk_size, chunk_overlap=overlap_size)
return loader.load_and_split(splitter)
except Exception as e:
print(f'`{file_path}`切分失败: {e}')
return []
from langchain_core.prompts import PromptTemplate
SYS_PROMPT = """
你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
# Note
- 您将获得与问题相关的多个上下文片段,每个上下文都以引用编号开头,例如[[citation:x]],其中x是一个数字。如果适用,请使用上下文并在每个句子的末尾引用上下文。
- 您的答案必须是正确的、准确的,并且以专家的身份使用无偏见和专业的语调来撰写。
- 请你的回答限制在2千字以内,不要提供与问题无关的信息,也不要重复。
- 请以引用编号的格式[[citation:x]]来引用上下文。如果一个句子来自多个上下文,请列出所有适用的引用,例如[[citation:3]][[citation:5]]。
- 若所有上下文均不相关,请以自己的理解回答用户提出的问题,此时回答中可以不带引用编号。
- 除了代码和特定的名称和引用外,您的答案必须使用与问题相同的语言来撰写。
""".lstrip()
template = """
[引用]
{context}
问:{question}
""".lstrip()
CUSTOM_RAG_PROMPT = PromptTemplate.from_template(template)
import os
from langchain_community.docstore import InMemoryDocstore
from langchain_community.vectorstores.faiss import FAISS, dependable_faiss_import
from models.embedding import GLMEmbeddings
from tqdm import tqdm
from utils.data import split_into_chunks
embed_model = GLMEmbeddings()
def vectorize(files: list[str], args):
# split file into chunks
chunks = []
for file in tqdm(files, desc="文件切分"):
chunks.extend(split_into_chunks(file, args.chunk_size, args.overlap_size))
# initialize the vector store
vector_store = FAISS(
embedding_function=embed_model,
index=dependable_faiss_import().IndexFlatL2(embed_model.embedding_size),
docstore=InMemoryDocstore(),
index_to_docstore_id={},
)
# translate to vectors
batch_size = args.batch_size
for i in tqdm(range(0, len(chunks), batch_size), desc="向量化"):
try:
vector_store.add_documents(chunks[i:i + batch_size])
except Exception as e:
print(f"文件向量化失败,{e}")
# save embedded vectors
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
vector_store.save_local(output_path)
print(f"文件向量化完成,已保存至{output_path}")
def load_vector_store(vector_path: str):
return FAISS.load_local(vector_path, embed_model, allow_dangerous_deserialization=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment