import argparse import fitz import re import os import time import pandas as pd import hashlib import textract import shutil import configparser import json from multiprocessing import Pool from typing import List from loguru import logger from BCEmbedding.tools.langchain import BCERerank from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores.faiss import FAISS from torch.cuda import empty_cache from bs4 import BeautifulSoup from elastic_keywords_search import ElasticKeywordsSearch from retriever import Retriever def check_envs(args): if all(isinstance(item, int) for item in args.DCU_ID): os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, args.DCU_ID)) logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {args.DCU_ID}") else: logger.error(f"The --DCU_ID argument must be a list of integers, but got {args.DCU_ID}") raise ValueError("The --DCU_ID argument must be a list of integers") class DocumentName: def __init__(self, directory: str, name: str, category: str): self.directory = directory self.prefix = name.replace('/', '_') self.basename = os.path.basename(name) self.origin_path = os.path.join(directory, name) self.copy_path = '' self._category = category self.status = True self.message = '' def __str__(self): return '{},{},{},{}\n'.format(self.basename, self.copy_path, self.status, self.message) class DocumentProcessor: def __init__(self): self.image_suffix = ['.jpg', '.jpeg', '.png', '.bmp'] self.md_suffix = '.md' self.text_suffix = ['.txt', '.text'] self.excel_suffix = ['.xlsx', '.xls', '.csv'] self.pdf_suffix = '.pdf' self.ppt_suffix = '.pptx' self.html_suffix = ['.html', '.htm', '.shtml', '.xhtml'] self.word_suffix = ['.docx', '.doc'] self.json_suffix = '.json' def md5(self, filepath: str): hash_object = hashlib.sha256() with open(filepath, 'rb') as file: chunk_size = 8192 while chunk := file.read(chunk_size): hash_object.update(chunk) return hash_object.hexdigest()[0:8] def summarize(self, files: list): success = 0 skip = 0 failed = 0 for file in files: if file.status: success += 1 elif file.message == 'skip': skip += 1 else: logger.info('{}文件异常, 异常信息: {} '.format(file.origin_path, file.message)) failed += 1 logger.info('解析{}文件,成功{}个,跳过{}个,异常{}个'.format(len(files), success, skip, failed)) def read_file_type(self, filepath: str): filepath = filepath.lower() if filepath.endswith(self.pdf_suffix): return 'pdf' if filepath.endswith(self.md_suffix): return 'md' if filepath.endswith(self.ppt_suffix): return 'ppt' if filepath.endswith(self.json_suffix): return 'json' for suffix in self.image_suffix: if filepath.endswith(suffix): return 'image' for suffix in self.text_suffix: if filepath.endswith(suffix): return 'text' for suffix in self.word_suffix: if filepath.endswith(suffix): return 'word' for suffix in self.excel_suffix: if filepath.endswith(suffix): return 'excel' for suffix in self.html_suffix: if filepath.endswith(suffix): return 'html' return None def scan_directory(self, repo_dir: str): documents = [] for directory, _, names in os.walk(repo_dir): for name in names: category = self.read_file_type(name) if category is not None: documents.append( DocumentName(directory=directory, name=name, category=category)) return documents def read(self, filepath: str): file_type = self.read_file_type(filepath) text = '' if not os.path.exists(filepath): return text try: if file_type == 'md' or file_type == 'text': text = [] with open(filepath) as f: txt = f.read() cleaned_txt = re.sub(r'\n\s*\n', '\n\n', txt) text.append(cleaned_txt) elif file_type == 'pdf': text += self.read_pdf(filepath) text = re.sub(r'\n\s*\n', '\n\n', text) elif file_type == 'excel': text += self.read_excel(filepath) elif file_type == 'word' or file_type == 'ppt': # https://stackoverflow.com/questions/36001482/read-doc-file-with-python # https://textract.readthedocs.io/en/latest/installation.html text = textract.process(filepath).decode('utf8') text = re.sub(r'\n\s*\n', '\n\n', text) if file_type == 'ppt': text = text.replace('\n', ' ') elif file_type == 'html': with open(filepath) as f: soup = BeautifulSoup(f.read(), 'html.parser') text += soup.text elif filepath.endswith('.json'): # 打开JSON文件进行读取 with open(filepath, 'r', encoding='utf-8') as file: # 读取文件的所有行 text = file.readlines() except Exception as e: logger.error((filepath, str(e))) return '', e return text, None def read_excel(self, filepath: str): table = None if filepath.endswith('.csv'): table = pd.read_csv(filepath) else: table = pd.read_excel(filepath) if table is None: return '' json_text = table.dropna(axis=1).to_json(force_ascii=False) return json_text def read_pdf(self, filepath: str): # load pdf and serialize table text = '' with fitz.open(filepath) as pages: for page in pages: text += page.get_text() tables = page.find_tables() for table in tables: tablename = '_'.join( filter(lambda x: x is not None and 'Col' not in x, table.header.names)) pan = table.to_pandas() json_text = pan.dropna(axis=1).to_json(force_ascii=False) text += tablename text += '\n' text += json_text text += '\n' return text def read_and_save(file: DocumentName, file_opr: DocumentProcessor): try: if os.path.exists(file.copy_path): # already exists, return logger.info('{} already processed, output file: {}, skip load' .format(file.origin_path, file.copy_path)) return logger.info('reading {}, would save to {}'.format(file.origin_path, file.copy_path)) content, error = file_opr.read(file.origin_path) if error is not None: logger.error('{} load error: {}'.format(file.origin_path, str(error))) return if content is None or len(content) < 1: logger.warning('{} empty, skip save'.format(file.origin_path)) return cleaned_content = re.sub(r'\n\s*\n', '\n\n', content) with open(file.copy_path, 'w') as f: f.write(os.path.splitext(file.basename)[0] + '\n') f.write(cleaned_content) except Exception as e: logger.error(f"Error in read_and_save: {e}") class FeatureDataBase: def __init__(self, embeddings: HuggingFaceEmbeddings, reranker: BCERerank, reject_throttle=-1) -> None: # logger.debug('loading text2vec model..') self.embeddings = embeddings self.reranker = reranker self.compression_retriever = None self.rejecter = None self.retriever = None self.reject_throttle = reject_throttle if reject_throttle else -1 self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=1068, chunk_overlap=32) def get_documents(self, text, file): # if len(text) <= 1: # return [] chunks = self.text_splitter.create_documents(text) documents = [] for chunk in chunks: # `source` is for return references # `read` is for LLM response chunk.metadata = {'source': file.basename, 'read': file.origin_path} documents.append(chunk) return documents def build_database(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None): feature_dir = os.path.join(work_dir, 'db_response') if not os.path.exists(feature_dir): os.makedirs(feature_dir) documents = [] texts_for_es = [] metadatas_for_es = [] ids_for_es = [] for i, file in enumerate(files): if not file.status: continue # 读取每个file text, error = file_opr.read(file.copy_path) if error is not None: file.status = False file.message = str(error) continue file.message = str(text[0]) texts_for_es.append(text[0]) metadatas_for_es.append({'source': file.basename, 'read': file.origin_path}) ids_for_es.append(str(i)) document = self.get_documents(text, file) documents += document logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents' .format(i + 1, len(files), file.basename, len(document))) if elastic_search is not None: logger.debug('ES database pipeline register {} documents into database...'.format(len(texts_for_es))) es_time_before_register = time.time() elastic_search.add_texts(texts_for_es, metadatas=metadatas_for_es, ids=ids_for_es) es_time_after_register = time.time() logger.debug('ES database pipeline take time: {} '.format(es_time_after_register - es_time_before_register)) logger.debug('Vector database pipeline register {} documents into database...'.format(len(documents))) ve_time_before_register = time.time() vs = FAISS.from_documents(documents, self.embeddings) vs.save_local(feature_dir) ve_time_after_register = time.time() logger.debug('Vector database pipeline take time: {} '.format(ve_time_after_register - ve_time_before_register)) def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor): preproc_dir = os.path.join(work_dir, 'preprocess') if not os.path.exists(preproc_dir): os.makedirs(preproc_dir) pool = Pool(processes=16) for idx, file in enumerate(files): if not os.path.exists(file.origin_path): file.status = False file.message = 'skip not exist' continue if file._category == 'image': file.status = False file.message = 'skip image' elif file._category in ['pdf', 'word', 'ppt', 'html', 'excel']: # read pdf/word/excel file and save to text format md5 = file_opr.md5(file.origin_path) file.copy_path = os.path.join(preproc_dir, '{}.text'.format(md5)) pool.apply_async(read_and_save, args=(file, file_opr)) elif file._category in ['md', 'text']: # rename text files to new dir file.copy_path = os.path.join( preproc_dir, file.origin_path.replace('/', '_')[-84:]) try: shutil.copy(file.origin_path, file.copy_path) file.status = True file.message = 'preprocessed' except Exception as e: file.status = False file.message = str(e) elif file._category in ['json']: file.status = True file.copy_path = file.origin_path file.message = 'preprocessed' else: file.status = False file.message = 'skip unknown format' pool.close() logger.debug('waiting for preprocess read finish..') pool.join() # check process result for file in files: if file._category in ['pdf', 'word', 'excel']: if os.path.exists(file.copy_path): file.status = True file.message = 'preprocessed' else: file.status = False file.message = 'read error' def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor, elastic_search=None): self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr) self.build_database(files=files, work_dir=work_dir, file_opr=file_opr, elastic_search=elastic_search) def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor): feature_dir = os.path.join(work_dir, 'db_response') if not os.path.exists(feature_dir): os.makedirs(feature_dir) documents = [] for i, file in enumerate(files): logger.debug('{}/{}.. register 《{}》 into database...'.format(i + 1, len(files), file.basename)) if not file.status: continue # 读取每个file text, error = file_opr.read(file.copy_path) if error is not None: file.status = False file.message = str(error) continue logger.info(str(len(text)), text, str(text[0])) file.message = str(text[0]) # file.message = str(len(text)) # logger.info('{} content length {}'.format( # file._category, len(text))) documents += self.get_documents(text, file) if documents: vs = FAISS.from_documents(documents, self.embeddings) if faiss: faiss.merge_from(vs) faiss.save_local(feature_dir) else: vs.save_local(feature_dir) def test_reject(retriever: Retriever): """Simple test reject pipeline.""" real_questions = [ '姚明是谁?', 'CBBA是啥?', '差多少嘞?', 'cnn 的全称是什么?', 'transformer啥意思?', '成都有什么好吃的推荐?', '树博士是什么?', '白马非马啥意思?', 'mmpose 如何安装?', '今天天气如何?', '写一首五言律诗?', '先有鸡还是先有蛋?', '如何在Gromacs中进行蛋白质的动态模拟?', 'wy-vSphere 7 海光平台兼容补丁?', '在Linux系统中,如何进行源码包的安装?' ] for example in real_questions: relative, _ = retriever.is_relative(example) if relative: logger.warning(f'process query: {example}') retriever.query(example) empty_cache() else: logger.error(f'reject query: {example}') empty_cache() def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( description='Feature store for processing directories.') parser.add_argument('--work_dir', type=str, default='', help='自定义.') parser.add_argument( '--repo_dir', type=str, default='', help='需要读取的文件目录.') parser.add_argument( '--config_path', default='./ai/rag/config.ini', help='config目录') parser.add_argument( '--DCU_ID', default=[7], help='设置DCU') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() log_file_path = os.path.join(args.work_dir, 'application.log') logger.add(log_file_path, rotation='10MB', compression='zip') check_envs(args) config = configparser.ConfigParser() config.read(args.config_path) # only init vector retriever embedding_model_path = config.get('rag', 'embedding_model_path') or None reranker_model_path = config.get('rag', 'reranker_model_path') or None if embedding_model_path and reranker_model_path: embeddings = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={'device': 'cuda'}, encode_kwargs={ 'batch_size': 1, 'normalize_embeddings': True }) embeddings.client = embeddings.client.half() reranker_args = { 'model': reranker_model_path, 'top_n': int(config['rag']['vector_top_k']), 'device': 'cuda', 'use_fp16': True } reranker = BCERerank(**reranker_args) fs_init = FeatureDataBase(embeddings=embeddings, reranker=reranker) # init es retriever, drop_old means build new one or updata the 'index_name' es_url = config.get('rag', 'es_url') index_name = config.get('rag', 'index_name') elastic_search = ElasticKeywordsSearch( elasticsearch_url=es_url, index_name=index_name, drop_old=True) # walk all files in repo dir file_opr = DocumentProcessor() files = file_opr.scan_directory(repo_dir=args.repo_dir) fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr, elastic_search=elastic_search) file_opr.summarize(files) # del fs_init # with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f: # positive_sample = json.load(f) # with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f: # negative_sample = json.load(f) # # with open(os.path.join(args.work_dir, 'sample', 'positive.txt'), 'r', encoding='utf-8') as file: # positive_sample = [] # for line in file: # positive_sample.append(line.strip()) # # with open(os.path.join(args.work_dir, 'sample', 'negative.txt'), 'r', encoding='utf-8') as file: # negative_sample = [] # for line in file: # negative_sample.append(line.strip()) # # test_reject(retriever)