from __future__ import annotations import argparse import uuid from abc import ABC from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple from loguru import logger import jieba.analyse from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore if TYPE_CHECKING: from elasticsearch import Elasticsearch # noqa: F401 def _default_text_mapping() -> Dict: return {'properties': {'text': {'type': 'text'}}} DEFAULT_PROMPT = PromptTemplate( input_variables=['question'], template="""分析给定Question,提取Question中包含的KeyWords,输出列表形式 Examples: Question: 达梦公司在过去三年中的流动比率如下:2021年:3.74倍;2020年:2.82倍;2019年:2.05倍。 KeyWords: ['过去三年', '流动比率', '2021', '3.74', '2020', '2.82', '2019', '2.05'] ---------------- Question: {question} KeyWords: """, ) class ElasticKeywordsSearch(VectorStore, ABC): def __init__( self, elasticsearch_url: str, index_name: str, drop_old: Optional[bool] = False, *, ssl_verify: Optional[Dict[str, Any]] = None, llm_chain: Optional[LLMChain] = None, ): try: import elasticsearch except ImportError: logger.error('Could not import elasticsearch python package. ' 'Please install it with `pip install elasticsearch`.') return self.index_name = index_name self.llm_chain = llm_chain self.drop_old = drop_old _ssl_verify = ssl_verify or {} self.elasticsearch_url = elasticsearch_url self.ssl_verify = _ssl_verify try: self.client = elasticsearch.Elasticsearch(elasticsearch_url, **_ssl_verify) except ValueError as e: logger.error(f'Your elasticsearch client string is mis-formatted. Got error: {e}') return if drop_old: try: self.client.indices.delete(index=index_name) except elasticsearch.exceptions.NotFoundError: logger.info(f"Index '{index_name}' not found, nothing to delete.") except Exception as e: logger.error(f"Error occurred while trying to delete index '{index_name}': {e}") logger.info(f"ElasticKeywordsSearch initialized with URL: {elasticsearch_url} and index: {index_name}") def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, refresh_indices: bool = True, **kwargs: Any, ) -> List[str]: try: from elasticsearch.exceptions import NotFoundError from elasticsearch.helpers import bulk except ImportError: raise ImportError('Could not import elasticsearch python package. ' 'Please install it with `pip install elasticsearch`.') requests = [] ids = ids or [str(uuid.uuid4()) for _ in texts] mapping = _default_text_mapping() # check to see if the index already exists try: self.client.indices.get(index=self.index_name) if texts and self.drop_old: self.client.indices.delete(index=self.index_name) self.create_index(self.client, self.index_name, mapping) except NotFoundError: # TODO would be nice to create index before embedding, # just to save expensive steps for last self.create_index(self.client, self.index_name, mapping) for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} request = { '_op_type': 'index', '_index': self.index_name, 'text': text, 'metadata': metadata, '_id': ids[i], } requests.append(request) bulk(self.client, requests) if refresh_indices: self.client.indices.refresh(index=self.index_name) return ids def similarity_search(self, query: str, k: int = 4, query_strategy: str = 'match_phrase', must_or_should: str = 'should', **kwargs: Any) -> List[Document]: if k == 0: # pm need to control return [] docs_and_scores = self.similarity_search_with_score(query, k=k, query_strategy=query_strategy, must_or_should=must_or_should, **kwargs) documents = [d[0] for d in docs_and_scores] return documents @staticmethod def _relevance_score_fn(distance: float) -> float: """Normalize the distance to a score on a scale [0, 1].""" # Todo: normalize the es score on a scale [0, 1] return distance def _select_relevance_score_fn(self) -> Callable[[float], float]: return self._relevance_score_fn def similarity_search_with_score(self, query: str, k: int = 4, query_strategy: str = 'match_phrase', must_or_should: str = 'should', **kwargs: Any) -> List[Tuple[Document, float]]: if k == 0: # pm need to control return [] assert must_or_should in ['must', 'should'], 'only support must and should.' # llm or jiaba extract keywords if self.llm_chain: keywords_str = self.llm_chain.run(query) print('llm search keywords:', keywords_str) try: keywords = eval(keywords_str) if not isinstance(keywords, list): raise ValueError('Keywords extracted by llm is not list.') except Exception as e: print(str(e)) keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False) else: keywords = jieba.analyse.extract_tags(query, topK=10, withWeight=False) logger.info('jieba search keywords:{}'.format(keywords)) match_query = {'bool': {must_or_should: []}} for key in keywords: match_query['bool'][must_or_should].append({query_strategy: {'text': key}}) response = self.client_search(self.client, self.index_name, match_query, size=k) hits = [hit for hit in response['hits']['hits']] docs_and_scores = [ Document( page_content=hit['_source']['text'], metadata={**hit['_source']['metadata'], 'relevance_score': hit['_score']} ) for hit in hits] return docs_and_scores @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, ids: Optional[List[str]] = None, index_name: Optional[str] = None, refresh_indices: bool = True, llm: Optional[BaseLLM] = None, prompt: Optional[PromptTemplate] = DEFAULT_PROMPT, drop_old: Optional[bool] = False, **kwargs: Any, ) -> ElasticKeywordsSearch: elasticsearch_url = get_from_dict_or_env(kwargs, 'elasticsearch_url', 'ELASTICSEARCH_URL') if 'elasticsearch_url' in kwargs: del kwargs['elasticsearch_url'] index_name = index_name or uuid.uuid4().hex if llm: llm_chain = LLMChain(llm=llm, prompt=prompt) vectorsearch = cls(elasticsearch_url, index_name, llm_chain=llm_chain, drop_old=drop_old, **kwargs) else: vectorsearch = cls(elasticsearch_url, index_name, drop_old=drop_old, **kwargs) vectorsearch.add_texts(texts, metadatas=metadatas, ids=ids, refresh_indices=refresh_indices) return vectorsearch def create_index(self, client: Any, index_name: str, mapping: Dict) -> None: version_num = client.info()['version']['number'][0] version_num = int(version_num) if version_num >= 8: client.indices.create(index=index_name, mappings=mapping) else: client.indices.create(index=index_name, body={'mappings': mapping}) def client_search(self, client: Any, index_name: str, script_query: Dict, size: int) -> Any: version_num = client.info()['version']['number'][0] version_num = int(version_num) if version_num >= 8: response = client.search(index=index_name, query=script_query, size=size, timeout='5s') else: response = client.search(index=index_name, body={'query': script_query, 'size': size}, timeout='5s') return response def delete(self, **kwargs: Any) -> None: # TODO: Check if this can be done in bulk self.client.indices.delete(index=self.index_name) def read_text(filepath): with open(filepath) as f: txt = f.read() return txt def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( description='Feature store for processing directories.') parser.add_argument( '--elasticsearch_url', type=str, default='http://127.0.0.1:9200') parser.add_argument( '--index_name', type=str, default='dcu_knowledge_base') parser.add_argument( '--query', type=str, default='介绍下K100_AI?') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() elastic_search = ElasticKeywordsSearch( elasticsearch_url=args.elasticsearch_url, index_name=args.index_name, drop_old=False ) # texts = [] # file_list = ['/home/zhangwq/data/doc_new/preprocess/dbaa1604.text', # '/home/zhangwq/data/doc_new/preprocess/a8e2e50d.text', # '/home/zhangwq/data/doc_new/preprocess/a3fdf916.text', # '/home/zhangwq/data/doc_new/preprocess/9d2683f3.text', # '/home/zhangwq/data/doc_new/preprocess/2584c250.text'] # for file in file_list: # text = read_text(file) # texts.append(text) # metadatas = [ # {"source": "白皮书-K100.pdf", "type": "text"}, # {"source": "DCU人工智能基础软件系统DAS1.0介绍.pdf", "type": "text"}, # {"source": "202404-DCU优势测试项.pdf", "type": "text"}, # {"source": "202301-达芬奇架构简介.pdf", "type": "text"}, # {"source": "曙光DCU在大模型方面的布局与应用.docx", "type": "text"}, # ] # ids = ["doc1", "doc2", "doc3", "doc4", "doc5"] # elastic_search.add_texts(texts, metadatas=metadatas, ids=ids) search_results = elastic_search.similarity_search_with_score(args.query, k=5) for result in search_results: logger.debug('Query: {} \nDoc: {} \nScore: {}'.format(args.query, result[0], result[1]))