import io import os import copy import json import logging import torch import random from typing import List, Optional, Tuple, Union, Dict, Sequence from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True from vary.data.base_dataset import BaseDataset from vary.utils.constants import * from vary.utils import conversation as conversation_lib # 这段代码定义了一个名为 CaptionDataset 的类,它继承自 BaseDataset 类。 # 这个类主要用于处理和存储对话格式的数据集,以便于后续的模型训练。 class CaptionDataset(BaseDataset): """Conversation format dataset stage2 fine-tuning.""" """CaptionDataset 类的 __init__ 方法接收三个参数:datasets、tokenizer 和 multimodal_cfg。 其中,datasets 是数据集的名称,tokenizer 是一个预训练的分词器,multimodal_cfg 是一个字典,包含了多模态配置信息。""" def __init__(self, datasets, tokenizer, multimodal_cfg): # 调用父类初始化方法 super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg) # v0 version format conversation """在 __init__ 方法中,首先调用父类的初始化方法,然后设置默认的对话模板为 "default"。 然后,对每个数据集,读取其注释数据和图像路径,并将这些数据添加到 list_data_dict 和 list_image_path 列表中。 然后,将这两个列表打乱并重新组合,最后将这两个列表以及一些分词器生成的 token 保存为类的属性。""" # 设置默认对话模板 conversation_lib.default_conversation = conversation_lib.conv_templates["default"] # 开始格式化输入数据 logging.warning("Formatting inputs into conversation type: v0-fixed") logging.warning("Loading data...") # 初始化用于存储数据和图像路径的列表 list_data_dict = [] list_image_path = [] # 遍历数据集名称,加载数据 for name in datasets.split("+"): dataset = CONVERSATION_DATA[name] # in vary.utils # 从预定义的数据字典中获取数据集 # 加载注释数据 data_path = dataset['annotations'] data = json.load(open(data_path, "r")) list_data_dict.extend(data) # 获取图像路径,并与注释数据对应 image_path = dataset['images'] list_image_path.extend([image_path] * len(data)) # 输出数据加载信息 logging.warning(f"Data from {data_path} provide {len(data)} conversations.") # 确保数据和图像路径的数量一致 assert len(list_data_dict) == len(list_image_path) # 输出总数据量 logging.warning(f"{len(list_data_dict)} conversations in total.") # 打乱并重新组合数据和图像路径 a_new_list = list(zip(list_data_dict, list_image_path)) random.shuffle(a_new_list) list_data_dict_new, list_image_path_new = zip(*a_new_list) self.list_data_dict = list_data_dict_new self.list_image_path = list_image_path_new # 为图像标记生成对应的token ID self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids( [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) def multimodal_processor(self, sources): """ 处理多模态源数据,将图像标记替换为一定长度的图像标记序列。 :param sources: 包含多个源数据的列表,每个源数据是一个字典列表,其中第一个字典代表图像信息, 其他字典代表与图像相关联的文本信息。 :return: 经过图像标记替换处理后的源数据列表。 """ for source in sources: # 为每个源数据的图像信息设置默认的图像标记 source[0]['value'] = DEFAULT_IMAGE_TOKEN for sentence in source: # 根据配置,生成替换用的图像标记序列 replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len'] if self.multimodal_cfg['use_im_start_end']: # 如果使用图像开始和结束标记,则在图像标记序列两端添加开始和结束标记 replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN # 使用生成的图像标记序列替换句子中的默认图像标记 sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) return sources def _tokenize_fn(self, strings): """Tokenize a list of strings.""" """ 对字符串列表进行分词处理,以供模型输入。 此函数处理一个字符串列表,对每个字符串进行分词,填充或截断令牌到最大长度,并创建输入掩码。 它还处理特定的分词情况,如确保每个输入的最后一个令牌是特定令牌。 参数: - strings (list): 需要分词的字符串列表。 返回: - dict: 包含以下键值对的字典: - input_ids (list): 表示每个输入字符串的令牌ID列表。 - labels (list): 用作标签的令牌ID列表,通常与input_ids相同。 - input_ids_lens (list): 分词后的字符串长度列表。 - labels_lens (list): 标签长度列表。 """ # 对字符串列表中的每个字符串进行分词并准备模型输入 tokenized_list = [ self.tokenizer( text, return_tensors="pt", padding="longest", max_length=self.tokenizer.model_max_length, truncation=True, ) for text in strings ] # 从分词结果中提取输入ID和标签 input_ids = labels = [ tokenized.input_ids[0] for tokenized in tokenized_list ] # 确保每个输入的最后一个令牌是特定令牌(此处为2) for idx, ii in enumerate(input_ids): if ii[-1] != 2: input_ids[idx][-1] = 2 labels[idx][-1] = 2 # 计算每个输入的非填充令牌长度 input_ids_lens = labels_lens = [ tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list ] # 返回包含所有必需处理数据的字典 return dict( input_ids=input_ids, labels=labels, input_ids_lens=input_ids_lens, labels_lens=labels_lens, ) def _mask_targets(self, target, tokenized_lens, speakers): """ 隐藏目标序列中特定部分的数据。 参数: - target: 待处理的目标序列,通常是一个序列化的文本或音频表示。 - tokenized_lens: 分词后每个序列的长度列表。 - speakers: 对应于每个序列的说话者信息列表。 返回值: - 修改后的target,其中根据说话者信息和长度信息进行了掩码处理。 """ # cur_idx = 0 # 初始化当前索引为第一个序列的长度 cur_idx = tokenized_lens[0] tokenized_lens = tokenized_lens[1:] # 移除已经处理的第一个序列的长度 target[:cur_idx] = IGNORE_INDEX # 隐藏目标序列的起始部分 # 遍历剩余的序列长度和说话者信息,进行相应的掩码处理 for tokenized_len, speaker in zip(tokenized_lens, speakers): if speaker.lower() == "human": target[cur_idx:tokenized_len] = IGNORE_INDEX # 如果说话者是“human”,隐藏该部分目标序列 cur_idx += tokenized_len # 更新当前索引位置,准备处理下一个序列 def _add_speaker_and_signal(self, header, source, get_conversation=True): """Add speaker and start/end signal on each round.""" """ 在每个回合中添加说话者和开始/结束信号。 此方法处理一个句子列表,识别每个句子的说话者,并向句子添加开始和结束信号。 如果`get_conversation`为真,它将所有句子连接成一个完整的对话字符串。 参数: - header (str): 对话的初始消息或标题。 - source (list): 每个句子是一个字典,至少包含'from'键和'value'键。 'from'键指定说话者,'value'键指定所说的话。 - get_conversation (bool, 可选): 标志,表示是否将所有句子合并为一个单一的对话字符串。 默认为True。 返回: - str: 添加了说话者并带有适当开始和结束信号的最终对话字符串。 如果`get_conversation`为False,则返回None。 """ BEGIN_SIGNAL = "" END_SIGNAL = "\n" conversation = header # 初始化对话,附带标题 # 处理源列表中的每个句子 for sentence in source: from_str = sentence["from"] # 获取说话者 获取说话信息 # 将说话者名称映射到预定义的角色 if from_str.lower() == "human": from_str = conversation_lib.default_conversation.roles[0] else: from_str = conversation_lib.default_conversation.roles[1] sentence["value"] = sentence["value"] + END_SIGNAL # 向句子添加结束信号 if get_conversation: conversation += sentence["value"] # 若需要,将句子拼接到对话中 conversation += BEGIN_SIGNAL return conversation def token_processor(self, sources): """ Given a list of sources, each is a conversation list. This transform: 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 2. Concatenate conversations together; 3. Tokenize the concatenated conversation; 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. multimodal_processor 方法是用来处理多模态数据的。这个方法接收一个名为 sources 的参数,这个参数应该是一个包含多个数据源的列表。 每个数据源都应该是一个字典,包含 "value" 和 "sentence" 这两个键。 这个方法会遍历每个数据源,将 "value" 键的值替换为特定的 token,然后返回处理后的数据源列表。 """ """ 处理输入的对话列表,进行标记化和掩码处理。 参数: sources: 一个包含多个会话列表的列表,每个会话列表是一个字典列表,其中每个字典包含"value"和"sentence"键。 返回: 一个字典,包含输入ID和标签(被掩码处理的人类语言)的列表。 """ # add end signal and concatenate together # 为每个会话添加信号并在末尾连接 conversations = [] for source in sources: header = '' conversation = self._add_speaker_and_signal(header, source) # 添加说话者和信号到会话 conversations.append(conversation) # 对会话进行标记化 conversations_tokenized = self._tokenize_fn(conversations) # 标记化会话文本 input_ids = conversations_tokenized["input_ids"] # 获取输入ID # 创建目标的深拷贝,并掩码处理人类语言单词 targets = copy.deepcopy(input_ids) # 创建输入ID的深拷贝作为目标 for target, source in zip(targets, sources): # 遍历目标和源,进行掩码处理 tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"] # 计算每个说话者的标记长度 speakers = [sentence["from"] for sentence in source] # 获取每个句子的说话者 self._mask_targets(target, tokenized_lens, speakers) # 根据说话者信息掩码目标序列 return dict(input_ids=input_ids, labels=targets) # 返回处理后的输入和目标 def __getitem__(self, i) -> Dict[str, torch.Tensor]: # data = self.list_data_dict[i] """ 根据索引i获取数据集中的一个项目。 参数: - i: int, 数据集中项的索引。 返回值: - Dict[str, torch.Tensor]: 包含项目数据的字典,其中可能包括'image', 'image_high', 'input_ids', 'labels'键。 """ # 深拷贝数据以避免修改原始数据 data = copy.deepcopy(self.list_data_dict[i]) # 处理包含图像数据的项 if isinstance(data, dict): if 'image' in data: image_path = self.list_image_path[i] image_file = data['image'] # TODO this is a bug, because some json has wrong path # 尝试打开并处理图像 try: image = Image.open(image_path + image_file).convert('RGB') except: print(f'cannot identify image file {image_path+image_file}.') return self.__getitem__(0) # 尝试对图像进行进一步处理 try: image, image_high = self.image_processor(image) except: print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!') return self.__getitem__(0) conversations = self.multimodal_processor([data["conversations"]]) else: # 对于不包含图像数据的项 conversations = [data] # align with fastchat & llava here, put the conversation into a list for tokenization # 对会话数据进行标记化处理 data_dict = self.token_processor(conversations) # 简化数据字典结构 data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) # 如果数据包含图像,添加图像数据到字典中 if isinstance(data, dict) and 'image' in data: data_dict['image'] = [image] data_dict['image_high'] = [image_high] else: # 为不包含图像的数据项填充默认的图像张量 crop_size = self.multimodal_cfg['image_processor'].crop_size data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] # TODO sam is 1024*1024 data_dict['image_high'] = [torch.zeros(3, 1024, 1024)] return data_dict