Commit 1768a324 authored by dengjb's avatar dengjb
Browse files

update codes

parent 18493eef
Pipeline #1372 failed with stages
in 0 seconds
"""
Vectorize your local project
"""
import argparse
from utils.data import traverse
from utils.vector import vectorize
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--workspace', type=str, help="directory of the workspace to be vectorized", default='.')
parser.add_argument('--chunk_size', type=int, help="chunk size when splitting", default=512)
parser.add_argument('--overlap_size', type=int, help="chunk overlap when splitting", default=32)
parser.add_argument('--batch_size', type=int, help="embedding batch size", default=16)
parser.add_argument('--output_path', type=str, help="path to save the vectors", default='vectors')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
files = traverse(args.workspace)
vectorize(files, args)
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## RAG Functionality
CodeGeeX4 supports RAG retrieval enhancement and is compatible with the LlamaIndex framework to achieving project-level retrieval Q&A.
## Usage Tutorial
### 1. Install Dependencies
```bash
cd llamaindex_demo
pip install -r requirements.txt
```
Note: This project uses tree-sitter-language, which has compatibility issues with Python 3.10, so please use Python 3.8 or Python 3.9 to run
this project.
### 2. Configure Embedding API Key
This project uses the Zhipu Open Platform's Embedding API to implement vectorization. Please register and obtain an API Key first.
Then configure the API Key in `models/embedding.py`.
For details, refer to https://open.bigmodel.cn/dev/api#text_embedding
### 3. Generate Vector Data
```bash
python vectorize.py --workspace . --output_path vectors
>>> File vectorization completed, saved to vectors
```
### 4. Run the Q&A Script
```bash
python chat.py --vector_path vectors
>>> Running on local URL: http://127.0.0.1:8080
```
## Demo
![](resources/demo.png)
\ No newline at end of file
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## RAG功能
CodeGeeX4支持RAG检索增强,并兼容LlamaIndex框架,实现项目级检索问答。
## 使用教程
### 1. 安装依赖项
```bash
cd llamaindex_demo
pip install -r requirements.txt
```
注:此项目使用到tree-sitter-language,其与python3.10兼容的有问题,因此请使用python3.8或python3.9运行该项目。
### 2. 配置Embedding API Key
本项目使用智谱开放平台的Embedding API实现向量化功能,请先注册并获取API Key。
并在`models/embedding.py`中配置API Key。
详情可参考 https://open.bigmodel.cn/dev/api#text_embedding
### 3. 生成向量数据
```bash
python vectorize.py --workspace . --output_path vectors
>>> 文件向量化完成,已保存至vectors
```
### 4. 运行问答脚本
```bash
python chat.py --vector_path vectors
>>> Running on local URL: http://127.0.0.1:8080
```
## Demo
![](resources/demo_zh.png)
\ No newline at end of file
"""
References: https://docs.llamaindex.ai/en/stable/use_cases/q_and_a/
"""
import argparse
import gradio as gr
from llama_index.core import Settings
from models.embedding import GLMEmbeddings
from models.synthesizer import CodegeexSynthesizer
from utils.vector import load_vectors
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--vector_path', type=str, help="path to store the vectors", default='vectors')
parser.add_argument('--model_name_or_path', type=str, default='THUDM/codegeex4-all-9b')
parser.add_argument('--device', type=str, help="cpu or cuda", default="cpu")
parser.add_argument('--temperature', type=float, help="model's temperature", default=0.2)
return parser.parse_args()
def chat(query, history):
resp = query_engine.query(query)
ans = "相关文档".center(150, '-') + '\n'
yield ans
for i, node in enumerate(resp.source_nodes):
file_name = node.metadata['filename']
ext = node.metadata['extension']
text = node.text
ans += f"File{i + 1}: {file_name}\n```{ext}\n{text}\n```\n"
yield ans
ans += "模型回复".center(150, '-') + '\n'
ans += resp.response
yield ans
if __name__ == '__main__':
args = parse_arguments()
Settings.embed_model = GLMEmbeddings()
try:
query_engine = load_vectors(args.vector_path).as_query_engine(
response_synthesizer=CodegeexSynthesizer(args)
)
except Exception as e:
print(f"Fail to load vectors, caused by {e}")
exit()
demo = gr.ChatInterface(chat).queue()
demo.launch(server_name="127.0.0.1", server_port=8080)
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.core.llms import LLM
from pydantic import Field
from transformers import AutoTokenizer, AutoModel
from utils.prompts import SYS_PROMPT
class CodegeexChatModel(LLM):
device: str = Field(description="device to load the model")
tokenizer = Field(description="model's tokenizer")
model = Field(description="Codegeex model")
temperature: float = Field(description="temperature to use for the model.")
def __init__(self, args):
super().__init__()
self.device = args.device
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True).to(args.device).eval()
self.temperature = args.temperature
print("Model has been initialized.")
@classmethod
def class_name(cls) -> str:
return "codegeex"
@property
def metadata(self) -> LLMMetadata:
return LLMMetadata(
context_window=7168,
num_output=1024,
is_chat_model=True,
model_name="codegeex",
)
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
try:
response, _ = self.model.chat(
self.tokenizer,
query=messages[0].content,
history=[{"role": "system", "content": SYS_PROMPT}],
max_new_tokens=1024,
temperature=self.temperature
)
return ChatResponse(message=ChatMessage(role="assistant", content=response))
except Exception as e:
return ChatResponse(message=ChatMessage(role="assistant", content=e))
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
try:
for response, _ in self.model.stream_chat(
self.tokenizer,
query=messages[0].content,
history=[{"role": "system", "content": SYS_PROMPT}],
max_new_tokens=1024,
temperature=self.temperature
):
yield ChatResponse(message=ChatMessage(role="assistant", content=response))
except Exception as e:
yield ChatResponse(message=ChatMessage(role="assistant", content=e))
def complete(self, prompt: str, formatted: bool = False, **kwargs) -> CompletionResponse:
try:
response, _ = self.model.chat(
self.tokenizer,
query=prompt,
history=[{"role": "system", "content": "你是一个智能编程助手"}],
max_new_tokens=1024,
temperature=self.temperature
)
return CompletionResponse(text=response)
except Exception as e:
return CompletionResponse(text=e)
def stream_complete(self, prompt: str, formatted: bool = False, **kwargs) -> CompletionResponseGen:
try:
for response, _ in self.model.stream_chat(
self.tokenizer,
query=prompt,
history=[{"role": "system", "content": "你是一个智能编程助手"}],
max_new_tokens=1024,
temperature=self.temperature
):
yield CompletionResponse(text=response)
except Exception as e:
yield CompletionResponse(text=e)
async def achat(self, messages: list[ChatMessage], **kwargs):
return await self.chat(messages, **kwargs)
async def astream_chat(self, messages: list[ChatMessage], **kwargs):
async for resp in self.stream_chat(messages, **kwargs):
yield resp
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs):
return await self.complete(prompt, formatted, **kwargs)
async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs):
async for resp in self.stream_complete(prompt, formatted, **kwargs):
yield resp
import os
from llama_index.core.base.embeddings.base import BaseEmbedding
from pydantic import Field
from zhipuai import ZhipuAI
class GLMEmbeddings(BaseEmbedding):
client = Field(description="embedding model client")
embedding_size: float = Field(description="embedding size")
def __init__(self):
super().__init__(model_name='GLM', embed_batch_size=64)
self.client = ZhipuAI(api_key=os.getenv("Zhipu_API_KEY"))
self.embedding_size = 1024
def _get_query_embedding(self, query: str) -> list[float]:
return self._get_text_embeddings([query])[0]
def _get_text_embedding(self, text: str) -> list[float]:
return self._get_text_embeddings([text])[0]
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
return self._get_len_safe_embeddings(texts)
async def _aget_query_embedding(self, query: str) -> list[float]:
return self._get_query_embedding(query)
def _get_len_safe_embeddings(self, texts: list[str]) -> list[list[float]]:
try:
# 获取embedding响应
response = self.client.embeddings.create(model="embedding-2", input=texts)
data = [item.embedding for item in response.data]
return data
except Exception as e:
print(f"Fail to get embeddings, caused by {e}")
return []
from llama_index.core.response_synthesizers import BaseSynthesizer
from models.codegeex import CodegeexChatModel
from utils.prompts import CUSTOM_PROMPT_TEMPLATE
class CodegeexSynthesizer(BaseSynthesizer):
"""Response builder class."""
def __init__(self, args) -> None:
super().__init__(llm=CodegeexChatModel(args))
self.prompt_template = CUSTOM_PROMPT_TEMPLATE
def get_response(self, query_str: str, text_chunks: list[str], **kwargs) -> str:
context = self.build_context(text_chunks)
return self._llm.predict(self.prompt_template, query=query_str, context=context)
async def aget_response(self, query_str: str, text_chunks: list[str], **kwargs) -> str:
context = self.build_context(text_chunks)
return await self._llm.apredict(self.prompt_template, query=query_str, context=context)
def _get_prompts(self):
"""Get prompts."""
return {"text_qa_template": self.prompt_template}
def _update_prompts(self, prompts) -> None:
"""Update prompts."""
if "text_qa_template" in prompts:
self.prompt_template = prompts["text_qa_template"]
@staticmethod
def build_context(text_chunks):
"""
merge contexts
:param text_chunks: recalled texts
"""
return "\n\n".join(
[f"[[citation:{i + 1}]]\n```markdown\n{chunk}\n```" for i, chunk in enumerate(text_chunks)]
)
accelerate==0.31.0
faiss-cpu==1.8
gradio==4.26.0
llama-index==0.10.43
regex==2024.5.15
tiktoken==0.7.0
torch==2.3.1
tree-sitter<0.22.0
tree-sitter-languages==1.10.2
tqdm==4.66.4
transformers==4.39.0
zhipuai~=2.0
import os
from pathlib import Path
from llama_index.core.node_parser import CodeSplitter
from llama_index.core.schema import BaseNode
from llama_index.readers.file import FlatReader
Languages = {
'c': "c",
'cpp': "cpp",
'go': "go",
'java': "java",
'js': "javascript",
'md': "markdown",
'py': "python",
'ts': "typescript",
}
def traverse(repo_path: str) -> list[str]:
"""
Traverse the directory, fetch all files
- skip hidden directories
- only keep the supported files
:param repo_path: path to this repo
"""
def helper(root):
for entry in os.scandir(root):
if entry.name.startswith('.'):
continue
if entry.is_file():
ext = entry.name.split('.')[-1].lower()
if ext not in Languages.keys():
continue
file_paths.append(entry.path)
elif entry.is_dir():
helper(entry.path)
file_paths = []
helper(repo_path)
return sorted(file_paths)
def split_into_chunks(file_path, lines_per_chunk, lines_overlap, max_chars) -> list[BaseNode]:
"""
Split file into chunks
:param file_path: path to the file
:param lines_per_chunk: lines for each chunk
:param lines_overlap: overlap lines between 2 chunks
:param max_chars: max characters for each chunk
"""
ext = file_path.split('.')[-1].lower()
lang = Languages.get(ext, None)
if not lang:
return []
try:
documents = FlatReader().load_data(Path(file_path))
splitter = CodeSplitter(
language=lang,
chunk_lines=lines_per_chunk,
chunk_lines_overlap=lines_overlap,
max_chars=max_chars,
)
return splitter.get_nodes_from_documents(documents)
except Exception as e:
print(f'`{file_path}`切分失败: {e}')
return []
from llama_index.core import PromptTemplate
SYS_PROMPT = """
你将接收到一个用户提出的问题,并请撰写清晰、简洁且准确的答案。
# Note
- 您将获得与问题相关的多个上下文片段,每个上下文都以引用编号开头,例如[[citation:x]],其中x是一个数字。如果适用,请使用上下文并在每个句子的末尾引用上下文。
- 您的答案必须是正确的、准确的,并且以专家的身份使用无偏见和专业的语调来撰写。
- 请你的回答限制在2千字以内,不要提供与问题无关的信息,也不要重复。
- 请以引用编号的格式[[citation:x]]来引用上下文。如果一个句子来自多个上下文,请列出所有适用的引用,例如[[citation:3]][[citation:5]]。
- 若所有上下文均不相关,请以自己的理解回答用户提出的问题,此时回答中可以不带引用编号。
- 除了代码和特定的名称和引用外,您的答案必须使用与问题相同的语言来撰写。
""".lstrip()
template = """
[引用]
{context}
问:{query}
""".lstrip()
CUSTOM_PROMPT_TEMPLATE = PromptTemplate(template, prompt_type='text_qa')
import os
import faiss
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.legacy.vector_stores import FaissVectorStore
from models.embedding import GLMEmbeddings
from tqdm import tqdm
from utils.data import split_into_chunks
embed_model = GLMEmbeddings()
def save_vectors(files: list[str], args):
# split file into chunks
nodes = []
for file in tqdm(files, desc="文件切分"):
nodes.extend(split_into_chunks(file, args.lines_per_chunk, args.lines_overlap, args.max_chars))
# initialize vector store
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(embed_model.embedding_size))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# translate to vectors
index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, embed_model=embed_model)
# save embedded vectors
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
index.storage_context.persist(persist_dir=output_path)
print(f"文件向量化完成,已保存至{output_path}")
def load_vectors(vector_path: str):
vector_store = FaissVectorStore.from_persist_dir(vector_path)
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=vector_path)
return load_index_from_storage(storage_context=storage_context)
import argparse
from utils.data import traverse
from utils.vector import save_vectors
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--workspace', type=str, help="directory of the workspace to be vectorized", default='.')
parser.add_argument('--lines_per_chunk', type=int, help="chunk lines when splitting", default=40)
parser.add_argument('--lines_overlap', type=int, help="chunk lines overlap when splitting", default=15)
parser.add_argument("--max_chars", type=int, help="maximum number of characters in a chunk", default=1500)
parser.add_argument('--output_path', type=str, help="path to save the vectors", default='vectors')
return parser.parse_args()
if __name__ == '__main__':
args = parse_arguments()
files = traverse(args.workspace)
save_vectors(files, args)
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## Local Mode
The new version of the CodeGeeX plugin **supports offline mode**, allowing the use of offline deployed models to complete automatic
completion and simple conversation functions.
## Usage Tutorial
### 1. Install Dependencies
```bash
cd local_mode
pip install -r requirements.txt
```
### 2. Run the Project
```bash
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
>>> Running on local URL: http://127.0.0.1:8080
```
### 3. Set API Address and Model Name
As shown in the figure below, after opening the plugin with the local mode, enter the API address and model name in the settings.
![](resources/pic1.png)
### 4. Start Using
Click 'Connect' to test, or click 'Ask CodeGeeX' to start using.
## Demo
![](resources/demo.gif)
\ No newline at end of file
![](../resources/logo.jpeg)
[English](README.md) | [中文](README_zh.md)
## 本地模式
CodeGeeX新版插件**支持离线模式**,可使用离线部署的模型完成自动补全以及简单对话功能。
## 使用教程
### 1. 安装依赖项
```bash
cd local_mode
pip install -r requirements.txt
```
### 2. 运行项目
```bash
python main.py --model_name_or_path THUDM/codegeex4-all-9b --device cuda --bf16 true
>>> Running on local URL: http://127.0.0.1:8080
```
### 3. 设置api地址和模型名称
如下图所示,打开插件后进入本地模式,在设置中输入api地址和模型名称。
![](resources/pic1.png)
### 4. 开始使用
点击‘连接’进行测试,或点击‘Ask CodeGeeX’即可开始使用。
## Demo
![](resources/demo_zh.gif)
\ No newline at end of file
"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import argparse
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.responses import StreamingResponse
from protocols.openai_api import ChatCompletionRequest
from services.chat import init_model, chat_with_codegeex, stream_chat_with_codegeex
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", type=str, default="THUDM/codegeex4-all-9b")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--bf16", type=bool, default=False)
return parser.parse_args()
@app.post("/v1/chat/completions")
async def chat(request: ChatCompletionRequest):
try:
if request.stream:
return StreamingResponse(stream_chat_with_codegeex(request), media_type="text/event-stream")
else:
return JSONResponse(chat_with_codegeex(request))
except Exception as e:
return JSONResponse(e, status_code=500)
if __name__ == "__main__":
args = parse_arguments()
init_model(args)
uvicorn.run(app, host="127.0.0.1", port=8080)
"""
coding : utf-8
@Date : 2024/7/10
@Author : Shaobo
@Describe:
"""
import torch
from protocols.openai_api import ChatCompletionRequest, ChatCompletionStreamResponse, ChatCompletionResponse
from sseclient import Event
from transformers import AutoTokenizer, AutoModel
class CodegeexChatModel:
def __init__(self, args):
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
if args.bf16:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(args.device).eval()
else:
self.model = AutoModel.from_pretrained(
args.model_name_or_path,
trust_remote_code=True
).to(args.device).eval()
print("Model is initialized.")
def stream_chat(self, request: ChatCompletionRequest):
try:
inputs = self.tokenizer.apply_chat_template(
conversation=[msg.model_dump() for msg in request.messages],
add_generation_prompt=True,
return_tensors="pt",
return_dict=True
).to(self.model.device)
gen_configs = {
"max_new_tokens": request.max_tokens,
"temperature": request.temperature,
"top_p": request.top_p,
"repetition_penalty": request.presence_penalty,
"do_sample": True if request.temperature else request.temperature,
}
length = 0
for i, outputs in enumerate(self.model.stream_generate(**inputs, **gen_configs)):
response = self.tokenizer.decode(outputs.tolist()[0][len(inputs["input_ids"][0]):-1])
if not response or response[-1] == "�":
continue
resp = ChatCompletionStreamResponse()
resp.choices[0].index = i
resp.choices[0].delta.content = response[length:]
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
length = len(response)
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=resp.json(), event='message')
yield event.dump()
except Exception as e:
resp = ChatCompletionStreamResponse()
resp.choices[0].finish_reason = 'stop'
event = Event(id=resp.id, data=f"请求报错,错误原因:{e}", event='message')
yield event.dump()
def chat(self, request: ChatCompletionRequest):
try:
response, _ = self.model.chat(
self.tokenizer,
query=request.messages[-1].content,
history=[msg.model_dump() for msg in request.messages[:-1]],
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
repetition_penalty=request.presence_penalty
)
resp = ChatCompletionResponse()
resp.choices[0].message.content = response
resp.choices[0].finish_reason = 'stop'
return resp.model_dump()
except Exception as e:
return f"请求报错,错误原因:{e}"
"""
coding : utf-8
@Date : 2024/7/11
@Author : Shaobo
@Describe:
"""
import time
from typing import Literal
import shortuuid
from pydantic import BaseModel
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = "codegeex4"
messages: list[ChatMessage]
temperature: float = 0.2
top_p: float = 1.0
max_tokens: int = 1024
stop: list[str] = ['<|user|>', '<|assistant|>', '<|observation|>', '<|endoftext|>']
stream: bool = True
presence_penalty: float = None
class DeltaMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseStreamChoice(BaseModel):
index: int = 0
delta: DeltaMessage = DeltaMessage(role='assistant', content='')
finish_reason: Literal["stop", "length"] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = f"chatcmpl-{shortuuid.random()}"
object: str = "chat.completion.chunk"
created: int = int(time.time())
model: str = "codegeex4"
choices: list[ChatCompletionResponseStreamChoice] = [ChatCompletionResponseStreamChoice()]
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage = ChatMessage(role="assistant", content="")
finish_reason: Literal["stop", "length"] = None
class ChatCompletionResponse(BaseModel):
id: str = f"chatcmpl-{shortuuid.random()}"
object: str = "chat.completion"
created: int = int(time.time())
model: str = "codegeex4"
choices: list[ChatCompletionResponseChoice] = [ChatCompletionResponseChoice()]
# usage: UsageInfo
accelerate==0.31.0
fastapi==0.111.0
openai==1.35.12
pydantic==2.8.2
regex==2024.5.15
requests==2.32.3
shortuuid==1.0.13
sseclient==0.0.27
starlette==0.37.2
tiktoken==0.7.0
torch==2.3.1
transformers==4.39.0
uvicorn==0.30.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment