import torch import transformers from dataclasses import dataclass, field from vary.utils.constants import * """ DataCollatorForSupervisedDataset 是一个数据整理类,它的主要任务是将一批数据整理成一个字典格式的批次(batch)。这个类需要一个 transformers.PreTrainedTokenizer 对象作为初始化参数。 它的 __call__ 方法会被调用来处理数据,这个方法接收一个名为 instances 的参数,这个参数应该是一个包含多个数据实例的列表。每个实例都应该是一个字典,包含 "input_ids"、"labels"、"image" 和 "image_high" 这四个键。 这个方法会将这些数据整理成一个字典,其中包含 "input_ids"、"labels"、"attention_mask" 和 "images" 这四个键,每个键对应的值都是一个批次的数据。 """ @dataclass class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer """ 将输入的实例列表转换为一个统一的批次数据字典。 参数: instances: 一个包含多个实例的列表,每个实例是一个字典,应包含`input_ids`, `labels`, `image`, 和 `image_high`键。 返回: 一个字典,包含批处理后的`input_ids`, `labels`, `attention_mask`, 和 `images`。其中`images`包含了原始图片和高分辨率图片。 """ def __call__(self, instances): # 从实例中提取`input_ids`和`labels`,并确保它们是张量 input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) # 提取图片,并将它们堆叠起来形成张量 images = [torch.stack(instance['image']) for instance in instances] # 将原始图片和高分辨率图片配对 images_high = [torch.stack(instance['image_high']) for instance in instances] images = list(zip(images, images_high)) # 对`input_ids`和`labels`进行序列填充,以便能够在同一批次中处理不同长度的输入 input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) # 构建批次数据字典,包括输入标识、标签、注意力掩码和图片 batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), images=images, ) return batch """ make_supervised_data_module 是一个函数,它的任务是创建一个监督学习的数据模块。这个函数接收四个参数: interleave、with_box、tokenizer 和 data_args。这个函数首先根据 data_args.conversation_version 的值来决定使用哪个数据集类,然后使用这个类和其他参数来创建一个数据集对象。 """ def make_supervised_data_module(interleave, with_box, tokenizer, data_args): """ 根据给定的参数配置,创建一个监督学习的数据模块。 参数: - interleave: 一个控制数据混洗方式的参数。 - with_box: 一个标志,指示是否在数据中包含边界框信息。 - tokenizer: 用于文本序列的标记化器。 - data_args: 包含各种数据加载和处理配置的数据 arguments 对象。 返回值: - 一个字典,包含训练数据集和数据胶合器(用于数据批处理)。 """ # 根据对话版本选择对应的Dataset类 if data_args.conversation_version == 'mpt': from vary.data.conversation_dataset_qwen import ConversationDataset dataset_cls = ConversationDataset elif data_args.conversation_version == 'opt': from vary.data.caption_opt import CaptionDataset dataset_cls = CaptionDataset # 初始化训练数据集 train_dataset = dataset_cls( tokenizer=tokenizer, datasets=data_args.datasets, multimodal_cfg=dict( sep_image_conv_front=data_args.sep_image_conv_front, image_token_len=data_args.image_token_len, image_aspect_ratio=data_args.image_aspect_ratio, use_im_start_end=data_args.use_im_start_end, image_processor=data_args.image_processor, image_processor_high = data_args.image_processor_high, box_limit=data_args.box_limit, ) ) # 创建数据胶合器 data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) # 返回配置好的训练数据集和数据胶合器 return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)