import gc import json import logging import math import os import random import sys import traceback import warnings from copy import deepcopy from dataclasses import dataclass, field from typing import Dict, Optional import numpy as np import torch import torch.distributed as dist import transformers from internvl.dist_utils import init_dist from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM from internvl.model.internvl_chat import (InternVisionConfig, InternVisionModel, InternVLChatConfig, InternVLChatModel) from internvl.patch import (concat_pad_data_collator, replace_llama_rmsnorm_with_fused_rmsnorm, replace_train_sampler) from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN, IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, QUAD_END_TOKEN, QUAD_START_TOKEN, REF_END_TOKEN, REF_START_TOKEN) from internvl.train.dataset import (ConcatDataset, TCSLoader, WeightedConcatDataset, build_transform, dynamic_preprocess, preprocess, preprocess_internlm, preprocess_mpt, preprocess_phi3) from internvl.train.trainer_monkey_patch import replace_create_optimizer from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError from torch.utils.data import Dataset from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, Trainer, TrainingArguments, set_seed) from transformers.trainer_utils import get_last_checkpoint from transformers.utils.logging import (enable_default_handler, enable_explicit_format, set_verbosity) # Apply necessary patches for the transformers library replace_llama_rmsnorm_with_fused_rmsnorm() replace_train_sampler() # Try to import petrel_client for image loading, fallback to PIL if unavailable try: from petrel_client.client import Client from petrel_client.common.config import Config has_tcs_loader = True except ImportError as E: print('petrel_client is not installed. Using PIL to load images.') has_tcs_loader = False # Set constants for image processing and logging IGNORE_INDEX = -100 Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True MaximumDecompressedSize = 1024 MegaByte = 2 ** 20 PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte warnings.filterwarnings('ignore') logger = logging.getLogger(__name__) os.environ['TOKENIZERS_PARALLELISM'] = 'true' @dataclass class ModelArguments: """ Arguments for specifying model, tokenizer, and configurations. """ model_name_or_path: Optional[str] = field( default=None, metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'} ) vision_path: Optional[str] = field( default=None, metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'} ) llm_path: Optional[str] = field( default=None, metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'} ) mlp_path: Optional[str] = field( default=None, metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'} ) freeze_llm: bool = field( default=False, metadata={'help': 'Set to True to freeze the LLM decoder.'}, ) freeze_backbone: bool = field( default=False, metadata={'help': 'Set to True to freeze the vision backbone of the model.'}, ) freeze_mlp: bool = field( default=False, metadata={'help': 'Set to True to freeze the MLP layers of the model.'}, ) unfreeze_vit_layers: int = field( default=0, metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'}, ) vision_select_layer: int = field( default=-1, metadata={'help': 'Specify the layer of ViT feature map to use. Default is last layer.'}, ) use_backbone_lora: int = field( default=0, metadata={'help': 'Set the LoRA adapter rank for the backbone model. Default is 0.'} ) use_llm_lora: int = field( default=0, metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'} ) unfreeze_lm_head: bool = field( default=False, metadata={'help': "Set to True to unfreeze the language model's head."}, ) use_custom_trainer: bool = field( default=False, metadata={'help': 'Set to True to enable the use of a custom trainer.'}, ) grad_checkpoint: Optional[bool] = field( default=False, metadata={'help': 'Set to True to use gradient checkpointing.'}, ) drop_path_rate: float = field( default=0.0, metadata={'help': 'Set the drop path rate for the ViT model. Default is 0.'}, ) ps_version: str = field( default='v2', metadata={'help': 'Specify the version of pixel shuffle implementation. Default is `v1`.' 'Please use `v2` to fix the bug of transposed image.'} ) @dataclass class DataTrainingArguments: """ Arguments for specifying data input for training and evaluation. """ max_seq_length: Optional[int] = field( default=2048, metadata={ 'help': ( 'The maximum total input sequence length after tokenization. Sequences longer ' 'than this will be truncated, sequences shorter will be padded.' ) }, ) force_image_size: Optional[int] = field( default=448, metadata={'help': 'Set the desired size for the image. Default is 224.'}, ) down_sample_ratio: Optional[float] = field( default=0.5, metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 1.0.'}, ) pad2square: Optional[bool] = field( default=False, metadata={'help': 'Pad the image to a square shape if set to True.'}, ) conv_style: Optional[str] = field( default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'} ) meta_path: Optional[str] = field( default=None, metadata={'help': 'The path of the meta file of datasets.'}, ) use_data_resampling: Optional[bool] = field( default=False, metadata={'help': 'Set to True to use data resampling.'}, ) dynamic_image_size: Optional[bool] = field( default=False, metadata={'help': 'Set to True to use dynamic image size.'}, ) use_thumbnail: Optional[bool] = field( default=False, metadata={'help': 'Set to True to add a thumbnail image.'}, ) min_dynamic_patch: Optional[int] = field( default=1, metadata={'help': 'The minimum number of dynamic patches. Default is 1.'}, ) max_dynamic_patch: Optional[int] = field( default=12, metadata={'help': 'The maximum number of dynamic patches. Default is 6.'}, ) normalize_type: Optional[str] = field( default='imagenet', metadata={'help': 'The normalize type for the image. Default is imagenet.'}, ) class LazySupervisedDataset(Dataset): """Dataset for supervised fine-tuning.""" def __init__( self, template_name, meta, tokenizer, tcs_loader, ds_name, num_image_token, image_size=224, is_train=True, pad2square=False, group_by_length=False, dynamic_image_size=False, use_thumbnail=False, min_dynamic_patch=1, max_dynamic_patch=6, min_num_frame=4, # for video data max_num_frame=12, # for video data sampling_method='rand', # for video data repeat_time=1, normalize_type='imagenet', random_seed=0, ): super(LazySupervisedDataset, self).__init__() self.ds_name = ds_name self.tokenizer = tokenizer self.template_name = template_name self.num_image_token = num_image_token logger.info(f'[Dataset] num_image_token: {num_image_token}') logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}') logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}') logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}') self.image_size = image_size self.is_train = is_train self.pad2square = pad2square self.max_num_frame = max_num_frame self.min_num_frame = min_num_frame self.sampling_method = sampling_method logger.info('Formatting inputs...Skip in lazy mode') assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}' with open(meta['annotation'], 'r') as f: self.raw_data = f.readlines() if repeat_time < 1: # If repeat_time is less than 1, select a portion of the data self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)] if repeat_time > 1: assert isinstance(repeat_time, int) # Repeat the list if repeat_time is greater than 1 self.raw_data = self.raw_data * repeat_time self.rng = np.random.default_rng(seed=random_seed) self.rng.shuffle(self.raw_data) gc.collect() self.root = meta['root'] self.cached_data_dict = {} self.tcs_loader = tcs_loader self.group_by_length = group_by_length self.dynamic_image_size = dynamic_image_size self.use_thumbnail = use_thumbnail self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch self.normalize_type = normalize_type # If the precomputed length does not exist, roughly estimate the length of # each sample to improve the efficiency of group_by_length. if self.group_by_length: self.conv2length = {} # Using a dictionary to speed up token length calculation self.length = [] for data_item in self.raw_data: data_item = json.loads(data_item) if 'length' in data_item: token_length = data_item['length'] # Use precomputed length if available else: # Compute token length using the tokenizer conversations = '\n'.join([temp['value'] for temp in data_item['conversations']]) str_length = len(conversations) if str_length not in self.conv2length: token_length = tokenizer( conversations, return_tensors='pt', padding=False, truncation=False, ).input_ids.size(1) self.conv2length[str_length] = token_length + num_image_token * ( max_dynamic_patch + use_thumbnail) else: token_length = self.conv2length[str_length] self.length.append(token_length) gc.collect() def __len__(self): return len(self.raw_data) def get_preprocess_function(self): # Select the appropriate preprocessing function based on the template name if self.template_name == 'Hermes-2': preprocess_function = preprocess_mpt elif self.template_name == 'internlm2-chat': preprocess_function = preprocess_internlm elif self.template_name == 'phi3-chat': preprocess_function = preprocess_phi3 else: preprocess_function = preprocess return preprocess_function def load_image(self, image_path): # Load the image using tcs_loader if available, otherwise use PIL if self.tcs_loader is not None and 's3://' in image_path: return self.tcs_loader(image_path) return Image.open(image_path).convert('RGB') def get_image_path(self, image_path): if image_path.startswith('s3://'): # for ceph image_path = self.root + image_path else: # for local image image_path = os.path.join(self.root, image_path) return image_path def get_transform(self): # Build transformation function transform = build_transform(is_train=self.is_train, input_size=self.image_size, pad2square=self.pad2square, normalize_type=self.normalize_type) return transform def multi_modal_get_item(self, data_item): # Build transformation function transform = self.get_transform() # Ensure the first conversation contains an image placeholder if '' not in data_item['conversations'][0]['value']: data_item['conversations'][0]['value'] = '\n' + data_item['conversations'][0]['value'] # Merge the image path image_path = self.get_image_path(data_item['image']) # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch, image_size=self.image_size, use_thumbnail=self.use_thumbnail) else: # Otherwise, use the original image as a single patch images = [image] # Apply the transformation to each image and stack the results into a tensor pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) # Ensure that there is only one patch if dynamic image size is not enabled num_patches = pixel_values.size(0) if not self.dynamic_image_size: assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.' # Select the appropriate preprocessing function based on the template name preprocess_function = self.get_preprocess_function() # Preprocess the conversations and generate the return dictionary ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], self.tokenizer, [self.num_image_token * num_patches], group_by_length=self.group_by_length, ds_name=self.ds_name) # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) return ret def multi_modal_multi_image_get_item(self, data_item): # Build transformation function transform = self.get_transform() images, num_tiles = [], [] num_image = len(data_item['image']) for image_path in data_item['image']: # Merge the image path image_path = self.get_image_path(image_path) # Load the image using tcs_loader if available, otherwise use PIL image = self.load_image(image_path) if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically image = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch // num_image, image_size=self.image_size, use_thumbnail=self.use_thumbnail) images += image num_tiles.append(len(image)) else: # Otherwise, use the original image as a single patch images.append(image) num_tiles.append(1) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) num_patches = pixel_values.size(0) # Select the appropriate preprocessing function based on the template name preprocess_function = self.get_preprocess_function() # Preprocess the conversations and generate the return dictionary num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles] ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])], self.tokenizer, num_image_tokens, group_by_length=self.group_by_length, ds_name=self.ds_name, num_image=num_image) # Create the final return dictionary ret = dict( input_ids=ret['input_ids'][0], labels=ret['labels'][0], attention_mask=ret['attention_mask'][0], pixel_values=pixel_values, image_flags=torch.tensor([1] * num_patches, dtype=torch.long) ) return ret def video_get_item(self, data_item): # Build transformation function transform = self.get_transform() # Ensure the first conversation contains a video placeholder if '