import os import nltk nltk.data.path.append('/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data') os.environ["CUDA_VISIBLE_DEVICES"] = '1' os.environ['NLTK_DATA'] = '/home/zhangwq/project/whl/nltk/nltk_data-gh-pages/nltk_data' import base64 import argparse import uuid import re import io from PIL import Image from IPython.display import HTML, display from langchain_experimental.open_clip import OpenCLIPEmbeddings from langchain.schema.runnable import RunnableLambda, RunnablePassthrough from langchain.retrievers.multi_vector import MultiVectorRetriever from langchain.schema.document import Document from langchain.storage import InMemoryStore from langchain_community.vectorstores import FAISS from unstructured.partition.pdf import partition_pdf from langchain_core.messages import HumanMessage from langchain_community.chat_models import ChatVertexAI from langchain.schema.output_parser import StrOutputParser from langchain_core.runnables import RunnableLambda from loguru import logger def plt_img_base64(img_base64): # Create an HTML img tag with the base64 string as the source image_html = f'' # Display the image by rendering the HTML display(HTML(image_html)) def multi_modal_rag_chain(retriever): """ Multi-modal RAG chain """ # Multi-modal LLM model = ChatVertexAI( temperature=0, model_name="gemini-pro-vision", max_output_tokens=1024 ) # RAG pipeline chain = ( { "context": retriever | RunnableLambda(split_image_text_types), "question": RunnablePassthrough(), } | RunnableLambda(img_prompt_func) | model | StrOutputParser() ) return chain def img_prompt_func(data_dict): """ Join the context into a single string """ formatted_texts = "\n".join(data_dict["context"]["texts"]) messages = [] # Adding the text for analysis text_message = { "type": "text", "text": ( "You are an AI scientist tasking with providing factual answers.\n" "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n" "Use this information to provide answers related to the user question. \n" f"User-provided question: {data_dict['question']}\n\n" "Text and / or tables:\n" f"{formatted_texts}" ), } messages.append(text_message) # Adding image(s) to the messages if present if data_dict["context"]["images"]: for image in data_dict["context"]["images"]: image_message = { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}, } messages.append(image_message) return [HumanMessage(content=messages)] def split_image_text_types(docs): """ Split base64-encoded images and texts """ b64_images = [] texts = [] for doc in docs: # Check if the document is of type Document and extract page_content if so if isinstance(doc, Document): doc = doc.page_content if looks_like_base64(doc) and is_image_data(doc): doc = resize_base64_image(doc, size=(1300, 600)) b64_images.append(doc) else: texts.append(doc) if len(b64_images) > 0: return {"images": b64_images[:1], "texts": []} return {"images": b64_images, "texts": texts} def resize_base64_image(base64_string, size=(128, 128)): """ Resize an image encoded as a Base64 string """ # Decode the Base64 string img_data = base64.b64decode(base64_string) img = Image.open(io.BytesIO(img_data)) # Resize the image resized_img = img.resize(size, Image.LANCZOS) # Save the resized image to a bytes buffer buffered = io.BytesIO() resized_img.save(buffered, format=img.format) # Encode the resized image to Base64 return base64.b64encode(buffered.getvalue()).decode("utf-8") def is_image_data(b64data): """ Check if the base64 data is an image by looking at the start of the data """ image_signatures = { b"\xFF\xD8\xFF": "jpg", b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": "png", b"\x47\x49\x46\x38": "gif", b"\x52\x49\x46\x46": "webp", } try: header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes for sig, format in image_signatures.items(): if header.startswith(sig): return True return False except Exception: return False def looks_like_base64(sb): """Check if the string looks like base64""" return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None def create_multi_vector_retriever( vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images ): """ Create retriever that indexes summaries, but returns raw images or texts """ # Initialize the storage layer store = InMemoryStore() id_key = "doc_id" # Create the multi-vector retriever retriever = MultiVectorRetriever( vectorstore=vectorstore, docstore=store, id_key=id_key, ) # Helper function to add documents to the vectorstore and docstore def add_documents(retriever, doc_summaries, doc_contents): doc_ids = [str(uuid.uuid4()) for _ in doc_contents] summary_docs = [ Document(page_content=s, metadata={id_key: doc_ids[i]}) for i, s in enumerate(doc_summaries) ] retriever.vectorstore.add_documents(summary_docs) retriever.docstore.mset(list(zip(doc_ids, doc_contents))) # Add texts, tables, and images # Check that text_summaries is not empty before adding if text_summaries: add_documents(retriever, text_summaries, texts) # Check that table_summaries is not empty before adding if table_summaries: add_documents(retriever, table_summaries, tables) # Check that image_summaries is not empty before adding if image_summaries: add_documents(retriever, image_summaries, images) return retriever def extract_elements_from_pdf(file_path: str, image_output_dir_path: str): pdf_list = [os.path.join(file_path, file) for file in os.listdir(file_path) if file.endswith('.pdf')] tables = [] texts = [] raw_pdf_elements = partition_pdf( filename=pdf_list[0], extract_images_in_pdf=True, infer_table_structure=True, chunking_strategy="by_title", max_characters=4000, new_after_n_chars=3800, combine_text_under_n_chars=2000, image_output_dir_path=image_output_dir_path, ) for element in raw_pdf_elements: if "unstructured.documents.elements.Table" in str(type(element)): tables.append(str(element)) elif "unstructured.documents.elements.CompositeElement" in str(type(element)): texts.append(str(element)) return texts, tables class Summary: def __init__(self): pass def encode_image(self, image_path): """Getting the base64 string""" with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def image_summarize(self, img_base64, prompt): """Make image summary""" model = ChatVertexAI(model_name="gemini-pro-vision", max_output_tokens=1024) msg = model( [ HumanMessage( content=[ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, }, ] ) ] ) return msg.content def generate_img_summaries(self, path): """ Generate summaries and base64 encoded strings for images path: Path to list of .jpg files extracted by Unstructured """ # Store base64 encoded images img_base64_list = [] # Store image summaries image_summaries = [] # Prompt prompt = """You are an assistant tasked with summarizing images for retrieval. \ These summaries will be embedded and used to retrieve the raw image. \ Give a concise summary of the image that is well optimized for retrieval.""" # Apply to images for img_file in sorted(os.listdir(path)): if img_file.endswith(".jpg"): img_path = os.path.join(path, img_file) base64_image = self.encode_image(img_path) img_base64_list.append(base64_image) image_summaries.append(self.image_summarize(base64_image, prompt)) return img_base64_list, image_summaries def generate_text_summaries(self, texts, tables): text_summaries = texts table_summaries = tables return text_summaries, table_summaries def parse_args(): """Parse command-line arguments.""" parser = argparse.ArgumentParser() parser.add_argument( '--file_path', type=str, default='/home/zhangwq/data/art_test/pdf', help='') parser.add_argument( '--image_output_dir_path', default='/home/zhangwq/data/art_test', help='') parser.add_argument( '--query', default='compare and contrast between mistral and llama2 across benchmarks and explain the reasoning in detail', help='') args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() summary = Summary() texts, tables = extract_elements_from_pdf(file_path=args.file_path, image_output_dir_path=args.image_output_dir_path) text_summaries, table_summaries = summary.generate_text_summaries(texts, tables) img_base64_list, image_summaries = summary.generate_img_summaries(args.file_path) embeddings = OpenCLIPEmbeddings( model_name="/home/zhangwq/model/CLIP_VIT", checkpoint="laion2b_s34b_b88k") embeddings.client = embeddings.client.half() vectorstore = FAISS(collection_name="mm_rag_mistral", embedding_function=embeddings) # Create retriever retriever = create_multi_vector_retriever( vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, img_base64_list, ) chain_multimodal_rag = multi_modal_rag_chain(retriever) docs = retriever.get_relevant_documents(args.query, limit=1) logger.info(docs[0]) chain_multimodal_rag.invoke(args.query)