"container/vscode:/vscode.git/clone" did not exist on "cbd20c30067cff4b72b2c37fd715f2a47f79d09c"
Unverified Commit e53e729d authored by YeAnbang's avatar YeAnbang Committed by GitHub
Browse files

[Feature] Add document retrieval QA (#5020)



* add langchain

* add langchain

* Add files via upload

* add langchain

* fix style

* fix style: remove extra space

* add pytest; modified retriever

* add pytest; modified retriever

* add tests to build_on_pr.yml

* fix build_on_pr.yml

* fix build on pr; fix environ vars

* seperate unit tests for colossalqa from build from pr

* fix container setting; fix environ vars

* commented dev code

* add incremental update

* remove stale code

* fix style

* change to sha3 224

* fix retriever; fix style; add unit test for document loader

* fix ci workflow config

* fix ci workflow config

* add set cuda visible device script in ci

* fix doc string

* fix style; update readme; refactored

* add force log info

* change build on pr, ignore colossalqa

* fix docstring, captitalize all initial letters

* fix indexing; fix text-splitter

* remove debug code, update reference

* reset previous commit

* update LICENSE update README add key-value mode, fix bugs

* add files back

* revert force push

* remove junk file

* add test files

* fix retriever bug, add intent classification

* change conversation chain design

* rewrite prompt and conversation chain

* add ui v1

* ui v1

* fix atavar

* add header

* Refactor the RAG Code and support Pangu

* Refactor the ColossalQA chain to Object-Oriented Programming and the UI demo.

* resolved conversation. tested scripts under examples. web demo still buggy

* fix ci tests

* Some modifications to add ChatGPT api

* modify llm.py and remove unnecessary files

* Delete applications/ColossalQA/examples/ui/test_frontend_input.json

* Remove OpenAI api key

* add colossalqa

* move files

* move files

* move files

* move files

* fix style

* Add Readme and fix some bugs.

* Add something to readme and modify some code

* modify a directory name for clarity

* remove redundant directory

* Correct a type in  llm.py

* fix AI prefix

* fix test_memory.py

* fix conversation

* fix some erros and typos

* Fix a missing import in RAG_ChatBot.py

* add colossalcloud LLM wrapper, correct issues in code review

---------
Co-authored-by: default avatarYeAnbang <anbangy2@outlook.com>
Co-authored-by: default avatarOrion-Zheng <zheng_zian@u.nus.edu>
Co-authored-by: default avatarZian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: default avatarOrion-Zheng <zhengzian@u.nus.edu>
parent 3acbf6d4
name: Run colossalqa unit tests
on:
pull_request:
types: [synchronize, opened, reopened]
paths:
- 'applications/ColossalQA/colossalqa/**'
- 'applications/ColossalQA/requirements.txt'
- 'applications/ColossalQA/setup.py'
- 'applications/ColossalQA/tests/**'
- 'applications/ColossalQA/pytest.ini'
jobs:
tests:
name: Run colossalqa unit tests
if: |
github.event.pull_request.draft == false &&
github.base_ref == 'main' &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
container:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
volumes:
- /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
- /data/scratch/llama-tiny:/data/scratch/llama-tiny
options: --gpus all --rm
timeout-minutes: 30
defaults:
run:
shell: bash
steps:
- name: Checkout ColossalAI
uses: actions/checkout@v2
- name: Install colossalqa
run: |
cd applications/ColossalQA
pip install -e .
- name: Execute Unit Testing
run: |
cd applications/ColossalQA
pytest tests/
env:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
ZH_MODEL_PATH: bigscience/bloom-560m
ZH_MODEL_NAME: bloom
EN_MODEL_PATH: bigscience/bloom-560m
EN_MODEL_NAME: bloom
TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt
TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt
TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/*
SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path
\ No newline at end of file
......@@ -527,3 +527,28 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
---------------- LICENSE FOR LangChain TEAM ----------------
The MIT License
Copyright (c) Harrison Chase
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/.build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# IDE
.idea/
.vscode/
# macos
*.DS_Store
#data/
docs/.build
# pytorch checkpoint
*.pt
# sql
*.db
# wandb log
example/wandb/
example/ui/gradio/
example/vector_db_for_test
examples/awesome-chatgpt-prompts/
# ColossalQA - Langchain-based Document Retrieval Conversation System
## Table of Contents
- [Table of Contents](#table-of-contents)
- [Overall Implementation](#overall-implementation)
- [Install](#install)
- [How to Use](#how-to-use)
- Examples
- [A Simple Web UI Demo](examples/webui_demo/README.md)
- [Local Chinese Retrieval QA + Chat](examples/retrieval_conversation_zh.py)
- [Local English Retrieval QA + Chat](examples/retrieval_conversation_en.py)
- [Local Bi-lingual Retrieval QA + Chat](examples/retrieval_conversation_universal.py)
- [Experimental AI Agent Based on Chatgpt + Chat](examples/conversation_agent_chatgpt.py)
- Use cases
- [English customer service chatbot](examples/retrieval_conversation_en_customer_service.py)
- [Chinese customer service intent classification](examples/retrieval_intent_classification_zh_customer_service.py)
**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.**
## Overall Implementation
### Highlevel Design
![Alt text](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/diagram.png "Fig.1. Design of the document retrieval conversation system")
<p align="center">
Fig.1. Design of the document retrieval conversation system
</p>
Retrieval-based Question Answering (QA) is a crucial application of natural language processing that aims to find the most relevant answers based on the information from a corpus of text documents in response to user queries. Vector stores, which represent documents and queries as vectors in a high-dimensional space, have gained popularity for their effectiveness in retrieval QA tasks.
#### Step 1: Collect Data
A successful retrieval QA system starts with high-quality data. You need a collection of text documents that's related to your application. You may also need to manually design how your data will be presented to the language model.
#### Step 2: Split Data
Document data is usually too long to fit into the prompt due to the context length limitation of LLMs. Supporting documents need to be splited into short chunks before constructing vector stores. In this demo, we use neural text spliter for better performance.
#### Step 3: Construct Vector Stores
Choose a embedding function and embed your text chunk into high dimensional vectors. Once you have vectors for your documents, you need to create a vector store. The vector store should efficiently index and retrieve documents based on vector similarity. In this demo, we use [Chroma](https://python.langchain.com/docs/integrations/vectorstores/chroma) and incrementally update indexes of vector stores. Through incremental update, one can update and maintain a vector store without recalculating every embedding.
You are free to choose any vectorstore from a varity of [vector stores](https://python.langchain.com/docs/integrations/vectorstores/) supported by Langchain. However, the incremental update only works with LangChain vectorstore's that support:
- Document addition by id (add_documents method with ids argument)
- Delete by id (delete method with)
#### Step 4: Retrieve Relative Text
Upon querying, we will run a reference resolution on user's input, the goal of this step is to remove ambiguous reference in user's query such as "this company", "him". We then embed the query with the same embedding function and query the vectorstore to retrieve the top-k most similar documents.
#### Step 5: Format Prompt
The prompt carries essential information including task description, conversation history, retrived documents, and user's query for the LLM to generate a response. Please refer to this [README](./colossalqa/prompt/README.md) for more details.
#### Step 6: Inference
Pass the prompt to the LLM with additional generaton arguments to get agent response. You can control the generation with additional arguments such as temperature, top_k, top_p, max_new_tokens. You can also define when to stop by passing the stop substring to the retrieval QA chain.
#### Step 7: Update Memory
We designed a memory module that automatically summarize overlength conversation to fit the max context length of LLM. In this step, we update the memory with the newly generated response. To fix into the context length of a given LLM, we sumarize the overlength part of historical conversation and present the rest in round-based conversation format. Fig.2. shows how the memory is updated. Please refer to this [README](./colossalqa/prompt/README.md) for dialogue format.
![Alt text](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/memory.png "Fig.2. Design of the memory module")
<p align="center">
Fig.2. Design of the memory module
</p>
### Supported Language Models (LLMs) and Embedding Models
Our platform accommodates two kinds of LLMs: API-accessible and locally run models. For the API-style LLMs, we support ChatGPT, Pangu, and models deployed through the vLLM API Server. For locally operated LLMs, we are compatible with any language model that can be initiated using [`transformers.AutoModel.from_pretrained`](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#transformers.AutoModel.from_pretrained). However, due to the dependence of retrieval-based QA on the language model's abilities in zero-shot learning, instruction following, and logical reasoning, smaller models are typically not advised. In our local demo, we utilize ChatGLM2 for Chinese and LLaMa2 for English. Modifying the base LLM requires corresponding adjustments to the prompts.
Here are some sample codes to load different types of LLM.
```python
# For locally-run LLM
from colossalqa.local.llm import ColossalAPI, ColossalLLM
api = ColossalAPI('chatglm2', 'path_to_chatglm2_checkpoint')
llm = ColossalLLM(n=1, api=api)
# For LLMs running on the vLLM API Server
from colossalqa.local.llm import VllmAPI, VllmLLM
vllm_api = VllmAPI("Your_vLLM_Host", "Your_vLLM_Port")
llm = VllmLLM(n=1, api=vllm_api)
# For ChatGPT LLM
from langchain.llms import OpenAI
llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
# For Pangu LLM
# set up your authentification info
from colossalqa.local.pangu_llm import Pangu
os.environ["URL"] = ""
os.environ["URLNAME"] = ""
os.environ["PASSWORD"] = ""
os.environ["DOMAIN_NAME"] = ""
llm = Pangu(id=1)
llm.set_auth_config()
```
Regarding embedding models, we support all models that can be loaded via ["langchain.embeddings.HuggingFaceEmbeddings"](https://api.python.langchain.com/en/latest/embeddings/langchain.embeddings.huggingface.HuggingFaceEmbeddings.html). The default embedding model used in this demo is ["moka-ai/m3e-base"](https://huggingface.co/moka-ai/m3e-base), which enables consistent text similarity computations in both Chinese and English.
In the future, supported LLM will also include models running on colossal inference and serving framework.
## Install
Install colossalqa
```bash
# python==3.8.17
cd ColossalAI/applications/ColossalQA
pip install -e .
```
To use the vLLM for providing LLM services via an API, please consult the official guide [here](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#api-server) to start the API server. It's important to set up a new virtual environment for installing vLLM, as there are currently some dependency conflicts between vLLM and ColossalQA when installed on the same machine.
## How to Use
### Collect Your Data
For ChatGPT based Agent we support document retrieval and simple sql search.
If you want to run the demo locally, we provided document retrieval based conversation system built upon langchain. It accept a wide range of documents. After collecting your data, put your data under a folder.
Read comments under ./colossalqa/data_loader for more detail regarding supported data formats.
### Run The Script
We provide a simple Web UI demo of ColossalQA, enabling you to upload your files as a knowledge base and interact with them through a chat interface in your browser. More details can be found [here](examples/webui_demo/README.md)
![ColossalQA Demo](https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/colossalqa/ui.png)
We also provided some scripts for Chinese document retrieval based conversation system, English document retrieval based conversation system, Bi-lingual document retrieval based conversation system and an experimental AI agent with document retrieval and SQL query functionality. The Bi-lingual one is a high-level wrapper for the other two clases. We write different scripts for different languages because retrieval QA requires different embedding models, LLMs, prompts for different language setting. For now, we use LLaMa2 for English retrieval QA and ChatGLM2 for Chinese retrieval QA for better performance.
To run the bi-lingual scripts.
```bash
python retrieval_conversation_universal.py \
--en_model_path /path/to/Llama-2-7b-hf \
--zh_model_path /path/to/chatglm2-6b \
--zh_model_name chatglm2 \
--en_model_name llama \
--sql_file_path /path/to/any/folder
```
To run retrieval_conversation_en.py.
```bash
python retrieval_conversation_en.py \
--model_path /path/to/Llama-2-7b-hf \
--model_name llama \
--sql_file_path /path/to/any/folder
```
To run retrieval_conversation_zh.py.
```bash
python retrieval_conversation_zh.py \
--model_path /path/to/chatglm2-6b \
--model_name chatglm2 \
--sql_file_path /path/to/any/folder
```
To run retrieval_conversation_chatgpt.py.
```bash
python retrieval_conversation_chatgpt.py \
--open_ai_key_path /path/to/plain/text/openai/key/file \
--sql_file_path /path/to/any/folder
```
To run conversation_agent_chatgpt.py.
```bash
python conversation_agent_chatgpt.py \
--open_ai_key_path /path/to/plain/text/openai/key/file
```
After runing the script, it will ask you to provide the path to your data during the execution of the script. You can also pass a glob path to load multiple files at once. Please read this [guide](https://docs.python.org/3/library/glob.html) on how to define glob path. Follow the instruction and provide all files for your retrieval conversation system then type "ESC" to finish loading documents. If csv files are provided, please use "," as delimiter and "\"" as quotation mark. For json and jsonl files. The default format is
```
{
"data":[
{"content":"XXX"},
{"content":"XXX"}
...
]
}
```
For other formats, please refer to [this document](https://python.langchain.com/docs/modules/data_connection/document_loaders/json) on how to define schema for data loading. There are no other formatting constraints for loading documents type files. For loading table type files, we use pandas, please refer to [Pandas-Input/Output](https://pandas.pydata.org/pandas-docs/stable/reference/io.html) for file format details.
We also support another kay-value mode that utilizes a user-defined key to calculate the embeddings of the vector store. If a query matches a specific key, the value corresponding to that key will be used to generate the prompt. For instance, in the document below, "My coupon isn't working." will be employed during indexing, whereas "Question: My coupon isn't working.\nAnswer: We apologize for ... apply it to?" will appear in the final prompt. This format is typically useful when the task involves carrying on a conversation with readily accessible conversation data, such as customer service, question answering.
```python
Document(page_content="My coupon isn't working.", metadata={'is_key_value_mapping': True, 'seq_num': 36, 'source': 'XXX.json', 'value': "Question: My coupon isn't working.\nAnswer:We apologize for the inconvenience. Can you please provide the coupon code and the product name or SKU you're trying to apply it to?"})
```
For now, we only support the key-value mode for json data files. You can run the script retrieval_conversation_en_customer_service.py by the following command.
```bash
python retrieval_conversation_en_customer_service.py \
--model_path /path/to/Llama-2-7b-hf \
--model_name llama \
--sql_file_path /path/to/any/folder
```
## The Plan
- [x] build document retrieval QA tool
- [x] Add memory
- [x] Add demo for AI agent with SQL query
- [x] Add customer retriever for fast construction and retrieving (with incremental update)
## Reference
```bibtex
@software{Chase_LangChain_2022,
author = {Chase, Harrison},
month = oct,
title = {{LangChain}},
url = {https://github.com/hwchase17/langchain},
year = {2022}
}
```
```bibtex
@inproceedings{DBLP:conf/asru/ZhangCLLW21,
author = {Qinglin Zhang and
Qian Chen and
Yali Li and
Jiaqing Liu and
Wen Wang},
title = {Sequence Model with Self-Adaptive Sliding Window for Efficient Spoken
Document Segmentation},
booktitle = {{IEEE} Automatic Speech Recognition and Understanding Workshop, {ASRU}
2021, Cartagena, Colombia, December 13-17, 2021},
pages = {411--418},
publisher = {{IEEE}},
year = {2021},
url = {https://doi.org/10.1109/ASRU51503.2021.9688078},
doi = {10.1109/ASRU51503.2021.9688078},
timestamp = {Wed, 09 Feb 2022 09:03:04 +0100},
biburl = {https://dblp.org/rec/conf/asru/ZhangCLLW21.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
```bibtex
@misc{touvron2023llama,
title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
year={2023},
eprint={2307.09288},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
```bibtex
@article{zeng2022glm,
title={Glm-130b: An open bilingual pre-trained model},
author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
journal={arXiv preprint arXiv:2210.02414},
year={2022}
}
```
```bibtex
@inproceedings{du2022glm,
title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
pages={320--335},
year={2022}
}
```
"""
Custom SummarizerMixin base class and ConversationSummaryMemory class
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
from typing import Any, Dict, List, Type
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import BaseChatMessageHistory, BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
class SummarizerMixin(BaseModel):
"""
Mixin for summarizer.
"""
human_prefix: str = "Human"
ai_prefix: str = "Assistant"
llm: BaseLanguageModel
prompt: BasePromptTemplate = SUMMARY_PROMPT
summary_message_cls: Type[BaseMessage] = SystemMessage
llm_kwargs: Dict = {}
def predict_new_summary(self, messages: List[BaseMessage], existing_summary: str, stop: List = []) -> str:
"""
Recursively summarize a conversation by generating a new summary using
the last round of conversation and the existing summary.
"""
new_lines = get_buffer_string(
messages,
human_prefix=self.human_prefix,
ai_prefix=self.ai_prefix,
)
chain = LLMChain(llm=self.llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
return chain.predict(summary=existing_summary, new_lines=new_lines, stop=stop)
class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""Conversation summarizer to chat memory."""
buffer: str = ""
memory_key: str = "history"
@classmethod
def from_messages(
cls,
llm: BaseLanguageModel,
chat_memory: BaseChatMessageHistory,
summarize_step: int = 2,
**kwargs: Any,
) -> ConversationSummaryMemory:
obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
for i in range(0, len(obj.chat_memory.messages), summarize_step):
obj.buffer = obj.predict_new_summary(obj.chat_memory.messages[i : i + summarize_step], obj.buffer)
return obj
@property
def memory_variables(self) -> List[str]:
"""Will always return list of memory variables."""
return [self.memory_key]
def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Return history buffer."""
if self.return_messages:
buffer: Any = [self.summary_message_cls(content=self.buffer)]
else:
buffer = self.buffer
return {self.memory_key: buffer}
@root_validator()
def validate_prompt_input_variables(cls, values: Dict) -> Dict:
"""Validate that prompt input variables are consistent."""
prompt_variables = values["prompt"].input_variables
expected_keys = {"summary", "new_lines"}
if expected_keys != set(prompt_variables):
raise ValueError(
"Got unexpected prompt input variables. The prompt expects "
f"{prompt_variables}, but it should have {expected_keys}."
)
return values
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
"""Save context from this conversation to buffer."""
super().save_context(inputs, outputs)
self.buffer = self.predict_new_summary(self.chat_memory.messages[-2:], self.buffer)
def clear(self) -> None:
"""Clear memory contents."""
super().clear()
self.buffer = ""
"""
Chain for question-answering against a vector database.
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
from __future__ import annotations
import copy
import inspect
from typing import Any, Dict, List, Optional
from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.prompts import PromptTemplate
from langchain.pydantic_v1 import Field
from langchain.schema import BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
class CustomBaseRetrievalQA(BaseRetrievalQA):
"""Base class for question-answering chains."""
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Initialize from LLM."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
document_prompt = kwargs.get(
"document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
)
combine_documents_chain = CustomStuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=callbacks,
)
return cls(
combine_documents_chain=combine_documents_chain,
callbacks=callbacks,
**kwargs,
)
@classmethod
def from_chain_type(
cls,
llm: BaseLanguageModel,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Load chain from chain type."""
llm_kwargs = kwargs.pop("llm_kwargs", {})
_chain_type_kwargs = chain_type_kwargs or {}
combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
if accepts_run_manager:
docs = self._get_docs(question, run_manager=_run_manager)
else:
docs = self._get_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
answers = []
if self.combine_documents_chain.memory is not None:
buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
self.combine_documents_chain.memory.buffered_history
), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
else:
buffered_history_backup = None
summarized_history_temp_backup = None
answer = self.combine_documents_chain.run(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
if summarized_history_temp_backup is not None and buffered_history_backup is not None:
(
self.combine_documents_chain.memory.buffered_history,
self.combine_documents_chain.memory.summarized_history_temp,
) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
if self.combine_documents_chain.memory is not None:
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run get_relevant_text and llm on input query.
If chain has 'return_source_documents' as 'True', returns
the retrieved documents as well under the key 'source_documents'.
Example:
.. code-block:: python
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
if accepts_run_manager:
docs = await self._aget_docs(question, run_manager=_run_manager)
else:
docs = await self._aget_docs(question) # type: ignore[call-arg]
kwargs = {
k: v
for k, v in inputs.items()
if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
}
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
)
# if rejection_trigger_keywords is not given, return the response from LLM directly
rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None
if answer is None:
answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
return {self.output_key: answer}
class RetrievalQA(CustomBaseRetrievalQA):
"""Chain for question-answering against an index.
Example:
.. code-block:: python
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.faiss import FAISS
from langchain.vectorstores.base import VectorStoreRetriever
retriever = VectorStoreRetriever(vectorstore=FAISS(...))
retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
"""
retriever: BaseRetriever = Field(exclude=True)
def _get_docs(
self,
question: str,
*,
run_manager: CallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
async def _aget_docs(
self,
question: str,
*,
run_manager: AsyncCallbackManagerForChainRun,
) -> List[Document]:
"""Get docs."""
return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
@property
def _chain_type(self) -> str:
"""Return the chain type."""
return "retrieval_qa"
"""
Load question answering chains.
For now, only the stuffed chain is modified
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, Mapping, Optional, Protocol
from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import stuff_prompt
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.prompt_template import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLanguageModel, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> CustomStuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
if "llm_kwargs" in kwargs:
llm_kwargs = copy.deepcopy(kwargs["llm_kwargs"])
del kwargs["llm_kwargs"]
else:
llm_kwargs = {}
llm_chain = LLMChain(
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
llm_kwargs=llm_kwargs,
)
return CustomStuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,
)
def load_qa_chain(
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> BaseCombineDocumentsChain:
"""Load question answering chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", "map_rerank", and "refine".
verbose: Whether chains should be run in verbose mode or not. Note that this
applies to all chains that make up the final chain.
callback_manager: Callback manager to use for the chain.
Returns:
A chain to use for question answering.
"""
loader_mapping: Mapping[str, LoadingCallable] = {"stuff": _load_stuff_chain}
if chain_type not in loader_mapping:
raise ValueError(f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}")
return loader_mapping[chain_type](llm, verbose=verbose, callback_manager=callback_manager, **kwargs)
"""
Chain that combines documents by stuffing into context
Modified from Original Source
This code is based on LangChain Ai's langchain, which can be found at
https://github.com/langchain-ai/langchain
The original code is licensed under the MIT license.
"""
import copy
from typing import Any, List
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.docstore.document import Document
from langchain.schema import format_document
class CustomStuffDocumentsChain(StuffDocumentsChain):
"""Chain that combines documents by stuffing into context.
This chain takes a list of documents and first combines them into a single string.
It does this by formatting each document into a string with the `document_prompt`
and then joining them together with `document_separator`. It then adds that new
string to the inputs with the variable name set by `document_variable_name`.
Those inputs are then passed to the `llm_chain`.
Example:
.. code-block:: python
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
# This controls how each document will be formatted. Specifically,
# it will be passed to `format_document` - see that function for more
# details.
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
document_variable_name = "context"
llm = OpenAI()
# The prompt here should take as an input variable the
# `document_variable_name`
prompt = PromptTemplate.from_template(
"Summarize this content: {context}"
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
"""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
"""Construct inputs from kwargs and docs.
Format and the join all the documents together into one input with name
`self.document_variable_name`. The pluck any additional variables
from **kwargs.
Args:
docs: List of documents to format and then join into single input
**kwargs: additional inputs to chain, will pluck any other required
arguments from here.
Returns:
dictionary of inputs to LLMChain
"""
# Format each document according to the prompt
# if the document is in the key-value format has a 'is_key_value_mapping'=True in meta_data and has 'value' in metadata
# use the value to replace the key
doc_prefix = kwargs.get("doc_prefix", "Supporting Document")
docs_ = []
for id, doc in enumerate(docs):
doc_ = copy.deepcopy(doc)
if doc_.metadata.get("is_key_value_mapping", False) and "value" in doc_.metadata:
doc_.page_content = str(doc_.metadata["value"])
prefix = doc_prefix + str(id)
doc_.page_content = str(prefix + ":" + (" " if doc_.page_content[0] != " " else "") + doc_.page_content)
docs_.append(doc_)
doc_strings = [format_document(doc, self.document_prompt) for doc in docs_]
arg_list = ["stop", "temperature", "top_k", "top_p", "max_new_tokens"]
arg_list.extend(self.llm_chain.prompt.input_variables)
# Join the documents together to put them in the prompt.
inputs = {k: v for k, v in kwargs.items() if k in arg_list}
inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
return inputs
"""
Class for loading document type data
"""
import glob
from typing import List
from colossalqa.mylogging import get_logger
from langchain.document_loaders import (
JSONLoader,
PyPDFLoader,
TextLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
)
from langchain.document_loaders.csv_loader import CSVLoader
logger = get_logger()
SUPPORTED_DATA_FORMAT = [".csv", ".json", ".html", ".md", ".pdf", ".txt", ".jsonl"]
class DocumentLoader:
"""
Load documents from different files into list of langchain Documents
"""
def __init__(self, files: List, **kwargs) -> None:
"""
Args:
files: list of files (list[file path, name])
**kwargs: keyword type arguments, useful for certain document types
"""
self.data = {}
self.kwargs = kwargs
for item in files:
path = item[0] if isinstance(item, list) else item
logger.info(f"Loading data from {path}")
self.load_data(path)
logger.info("Data loaded")
self.all_data = []
for key in self.data:
if isinstance(self.data[key], list):
for item in self.data[key]:
if isinstance(item, list):
self.all_data.extend(item)
else:
self.all_data.append(item)
def load_data(self, path: str) -> None:
"""
Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/
for sepcific format requirements.
Args:
path: path to a file
To load files with glob path, here are some examples.
Load all file from directory: folder1/folder2/*
Load all pdf file from directory: folder1/folder2/*.pdf
"""
files = []
# Handle glob expression
try:
files = glob.glob(path)
except Exception as e:
logger.error(e)
if len(files) == 0:
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
elif len(files) == 1:
path = files[0]
else:
for file in files:
self.load_data(file)
return
# Load data if the path is a file
logger.info(f"load {path}", verbose=True)
if path.endswith(".csv"):
# Load csv
loader = CSVLoader(file_path=path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith(".txt"):
# Load txt
loader = TextLoader(path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith("html"):
# Load html
loader = UnstructuredHTMLLoader(path, encoding="utf8")
data = loader.load()
self.data[path] = data
elif path.endswith("json"):
# Load json
loader = JSONLoader(
file_path=path,
jq_schema=self.kwargs.get("jq_schema", ".data[]"),
content_key=self.kwargs.get("content_key", "content"),
metadata_func=self.kwargs.get("metadata_func", None),
)
data = loader.load()
self.data[path] = data
elif path.endswith("jsonl"):
# Load jsonl
loader = JSONLoader(
file_path=path, jq_schema=self.kwargs.get("jq_schema", ".data[].content"), json_lines=True
)
data = loader.load()
self.data[path] = data
elif path.endswith(".md"):
# Load markdown
loader = UnstructuredMarkdownLoader(path)
data = loader.load()
self.data[path] = data
elif path.endswith(".pdf"):
# Load pdf
loader = PyPDFLoader(path)
data = loader.load_and_split()
self.data[path] = data
else:
if "." in path.split("/")[-1]:
raise ValueError(f"Unsupported file format {path}. Supported formats: {SUPPORTED_DATA_FORMAT}")
else:
# May ba a directory, we strictly follow the glob path and will not load files in subdirectories
pass
'''
Class for loading table type data. please refer to Pandas-Input/Output for file format details.
'''
import os
import glob
import pandas as pd
from sqlalchemy import create_engine
from colossalqa.utils import drop_table
from colossalqa.mylogging import get_logger
logger = get_logger()
SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
class TableLoader:
'''
Load tables from different files and serve a sql database for database operations
'''
def __init__(self, files: str,
sql_path:str='sqlite:///mydatabase.db',
verbose=False, **kwargs) -> None:
'''
Args:
files: list of files (list[file path, name])
sql_path: how to serve the sql database
**kwargs: keyword type arguments, useful for certain document types
'''
self.data = {}
self.verbose = verbose
self.sql_path = sql_path
self.kwargs = kwargs
self.sql_engine = create_engine(self.sql_path)
drop_table(self.sql_engine)
self.sql_engine = create_engine(self.sql_path)
for item in files:
path = item[0]
dataset_name = item[1]
if not os.path.exists(path):
raise FileNotFoundError(f"{path} doesn't exists")
if not any([path.endswith(i) for i in SUPPORTED_DATA_FORMAT]):
raise TypeError(f"{path} not supported. Supported type {SUPPORTED_DATA_FORMAT}")
logger.info("loading data", verbose=self.verbose)
self.load_data(path)
logger.info("data loaded", verbose=self.verbose)
self.to_sql(path, dataset_name)
def load_data(self, path):
'''
Load data and serve the data as sql database.
Data must be in pandas format
'''
files = []
# Handle glob expression
try:
files = glob.glob(path)
except Exception as e:
logger.error(e)
if len(files)==0:
raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
elif len(files)==1:
path = files[0]
else:
for file in files:
self.load_data(file)
if path.endswith('.csv'):
# Load csv
self.data[path] = pd.read_csv(path)
elif path.endswith('.xlsx') or path.endswith('.xls'):
# Load excel
self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
elif path.endswith('.json'):
# Load json
self.data[path] = pd.read_json(path)
elif path.endswith('.html'):
# Load html
html_tables = pd.read_html(path)
# Choose the desired table from the list of DataFrame objects
self.data[path] = html_tables[0] # You may need to adjust this index
elif path.endswith('.h5') or path.endswith('.hdf5'):
# Load h5
self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
elif path.endswith('.parquet'):
# Load parquet
self.data[path] = pd.read_parquet(path, engine='fastparquet')
elif path.endswith('.feather'):
# Load feather
self.data[path] = pd.read_feather(path)
elif path.endswith('.dta'):
# Load dta
self.data[path] = pd.read_stata(path)
else:
raise ValueError("Unsupported file format")
def to_sql(self, path, table_name):
'''
Serve the data as sql database.
'''
self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
return self.sql_path
def get_sql_path(self):
return self.sql_path
def __del__(self):
if self.sql_engine:
drop_table(self.sql_engine)
self.sql_engine.dispose()
del self.data
del self.sql_engine
"""
LLM wrapper for LLMs running on ColossalCloud Platform
Usage:
os.environ['URL'] = ""
os.environ['HOST'] = ""
gen_config = {
'max_new_tokens': 100,
# 'top_k': 2,
'top_p': 0.9,
'temperature': 0.5,
'repetition_penalty': 2,
}
llm = ColossalCloudLLM(n=1)
llm.set_auth_config()
resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
print(resp) # super-heavyweight awesome-natured yawning Australian creature!
"""
import json
from typing import Any, List, Mapping, Optional
import requests
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class ColossalCloudLLM(LLM):
"""
A custom LLM class that integrates LLMs running on the ColossalCloud Platform
"""
n: int
gen_config: dict = None
auth_config: dict = None
valid_gen_para: list = ['max_new_tokens', 'top_k',
'top_p', 'temperature', 'repetition_penalty']
def __init__(self, gen_config=None, **kwargs):
"""
Args:
gen_config: config for generation,
max_new_tokens: 50 by default
top_k: (1, vocab_size)
top_p: (0, 1) if not None
temperature: (0, inf) if not None
repetition_penalty: (1, inf) if not None
"""
super(ColossalCloudLLM, self).__init__(**kwargs)
if gen_config is None:
self.gen_config = {"max_new_tokens": 50}
else:
assert "max_new_tokens" in gen_config, "max_new_tokens is a compulsory key in the gen config"
self.gen_config = gen_config
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n}
@property
def _llm_type(self) -> str:
return 'ColossalCloudLLM'
def set_auth_config(self, **kwargs):
url = get_from_dict_or_env(kwargs, "url", "URL")
host = get_from_dict_or_env(kwargs, "host", "HOST")
auth_config = {}
auth_config['endpoint'] = url
auth_config['Host'] = host
self.auth_config = auth_config
def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
"""
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered
Returns:
The string generated by the model
"""
# Update the generation arguments
for key, value in kwargs.items():
if key not in self.valid_gen_para:
raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
if key in self.gen_config:
self.gen_config[key] = value
resp_text = self.text_completion(prompt, self.gen_config, self.auth_config)
# TODO: This may cause excessive tokens count
if stop is not None:
for stopping_words in stop:
if stopping_words in resp_text:
resp_text = resp_text.split(stopping_words)[0]
return resp_text
def text_completion(self, prompt, gen_config, auth_config):
# Complusory Parameters
endpoint = auth_config.pop('endpoint')
max_new_tokens = gen_config.pop('max_new_tokens')
# Optional Parameters
optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
# Define the data payload
data = {
"max_new_tokens": max_new_tokens,
"history": [
{"instruction": prompt, "response": ""}
],
**gen_config
}
headers = {
"Content-Type": "application/json",
**auth_config # 'Host',
}
# Make the POST request
response = requests.post(endpoint, headers=headers, data=json.dumps(data))
response.raise_for_status() # raise error if return code is not 200(success)
# Check the response
return response.text
"""
API and LLM warpper class for running LLMs locally
Usage:
import os
model_path = os.environ.get("ZH_MODEL_PATH")
model_name = "chatglm2"
colossal_api = ColossalAPI(model_name, model_path)
llm = ColossalLLM(n=1, api=colossal_api)
TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始,"
logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
"""
from typing import Any, List, Mapping, Optional
import torch
from colossalqa.local.utils import get_response, post_http_request
from colossalqa.mylogging import get_logger
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from transformers import AutoModelForCausalLM, AutoTokenizer
logger = get_logger()
class ColossalAPI:
"""
API for calling LLM.generate
"""
__instances = dict()
def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
"""
Configurate model
"""
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
return
else:
ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self
self.model_type = model_type
self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
if ckpt_path is not None:
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)
self.model.to(torch.cuda.current_device())
# Configurate tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval()
@staticmethod
def get_api(model_type: str, model_path: str, ckpt_path: str = None):
if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")]
else:
return ColossalAPI(model_type, model_path, ckpt_path)
def generate(self, input: str, **kwargs) -> str:
"""
Generate response given the prompt
Args:
input: input string
**kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens...
Returns:
output: output string
"""
if self.model_type in ["chatglm", "chatglm2"]:
inputs = {
k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items()
}
else:
inputs = {
"input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device())
}
output = self.model.generate(**inputs, **kwargs)
output = output.cpu()
prompt_len = inputs["input_ids"].size(1)
response = output[0, prompt_len:]
output = self.tokenizer.decode(response, skip_special_tokens=True)
return output
class VllmAPI:
def __init__(self, host: str = "localhost", port: int = 8077) -> None:
# Configurate api for model served through web
self.host = host
self.port = port
self.url = f"http://{self.host}:{self.port}/generate"
def generate(self, input: str, **kwargs):
output = get_response(post_http_request(input, self.url, **kwargs))[0]
return output[len(input) :]
class ColossalLLM(LLM):
"""
Langchain LLM wrapper for a local LLM
"""
n: int
api: Any
kwargs = {"max_new_tokens": 100}
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
for k in self.kwargs:
if k not in kwargs:
kwargs[k] = self.kwargs[k]
generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]}
out = self.api.generate(prompt, **generate_args)
if isinstance(stop, list) and len(stop) != 0:
for stopping_words in stop:
if stopping_words in out:
out = out.split(stopping_words)[0]
logger.info(f"{prompt}{out}", verbose=self.verbose)
return out
@property
def _identifying_params(self) -> Mapping[str, int]:
"""Get the identifying parameters."""
return {"n": self.n}
class VllmLLM(LLM):
"""
Langchain LLM wrapper for a local LLM
"""
n: int
api: Any
kwargs = {"max_new_tokens": 100}
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
for k in self.kwargs:
if k not in kwargs:
kwargs[k] = self.kwargs[k]
logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]}
out = self.api.generate(prompt, **generate_args)
if len(stop) != 0:
for stopping_words in stop:
if stopping_words in out:
out = out.split(stopping_words)[0]
logger.info(f"{prompt}{out}", verbose=self.verbose)
return out
def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None:
if "max_tokens" not in kwargs:
kwargs["max_tokens"] = 100
self.kwargs = kwargs
self.api = VllmAPI(host=host, port=port)
@property
def _identifying_params(self) -> Mapping[str, int]:
"""Get the identifying parameters."""
return {"n": self.n}
"""
LLM wrapper for Pangu
Usage:
# URL: “盘古大模型套件管理”->点击“服务管理”->“模型列表”->点击想要使用的模型的“复制路径”
# USERNAME: 华为云控制台:“我的凭证”->“API凭证”下的“IAM用户名”,也就是你登录IAM账户的名字
# PASSWORD: IAM用户的密码
# DOMAIN_NAME: 华为云控制台:“我的凭证”->“API凭证”下的“用户名”,也就是公司管理IAM账户的总账户名
os.environ["URL"] = ""
os.environ["URLNAME"] = ""
os.environ["PASSWORD"] = ""
os.environ["DOMAIN_NAME"] = ""
pg = Pangu(id=1)
pg.set_auth_config()
res = pg('你是谁') # 您好,我是华为盘古大模型。我能够通过和您对话互动为您提供帮助。请问您有什么想问我的吗?
"""
import http.client
import json
from typing import Any, List, Mapping, Optional
import requests
from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env
class Pangu(LLM):
"""
A custom LLM class that integrates pangu models
"""
n: int
gen_config: dict = None
auth_config: dict = None
def __init__(self, gen_config=None, **kwargs):
super(Pangu, self).__init__(**kwargs)
if gen_config is None:
self.gen_config = {"user": "User", "max_tokens": 50, "temperature": 0.95, "n": 1}
else:
self.gen_config = gen_config
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"n": self.n}
@property
def _llm_type(self) -> str:
return "pangu"
def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
"""
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered
Returns:
The string generated by the model
"""
# Update the generation arguments
for key, value in kwargs.items():
if key in self.gen_config:
self.gen_config[key] = value
response = self.text_completion(prompt, self.gen_config, self.auth_config)
text = response["choices"][0]["text"]
if stop is not None:
for stopping_words in stop:
if stopping_words in text:
text = text.split(stopping_words)[0]
return text
def set_auth_config(self, **kwargs):
url = get_from_dict_or_env(kwargs, "url", "URL")
username = get_from_dict_or_env(kwargs, "username", "USERNAME")
password = get_from_dict_or_env(kwargs, "password", "PASSWORD")
domain_name = get_from_dict_or_env(kwargs, "domain_name", "DOMAIN_NAME")
region = url.split(".")[1]
auth_config = {}
auth_config["endpoint"] = url[url.find("https://") + 8 : url.find(".com") + 4]
auth_config["resource_path"] = url[url.find(".com") + 4 :]
auth_config["auth_token"] = self.get_latest_auth_token(region, username, password, domain_name)
self.auth_config = auth_config
def get_latest_auth_token(self, region, username, password, domain_name):
url = f"https://iam.{region}.myhuaweicloud.com/v3/auth/tokens"
payload = json.dumps(
{
"auth": {
"identity": {
"methods": ["password"],
"password": {"user": {"name": username, "password": password, "domain": {"name": domain_name}}},
},
"scope": {"project": {"name": region}},
}
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
return response.headers["X-Subject-Token"]
def text_completion(self, text, gen_config, auth_config):
conn = http.client.HTTPSConnection(auth_config["endpoint"])
payload = json.dumps(
{
"prompt": text,
"user": gen_config["user"],
"max_tokens": gen_config["max_tokens"],
"temperature": gen_config["temperature"],
"n": gen_config["n"],
}
)
headers = {
"X-Auth-Token": auth_config["auth_token"],
"Content-Type": "application/json",
}
conn.request("POST", auth_config["resource_path"], payload, headers)
res = conn.getresponse()
data = res.read()
data = json.loads(data.decode("utf-8"))
return data
def chat_model(self, messages, gen_config, auth_config):
conn = http.client.HTTPSConnection(auth_config["endpoint"])
payload = json.dumps(
{
"messages": messages,
"user": gen_config["user"],
"max_tokens": gen_config["max_tokens"],
"temperature": gen_config["temperature"],
"n": gen_config["n"],
}
)
headers = {
"X-Auth-Token": auth_config["auth_token"],
"Content-Type": "application/json",
}
conn.request("POST", auth_config["resource_path"], payload, headers)
res = conn.getresponse()
data = res.read()
data = json.loads(data.decode("utf-8"))
return data
"""
Generation utilities
"""
import json
from typing import List
import requests
def post_http_request(
prompt: str, api_url: str, n: int = 1, max_tokens: int = 100, temperature: float = 0.0, stream: bool = False
) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 1,
"use_beam_search": False,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True, timeout=3)
return response
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment