import io import os import copy import json import logging import torch import random from typing import List, Optional, Tuple, Union, Dict, Sequence from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from vary.data.base_dataset import BaseDataset from vary.utils.constants import * from vary.utils import conversation as conversation_lib class ConversationDataset(BaseDataset): """Conversation format dataset stage2 fine-tuning.""" """ 这段代码是一个名为ConversationDataset的Python类,它继承自BaseDataset类。 这个类主要用于处理对话格式的数据集,特别是在多模态配置(multimodal_cfg)的情况下。 这个方法中,首先调用了父类的构造函数来初始化一些基本的属性。然后,它设置了默认的对话模板,并打印了一些日志信息。 接着,它遍历datasets,加载数据和图片路径,并将它们添加到列表中。最后,它将数据和图片路径的列表打乱,并将它们赋值给类的属性。 """ def __init__(self, datasets, tokenizer, multimodal_cfg): # 调用父类的构造函数 super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg) # v0 version format conversation # 初始化会话格式为mpt conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"] # 输出警告信息,说明正在格式化输入为mpt-fixed类型的会话 logging.warning("Formatting inputs into conversation type: mpt-fixed") logging.warning("Loading data...") # 初始化数据字典列表和图片路径列表 list_data_dict = [] list_image_path = [] # 遍历数据集名称,加载每个数据集 for name in datasets.split("+"): # for name in vary_data_dict[name_all]: dataset = CONVERSATION_DATA[name] data_path = dataset['annotations'] data = json.load(open(data_path, "r")) # 加载注释数据 list_data_dict.extend(data) # 将数据添加到列表中 image_path = dataset['images'] list_image_path.extend([image_path] * len(data)) # 为每个数据项添加图片路径 # 输出每个数据集提供的会话数量 logging.warning(f"Data from {data_path} provide {len(data)} conversations.") # 确保数据项和图片路径数量一致 assert len(list_data_dict) == len(list_image_path) # 输出总共有多少会话 logging.warning(f"{len(list_data_dict)} conversations in total.") # 打乱数据和图片路径的组合,并重新分配 a_new_list = list(zip(list_data_dict, list_image_path)) random.shuffle(a_new_list) list_data_dict_new, list_image_path_new = zip(*a_new_list) self.list_data_dict = list_data_dict_new self.list_image_path = list_image_path_new # 设置图像标记相关的令牌 self.im_patch_token = 151859 self.im_start_token = 151857 self.im_end_token = 151858 def multimodal_processor(self, sources): """ 这段代码是一个名为`ConversationDataset`的Python类,它继承自`BaseDataset`类。 这个类主要用于处理对话格式的数据集,特别是在多模态配置(`multimodal_cfg`)的情况下。 multimodal_processor方法:这个方法用于处理多模态的源数据。它遍历sources,并检查每个源数据的第一个元素的'value'是否包含默认的图片标记。 如果不包含,它会抛出一个断言错误。 """ """ 处理多模态源数据的方法。 此方法遍历输入的源数据列表,对每条源数据进行处理。首先,检查是否在多模态配置中指定了在对话前添加图片标记(`sep_image_conv_front`), 如果指定且第一个元素的'value'中不包含默认的图片标记,则抛出断言错误。然后,对于每个源数据中的句子, 将默认图片标记替换为一定长度的图片标记字符串,如果在多模态配置中启用了使用图片开始和结束标记(`use_im_start_end`), 则在图片标记字符串前后添加图片开始和结束标记。 参数: - sources: 包含多模态源数据的列表。每个源数据是一个列表,其中第一个元素是特殊的源标记,后续元素是具体的句子信息。 返回值: - 处理后的多模态源数据列表。 """ for source in sources: # 检查是否在对话前添加图片标记,并处理图片标记 if self.multimodal_cfg['sep_image_conv_front']: assert DEFAULT_IMAGE_TOKEN in source[0]['value'] source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] # 重新格式化源标记,添加角色信息 for sentence in source: # 准备替换的图片标记字符串 replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len'] # if self.multimodal_cfg['use_im_start_end']: replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN # 如果启用了使用图片开始和结束标记,格式化图片标记字符串 sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def _tokenize_fn(self, strings): """Tokenize a list of strings.""" """ 对字符串列表进行分词处理。 此函数接受一个字符串列表,并使用实例关联的分词器对其进行分词。 返回一个字典,其中包含分词后的输入、标签以及它们的长度。 参数: - strings (list): 需要分词的字符串列表。 返回: - dict: 包含以下键值对的字典: - input_ids (list): 分词后输入ID的列表。 - labels (list): 分词后标签ID的列表(如果适用,否则与input_ids相同)。 - input_ids_lens (list): 分词后输入的长度列表。 - labels_lens (list): 分词后标签的长度列表。 """ # 对列表中的每个字符串进行分词并存储结果 tokenized_list = [ self.tokenizer( text, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True, ) for text in strings ] # 从每个分词后的字符串中提取input_ids input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] # 计算每个分词字符串的长度(不包括填充令牌) input_ids_lens = labels_lens = [ tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] # 返回包含所有相关分词数据的字典 return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(self, target, tokenized_lens, speakers): """ 隐藏目标序列中特定部分的数据。 参数: - target: 待处理的目标序列,通常是一个序列化的文本或音频表示。 - tokenized_lens: 分词后每个序列的长度列表。 - speakers: 对应于每个序列的说话者信息列表。 返回值: - 修改后的target,其中特定部分被隐藏。 """ # cur_idx = 0 # 初始化当前索引为第一个序列的长度 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度 target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分 # 遍历剩余的序列长度和说话者信息,隐藏相应部分的目标序列 for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker.lower() == "human": target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX # 隐藏human说话者的部分,偏移量+2可能用于标识说话者 cur_idx += tokenized_len # 更新当前索引以处理下一个序列 def token_processor(self, sources): """ 处理对话源数据,将其转换为适用于模型输入的格式,包括标记化和目标掩码处理。 参数: sources - 一个包含多个对话来源的列表,每个对话来源自身是一个列表,其中包含了多个句子字典, 每个字典包含"from"和"value"键,分别指明句子的来源("human"或"gpt")和句子内容。 返回值: 一个字典,包含两个键: - input_ids: 经过标记化处理后的输入序列ID,格式为PyTorch张量。 - labels: 对应的掩码处理后的目标序列ID,用于模型训练,格式同input_ids。 """ # 复制默认对话对象并定义角色 conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates # 应用对话模板 conversations = [] for i, source in enumerate(sources): # 如果首个消息不是来自人类,则跳过 if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] # 初始化对话消息 conv.messages = [] for j, sentence in enumerate(source): role = roles[sentence["from"]] # 确保角色匹配预期顺序 assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) # Tokenize conversations # 对话内容标记化 input_ids = self.tokenizer( conversations, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True, ).input_ids # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) targets = input_ids.clone() # 确保使用的分隔符风格匹配预期 assert conv.sep_style == conversation_lib.SeparatorStyle.MPT # 目标掩码处理 sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): # 计算非填充token的总数 total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # 系统 + 用户 + gpt for conv_idx in range(3, len(rounds), 2): re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # 用户 + gpt cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): if rou == "": break parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids) instruction_len = len(self.tokenizer(parts[0]).input_ids) target[cur_len : cur_len + instruction_len] = IGNORE_INDEX cur_len += round_len target[cur_len:] = IGNORE_INDEX # 警告处理:标记化结果与预期长度不匹配 if cur_len < self.tokenizer.model_max_length: if cur_len != total_len: target[:] = IGNORE_INDEX print( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) return dict( input_ids=input_ids, labels=targets, ) def __getitem__(self, i) -> Dict[str, torch.Tensor]: # data = self.list_data_dict[i] """ 根据索引i获取数据项,并进行预处理。如果数据项包含图像,则同时进行图像预处理。 参数: i: 索引值,用于指定要获取的数据项。 返回值: 一个字典,包含经过处理的文本和图像数据,键包括input_ids, labels, image, image_high等。 """ # 深拷贝数据项以避免修改原始数据 data = copy.deepcopy(self.list_data_dict[i]) # 检查数据项是否为字典,并处理其中的图像数据 if isinstance(data, dict): if 'image' in data: image_path = self.list_image_path[i] image_file = data['image'] # 尝试打开并转换图像 try: image = Image.open(image_path + image_file).convert('RGB') except: print(f'cannot identify image file {image_path + image_file}.') return self.__getitem__(0) # 尝试打开并转换图像 try: image, image_1 = self.image_processor(image) except: print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!') return self.__getitem__(0) # 对对话数据进行预处理 conversations = self.multimodal_processor([data["conversations"]]) else: # 如果数据项不是字典,则直接将其作为对话处理 conversations = [data] # 使用token处理器处理对话数据 # align with fastchat & llava here, put the conversation into a list for tokenization data_dict = self.token_processor(conversations) # 精简处理后的数据字典,只保留input_ids和labels data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # 如果数据项包含图像,将图像数据添加到数据字典中 if isinstance(data, dict) and 'image' in data: data_dict['image'] = [image] data_dict['image_high'] = [image_1] else: # 如果数据项不包含图像,为图像数据添加占位符 crop_size = self.multimodal_cfg['image_processor'].crop_size data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] data_dict['image_high'] = [torch.zeros(3, 1024, 1024)] return data_dict