import argparse import fitz import re import os import time import pandas as pd import hashlib import textract import shutil import configparser import json from multiprocessing import Pool from typing import List from loguru import logger from BCEmbedding.tools.langchain import BCERerank from langchain.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores.faiss import FAISS from torch.cuda import empty_cache from bs4 import BeautifulSoup from .retriever import CacheRetriever, Retriever class DocumentName: def __init__(self, directory: str, name: str, category: str): self.directory = directory self.prefix = name.replace('/', '_') self.basename = os.path.basename(name) self.origin_path = os.path.join(directory, name) self.copy_path = '' self._category = category self.status = True self.message = '' def __str__(self): return '{},{},{},{}\n'.format(self.basename, self.copy_path, self.status, self.message) class DocumentProcessor: def __init__(self): self.image_suffix = ['.jpg', '.jpeg', '.png', '.bmp'] self.md_suffix = '.md' self.text_suffix = ['.txt', '.text'] self.excel_suffix = ['.xlsx', '.xls', '.csv'] self.pdf_suffix = '.pdf' self.ppt_suffix = '.pptx' self.html_suffix = ['.html', '.htm', '.shtml', '.xhtml'] self.word_suffix = ['.docx', '.doc'] self.json_suffix = '.json' def md5(self, filepath: str): hash_object = hashlib.sha256() with open(filepath, 'rb') as file: chunk_size = 8192 while chunk := file.read(chunk_size): hash_object.update(chunk) return hash_object.hexdigest()[0:8] def summarize(self, files: list): success = 0 skip = 0 failed = 0 for file in files: if file.status: success += 1 elif file.message == 'skip': skip += 1 else: logger.info('{}文件异常, 异常信息: {} '.format(file.origin_path, file.message)) failed += 1 logger.info('解析{}文件,成功{}个,跳过{}个,异常{}个'.format(len(files), success, skip, failed)) def read_file_type(self, filepath: str): filepath = filepath.lower() if filepath.endswith(self.pdf_suffix): return 'pdf' if filepath.endswith(self.md_suffix): return 'md' if filepath.endswith(self.ppt_suffix): return 'ppt' if filepath.endswith(self.json_suffix): return 'json' for suffix in self.image_suffix: if filepath.endswith(suffix): return 'image' for suffix in self.text_suffix: if filepath.endswith(suffix): return 'text' for suffix in self.word_suffix: if filepath.endswith(suffix): return 'word' for suffix in self.excel_suffix: if filepath.endswith(suffix): return 'excel' for suffix in self.html_suffix: if filepath.endswith(suffix): return 'html' return None def scan_directory(self, repo_dir: str): documents = [] for directory, _, names in os.walk(repo_dir): for name in names: category = self.read_file_type(name) if category is not None: documents.append( DocumentName(directory=directory, name=name, category=category)) return documents def read(self, filepath: str): file_type = self.read_file_type(filepath) text = '' if not os.path.exists(filepath): return text try: if file_type == 'md' or file_type == 'text': text = [] with open(filepath) as f: txt = f.read() cleaned_txt = re.sub(r'\n\s*\n', '\n\n', txt) text.append(cleaned_txt) elif file_type == 'pdf': text += self.read_pdf(filepath) text = re.sub(r'\n\s*\n', '\n\n', text) elif file_type == 'excel': text = [] df = pd.read_excel(filepath, header=None) for row in df.index.values: doc = dict() doc['Que'] = df.iloc[row, 0] doc['Ans'] = df.iloc[row, 1] text.append(str(doc)) # text += self.read_excel(filepath) elif file_type == 'word' or file_type == 'ppt': # https://stackoverflow.com/questions/36001482/read-doc-file-with-python # https://textract.readthedocs.io/en/latest/installation.html text = textract.process(filepath).decode('utf8') text = re.sub(r'\n\s*\n', '\n\n', text) if file_type == 'ppt': text = text.replace('\n', ' ') elif file_type == 'html': with open(filepath) as f: soup = BeautifulSoup(f.read(), 'html.parser') text += soup.text elif filepath.endswith('.json'): # 打开JSON文件进行读取 with open(filepath, 'r', encoding='utf-8') as file: # 读取文件的所有行 text = file.readlines() except Exception as e: logger.error((filepath, str(e))) return '', e return text, None def read_pdf(self, filepath: str): # load pdf and serialize table text = '' with fitz.open(filepath) as pages: for page in pages: text += page.get_text() tables = page.find_tables() for table in tables: tablename = '_'.join( filter(lambda x: x is not None and 'Col' not in x, table.header.names)) pan = table.to_pandas() json_text = pan.dropna(axis=1).to_json(force_ascii=False) text += tablename text += '\n' text += json_text text += '\n' return text def read_and_save(file: DocumentName, file_opr: DocumentProcessor): try: if os.path.exists(file.copy_path): # already exists, return logger.info('{} already processed, output file: {}, skip load' .format(file.origin_path, file.copy_path)) return logger.info('reading {}, would save to {}'.format(file.origin_path, file.copy_path)) content, error = file_opr.read(file.origin_path) if error is not None: logger.error('{} load error: {}'.format(file.origin_path, str(error))) return if content is None or len(content) < 1: logger.warning('{} empty, skip save'.format(file.origin_path)) return cleaned_content = re.sub(r'\n\s*\n', '\n\n', content) with open(file.copy_path, 'w') as f: f.write(os.path.splitext(file.basename)[0] + '\n') f.write(cleaned_content) except Exception as e: logger.error(f"Error in read_and_save: {e}") class FeatureDataBase: def __init__(self, embeddings: HuggingFaceEmbeddings, reranker: BCERerank, reject_throttle: float) -> None: # logger.debug('loading text2vec model..') self.embeddings = embeddings self.reranker = reranker self.compression_retriever = None self.rejecter = None self.retriever = None self.reject_throttle = reject_throttle if reject_throttle else -1 self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=768, chunk_overlap=32) def get_documents(self, text, file): # if len(text) <= 1: # return [] chunks = self.text_splitter.create_documents(text) documents = [] for chunk in chunks: # `source` is for return references # `read` is for LLM response chunk.metadata = {'source': file.basename, 'read': file.origin_path} documents.append(chunk) return documents def register_response(self, files: list, work_dir: str, file_opr: DocumentProcessor): feature_dir = os.path.join(work_dir, 'db_response') if not os.path.exists(feature_dir): os.makedirs(feature_dir) documents = [] for i, file in enumerate(files): if not file.status: continue # 读取每个file text, error = file_opr.read(file.copy_path) if error is not None: file.status = False file.message = str(error) continue file.message = str(text[0]) # file.message = str(len(text)) # logger.info('{} content length {}'.format( # file._category, len(text))) document = self.get_documents(text, file) documents += document logger.debug('Positive pipeline {}/{}.. register 《{}》 and split {} documents' .format(i + 1, len(files), file.basename, len(document))) logger.debug('Positive pipeline register {} documents into database...'.format(len(documents))) time_before_register = time.time() vs = FAISS.from_documents(documents, self.embeddings) vs.save_local(feature_dir) time_after_register = time.time() logger.debug('Positive pipeline take time: {} '.format(time_after_register - time_before_register)) def register_reject(self, files: list, work_dir: str, file_opr: DocumentProcessor): feature_dir = os.path.join(work_dir, 'db_reject') if not os.path.exists(feature_dir): os.makedirs(feature_dir) documents = [] for i, file in enumerate(files): if not file.state: continue text, error = file_opr.read(file.copypath) if len(text) < 1: continue if error is not None: continue document = self.get_documents(text, file) documents += document logger.debug('Negative pipeline {}/{}.. register 《{}》 and split {} documents' .format(i + 1, len(files), file.basename, len(document))) if len(documents) < 1: return logger.debug('Negative pipeline register {} documents into database...'.format(len(documents))) time_before_register = time.time() vs = FAISS.from_documents(documents, self.embeddings) vs.save_local(feature_dir) time_after_register = time.time() logger.debug('Negative pipeline take time: {} '.format(time_after_register - time_before_register)) def preprocess(self, files: list, work_dir: str, file_opr: DocumentProcessor): preproc_dir = os.path.join(work_dir, 'preprocess') if not os.path.exists(preproc_dir): os.makedirs(preproc_dir) pool = Pool(processes=16) for idx, file in enumerate(files): if not os.path.exists(file.origin_path): file.status = False file.message = 'skip not exist' continue if file._category == 'image': file.status = False file.message = 'skip image' elif file._category in ['pdf', 'word', 'ppt', 'html']: # read pdf/word/excel file and save to text format md5 = file_opr.md5(file.origin_path) file.copy_path = os.path.join(preproc_dir, '{}.text'.format(md5)) pool.apply_async(read_and_save, args=(file, file_opr)) elif file._category in ['md', 'text']: # rename text files to new dir file.copy_path = os.path.join( preproc_dir, file.origin_path.replace('/', '_')[-84:]) try: shutil.copy(file.origin_path, file.copy_path) file.status = True file.message = 'preprocessed' except Exception as e: file.status = False file.message = str(e) elif file._category in ['json', 'excel']: file.status = True file.copy_path = file.origin_path file.message = 'preprocessed' else: file.status = False file.message = 'skip unknown format' pool.close() logger.debug('waiting for preprocess read finish..') pool.join() # check process result for file in files: if file._category in ['pdf', 'word', 'excel']: if os.path.exists(file.copy_path): file.status = True file.message = 'preprocessed' else: file.status = False file.message = 'read error' def initialize(self, files: list, work_dir: str, file_opr: DocumentProcessor): self.preprocess(files=files, work_dir=work_dir, file_opr=file_opr) self.register_response(files=files, work_dir=work_dir, file_opr=file_opr) # self.register_reject(files=files, work_dir=work_dir, file_opr=file_opr) def merge_db_response(self, faiss: FAISS, files: list, work_dir: str, file_opr: DocumentProcessor): feature_dir = os.path.join(work_dir, 'db_response') if not os.path.exists(feature_dir): os.makedirs(feature_dir) documents = [] for i, file in enumerate(files): logger.debug('{}/{}.. register 《{}》 into database...'.format(i + 1, len(files), file.basename)) if not file.status: continue # 读取每个file text, error = file_opr.read(file.copy_path) if error is not None: file.status = False file.message = str(error) continue file.message = str(text[0]) # file.message = str(len(text)) # logger.info('{} content length {}'.format( # file._category, len(text))) documents += self.get_documents(text, file) vs = FAISS.from_documents(documents, self.embeddings) faiss.merge_from(vs) faiss.save_local(feature_dir) def test_reject(retriever: Retriever): """Simple test reject pipeline.""" real_questions = [ '姚明是谁?', 'CBBA是啥?', '差多少嘞?', 'cnn 的全称是什么?', 'transformer啥意思?', '成都有什么好吃的推荐?', '树博士是什么?', '白马非马啥意思?', 'mmpose 如何安装?', '今天天气如何?', '写一首五言律诗?', '先有鸡还是先有蛋?', '如何在Gromacs中进行蛋白质的动态模拟?', 'wy-vSphere 7 海光平台兼容补丁?', '在Linux系统中,如何进行源码包的安装?' ] for example in real_questions: relative, _ = retriever.is_relative(example) if relative: logger.warning(f'process query: {example}') retriever.query(example) empty_cache() else: logger.error(f'reject query: {example}') empty_cache() def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser( description='Feature store for processing directories.') parser.add_argument('--work_dir', type=str, default='/home/chat_demo/work_dir/', help='自定义.') parser.add_argument( '--repo_dir', type=str, default='/home/chat_demo/work_dir/jifu/original', help='需要读取的文件目录.') parser.add_argument( '--config_path', default='/home/chat_demo/config.ini', help='config目录') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() log_file_path = os.path.join(args.work_dir, 'application.log') logger.add(log_file_path, rotation='10MB', compression='zip') config = configparser.ConfigParser() config.read(args.config_path) embedding_model_path = config['feature_database']['embedding_model_path'] reranker_model_path = config['feature_database']['reranker_model_path'] reject_throttle = float(config['feature_database']['reject_throttle']) cache = CacheRetriever(embedding_model_path=embedding_model_path, reranker_model_path=reranker_model_path) fs_init = FeatureDataBase(embeddings=cache.embeddings, reranker=cache.reranker, reject_throttle=reject_throttle) # walk all files in repo dir file_opr = DocumentProcessor() files = file_opr.scan_directory(repo_dir=args.repo_dir) fs_init.initialize(files=files, work_dir=args.work_dir, file_opr=file_opr) file_opr.summarize(files) del fs_init retriever = cache.get(reject_throttle=reject_throttle, work_dir=args.work_dir) # with open(os.path.join(args.work_dir, 'sample', 'positive.json')) as f: # positive_sample = json.load(f) # with open(os.path.join(args.work_dir, 'sample', 'negative.json')) as f: # negative_sample = json.load(f) with open(os.path.join(args.work_dir, 'jifu', 'positive.txt'), 'r', encoding='utf-8') as file: positive_sample = [] for line in file: positive_sample.append(line.strip()) with open(os.path.join(args.work_dir, 'jifu', 'negative.txt'), 'r', encoding='utf-8') as file: negative_sample = [] for line in file: negative_sample.append(line.strip()) reject_throttle = retriever.update_throttle(work_dir=args.work_dir, config_path=args.config_path, positive_sample=positive_sample, negative_sample=negative_sample) cache.pop('default') # test retriever = cache.get(reject_throttle=reject_throttle, work_dir=args.work_dir) test_reject(retriever)