# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import json import random import numpy as np import torch from .data_utils import ( get_flattened_position_ids_extrapolate, get_flattened_position_ids_interpolate, len2weight, patchify, prepare_attention_mask_per_sample, ) from .dataset_info import DATASET_INFO, DATASET_REGISTRY from .transforms import ImageTransform from .video_utils import FrameSampler class DataConfig: def __init__( self, grouped_datasets, text_cond_dropout_prob=0.1, vit_cond_dropout_prob=0.4, vae_cond_dropout_prob=0.1, vae_image_downsample=16, max_latent_size=32, vit_patch_size=14, max_num_patch_per_side=70, ): self.grouped_datasets = grouped_datasets self.text_cond_dropout_prob = text_cond_dropout_prob self.vit_cond_dropout_prob = vit_cond_dropout_prob self.vit_patch_size = vit_patch_size self.max_num_patch_per_side = max_num_patch_per_side self.vae_cond_dropout_prob = vae_cond_dropout_prob self.vae_image_downsample = vae_image_downsample self.max_latent_size = max_latent_size class PackedDataset(torch.utils.data.IterableDataset): bos_token_id: int eos_token_id: int start_of_image: int end_of_image: int def __init__( self, data_config, tokenizer, special_tokens, local_rank, world_size, num_workers, expected_num_tokens=32768, max_num_tokens_per_sample=16384, max_num_tokens=36864, prefer_buffer_before=16384, max_buffer_size=50, interpolate_pos=False, use_flex=False, data_status=None, ): super().__init__() self.expected_num_tokens = expected_num_tokens self.max_num_tokens_per_sample = max_num_tokens_per_sample self.prefer_buffer_before = prefer_buffer_before self.max_num_tokens = max_num_tokens self.max_buffer_size = max_buffer_size self.tokenizer = tokenizer self.local_rank = local_rank self.world_size = world_size self.num_workers = num_workers self.use_flex = use_flex for k, v in special_tokens.items(): setattr(self, k, v) grouped_datasets, is_mandatory, grouped_weights = self.build_datasets( data_config.grouped_datasets, data_status ) self.grouped_datasets = grouped_datasets self.dataset_iters = [iter(dataset) for dataset in grouped_datasets] self.is_mandatory = is_mandatory self.grouped_weights = grouped_weights self.data_config = data_config self.interpolate_pos = interpolate_pos if self.interpolate_pos: self.get_flattened_position_ids = get_flattened_position_ids_interpolate else: self.get_flattened_position_ids = get_flattened_position_ids_extrapolate def build_datasets(self, datasets_metainfo, data_status): datasets = [] is_mandatory = [] grouped_weights = [] for grouped_dataset_name, dataset_args in datasets_metainfo.items(): is_mandatory.append(dataset_args.pop("is_mandatory", False)) grouped_weights.append(dataset_args.pop("weight", 0.0)) if "frame_sampler_args" in dataset_args.keys(): frame_sampler = FrameSampler(**dataset_args.pop("frame_sampler_args")) dataset_args["frame_sampler"] = frame_sampler if "image_transform_args" in dataset_args.keys(): transform = ImageTransform(**dataset_args.pop("image_transform_args")) dataset_args["transform"] = transform if "vit_image_transform_args" in dataset_args.keys(): vit_transform = ImageTransform( **dataset_args.pop("vit_image_transform_args") ) dataset_args["vit_transform"] = vit_transform if "dataset_names" in dataset_args.keys(): dataset_names = dataset_args.pop("dataset_names") else: dataset_names = DATASET_INFO[grouped_dataset_name].keys() if "num_used_data" not in dataset_args.keys(): dataset_args["num_used_data"] = [] append_num_used_data = True else: append_num_used_data = False dataset_args["data_dir_list"] = [] for item in dataset_names: if self.local_rank == 0: print(f"Preparing Dataset {grouped_dataset_name}/{item}") meta_info = DATASET_INFO[grouped_dataset_name][item] dataset_args["data_dir_list"].append(meta_info["data_dir"]) if append_num_used_data: dataset_args["num_used_data"].append(meta_info["num_total_samples"]) if "parquet_info_path" in meta_info.keys(): if "parquet_info" not in dataset_args.keys(): dataset_args["parquet_info"] = {} with open(meta_info["parquet_info_path"], "r") as f: parquet_info = json.load(f) dataset_args["parquet_info"].update(parquet_info) if "json_dir" in meta_info.keys(): # parquet/tar with json if "json_dir_list" not in dataset_args.keys(): dataset_args["json_dir_list"] = [meta_info["json_dir"]] else: dataset_args["json_dir_list"].append(meta_info["json_dir"]) if "jsonl_path" in meta_info.keys(): # jsonl with jpeg if "jsonl_path_list" not in dataset_args.keys(): dataset_args["jsonl_path_list"] = [meta_info["jsonl_path"]] else: dataset_args["jsonl_path_list"].append(meta_info["jsonl_path"]) resume_data_status = dataset_args.pop("resume_data_status", True) if ( data_status is not None and grouped_dataset_name in data_status.keys() and resume_data_status ): data_status_per_group = data_status[grouped_dataset_name] else: data_status_per_group = None dataset = DATASET_REGISTRY[grouped_dataset_name]( dataset_name=grouped_dataset_name, tokenizer=self.tokenizer, local_rank=self.local_rank, world_size=self.world_size, num_workers=self.num_workers, data_status=data_status_per_group, **dataset_args, ) datasets.append(dataset) return datasets, is_mandatory, grouped_weights def set_epoch(self, seed): for dataset in self.grouped_datasets: dataset.set_epoch(seed) def set_sequence_status(self): sequence_status = dict( curr=0, sample_lens=list(), packed_position_ids=list(), nested_attention_masks=list(), split_lens=list(), attn_modes=list(), packed_text_ids=list(), packed_text_indexes=list(), packed_label_ids=list(), ce_loss_indexes=list(), ce_loss_weights=list(), vae_image_tensors=list(), packed_latent_position_ids=list(), vae_latent_shapes=list(), packed_vae_token_indexes=list(), packed_timesteps=list(), mse_loss_indexes=list(), packed_vit_tokens=list(), vit_token_seqlens=list(), packed_vit_position_ids=list(), packed_vit_token_indexes=list(), ) return sequence_status def to_tensor(self, sequence_status): data = dict( sequence_length=sum(sequence_status["sample_lens"]), sample_lens=sequence_status["sample_lens"], packed_text_ids=torch.tensor(sequence_status["packed_text_ids"]), packed_text_indexes=torch.tensor(sequence_status["packed_text_indexes"]), packed_position_ids=torch.tensor(sequence_status["packed_position_ids"]), ) if not self.use_flex: data["nested_attention_masks"] = sequence_status["nested_attention_masks"] else: sequence_len = data["sequence_length"] pad_len = self.max_num_tokens - sequence_len data["split_lens"] = sequence_status["split_lens"] + [pad_len] data["attn_modes"] = sequence_status["attn_modes"] + ["causal"] data["sample_lens"] += [pad_len] # if the model has a convnet vae (e.g., as visual tokenizer) if len(sequence_status["vae_image_tensors"]) > 0: image_tensors = sequence_status.pop("vae_image_tensors") image_sizes = [item.shape for item in image_tensors] max_image_size = [max(item) for item in list(zip(*image_sizes))] padded_images = torch.zeros(size=(len(image_tensors), *max_image_size)) for i, image_tensor in enumerate(image_tensors): padded_images[ i, :, : image_tensor.shape[1], : image_tensor.shape[2] ] = image_tensor data["padded_images"] = padded_images data["patchified_vae_latent_shapes"] = sequence_status["vae_latent_shapes"] data["packed_latent_position_ids"] = torch.cat( sequence_status["packed_latent_position_ids"], dim=0 ) data["packed_vae_token_indexes"] = torch.tensor( sequence_status["packed_vae_token_indexes"] ) # if the model has a vit (e.g., as visual tokenizer) if len(sequence_status["packed_vit_tokens"]) > 0: data["packed_vit_tokens"] = torch.cat( sequence_status["packed_vit_tokens"], dim=0 ) data["packed_vit_position_ids"] = torch.cat( sequence_status["packed_vit_position_ids"], dim=0 ) data["packed_vit_token_indexes"] = torch.tensor( sequence_status["packed_vit_token_indexes"] ) data["vit_token_seqlens"] = torch.tensor( sequence_status["vit_token_seqlens"] ) # if the model is required to perform visual generation if len(sequence_status["packed_timesteps"]) > 0: data["packed_timesteps"] = torch.tensor(sequence_status["packed_timesteps"]) data["mse_loss_indexes"] = torch.tensor(sequence_status["mse_loss_indexes"]) # if the model is required to perform text generation if len(sequence_status["packed_label_ids"]) > 0: data["packed_label_ids"] = torch.tensor(sequence_status["packed_label_ids"]) data["ce_loss_indexes"] = torch.tensor(sequence_status["ce_loss_indexes"]) data["ce_loss_weights"] = torch.tensor(sequence_status["ce_loss_weights"]) return data def __iter__(self): total_weights = sum(self.grouped_weights) assert total_weights > 0.0 group_cumprobs = [ sum(self.grouped_weights[: i + 1]) / total_weights for i in range(len(self.grouped_weights)) ] sequence_status = self.set_sequence_status() batch_data_indexes = [] buffer = [] video_buffer = [] # Separate buffer for extremely long video samples while True: # Ensure at least one sample from each group if sequence_status["curr"] == 0: if len(video_buffer) > 0: sample = video_buffer.pop(0) num_tokens = sample["num_tokens"] + 2 * len(sample["sequence_plan"]) sequence_status = self.pack_sequence(sample, sequence_status) batch_data_indexes.append(sample["data_indexes"]) else: for group_index, group_iter in enumerate(self.dataset_iters): if self.is_mandatory[group_index]: while True: sample = next(group_iter) # if a sample is too long, put it in video buffer num_tokens = sample["num_tokens"] + 2 * len( sample["sequence_plan"] ) if num_tokens > self.max_num_tokens_per_sample: if len(video_buffer) < self.max_buffer_size: video_buffer.append(sample) print( f"Added sample with length {num_tokens} to video_buffer (size: {len(video_buffer)})" ) else: print( f"video_buffer full, skip a sample with length {num_tokens}" ) break elif num_tokens < self.max_num_tokens_per_sample: sequence_status = self.pack_sequence( sample, sequence_status ) batch_data_indexes.append(sample["data_indexes"]) break if sequence_status["curr"] >= self.expected_num_tokens: data = self.to_tensor(sequence_status) data["batch_data_indexes"] = batch_data_indexes print( f"Yielding {len(sequence_status['sample_lens'])} 3D data with length {sum(sequence_status['sample_lens'])}, length of each sample: {sequence_status['sample_lens']}" ) yield data sequence_status = self.set_sequence_status() batch_data_indexes = [] if sequence_status["curr"] < self.prefer_buffer_before and len(buffer) > 0: sample = buffer.pop(0) sample_from_buffer = True else: # sample normally across all groups n = random.random() group_index = 0 for i, cumprob in enumerate(group_cumprobs): if n < cumprob: group_index = i break sample = next(self.dataset_iters[group_index]) sample_from_buffer = False # if a sample is too long, skip it num_tokens = sample["num_tokens"] + 2 * len(sample["sequence_plan"]) if num_tokens > self.max_num_tokens_per_sample: if len(video_buffer) < self.max_buffer_size: video_buffer.append(sample) print( f"Added sample with length {num_tokens} to video_buffer (size: {len(video_buffer)})" ) else: print(f"video_buffer full, skip a sample with length {num_tokens}") continue if sequence_status["curr"] + num_tokens > self.max_num_tokens: if len(buffer) < self.max_buffer_size and not sample_from_buffer: buffer.append(sample) else: # print(f"Yielding data with length {sum(sequence_status['sample_lens'])}") data = self.to_tensor(sequence_status) data["batch_data_indexes"] = batch_data_indexes yield data sequence_status = self.set_sequence_status() batch_data_indexes = [] continue sequence_status = self.pack_sequence(sample, sequence_status) batch_data_indexes.append(sample["data_indexes"]) if sequence_status["curr"] >= self.expected_num_tokens: data = self.to_tensor(sequence_status) data["batch_data_indexes"] = batch_data_indexes yield data sequence_status = self.set_sequence_status() batch_data_indexes = [] def pack_sequence(self, sample, sequence_status): image_tensor_list = sample["image_tensor_list"] text_ids_list = sample["text_ids_list"] sequence_plan = sample["sequence_plan"] split_lens, attn_modes = list(), list() curr = sequence_status["curr"] curr_rope_id = 0 sample_lens = 0 for item in sequence_plan: split_start = item.get("split_start", True) if split_start: curr_split_len = 0 if item["type"] == "text": text_ids = text_ids_list.pop(0) if ( item["enable_cfg"] == 1 and random.random() < self.data_config.text_cond_dropout_prob ): continue shifted_text_ids = [self.bos_token_id] + text_ids sequence_status["packed_text_ids"].extend(shifted_text_ids) sequence_status["packed_text_indexes"].extend( range(curr, curr + len(shifted_text_ids)) ) if item["loss"] == 1: sequence_status["ce_loss_indexes"].extend( range(curr, curr + len(shifted_text_ids)) ) sequence_status["ce_loss_weights"].extend( [len2weight(len(shifted_text_ids))] * len(shifted_text_ids) ) sequence_status["packed_label_ids"].extend( text_ids + [self.eos_token_id] ) curr += len(shifted_text_ids) curr_split_len += len(shifted_text_ids) # add a <|im_end|> token sequence_status["packed_text_ids"].append(self.eos_token_id) sequence_status["packed_text_indexes"].append(curr) if item["special_token_loss"] == 1: # <|im_end|> may have loss sequence_status["ce_loss_indexes"].append(curr) sequence_status["ce_loss_weights"].append(1.0) sequence_status["packed_label_ids"].append( item["special_token_label"] ) curr += 1 curr_split_len += 1 # update sequence status attn_modes.append("causal") sequence_status["packed_position_ids"].extend( range(curr_rope_id, curr_rope_id + curr_split_len) ) curr_rope_id += curr_split_len elif item["type"] == "vit_image": image_tensor = image_tensor_list.pop(0) if ( item["enable_cfg"] == 1 and random.random() < self.data_config.vit_cond_dropout_prob ): curr_rope_id += 1 continue # add a <|startofimage|> token sequence_status["packed_text_ids"].append(self.start_of_image) sequence_status["packed_text_indexes"].append(curr) curr += 1 curr_split_len += 1 # preprocess image vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size) num_img_tokens = vit_tokens.shape[0] sequence_status["packed_vit_token_indexes"].extend( range(curr, curr + num_img_tokens) ) curr += num_img_tokens curr_split_len += num_img_tokens sequence_status["packed_vit_tokens"].append(vit_tokens) sequence_status["vit_token_seqlens"].append(num_img_tokens) sequence_status["packed_vit_position_ids"].append( self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.data_config.vit_patch_size, max_num_patches_per_side=self.data_config.max_num_patch_per_side, ) ) # add a <|endofimage|> token sequence_status["packed_text_ids"].append(self.end_of_image) sequence_status["packed_text_indexes"].append(curr) if item["special_token_loss"] == 1: # <|endofimage|> may have loss sequence_status["ce_loss_indexes"].append(curr) sequence_status["ce_loss_weights"].append(1.0) sequence_status["packed_label_ids"].append( item["special_token_label"] ) curr += 1 curr_split_len += 1 # update sequence status attn_modes.append("full") sequence_status["packed_position_ids"].extend( [curr_rope_id] * curr_split_len ) curr_rope_id += 1 elif item["type"] == "vae_image": image_tensor = image_tensor_list.pop(0) if ( item["enable_cfg"] == 1 and random.random() < self.data_config.vae_cond_dropout_prob ): # FIXME fix vae dropout in video2video setting. curr_rope_id += 1 continue # add a <|startofimage|> token sequence_status["packed_text_ids"].append(self.start_of_image) sequence_status["packed_text_indexes"].append(curr) curr += 1 curr_split_len += 1 # preprocess image sequence_status["vae_image_tensors"].append(image_tensor) sequence_status["packed_latent_position_ids"].append( self.get_flattened_position_ids( image_tensor.size(1), image_tensor.size(2), self.data_config.vae_image_downsample, max_num_patches_per_side=self.data_config.max_latent_size, ) ) H, W = image_tensor.shape[1:] h = H // self.data_config.vae_image_downsample w = W // self.data_config.vae_image_downsample sequence_status["vae_latent_shapes"].append((h, w)) num_img_tokens = w * h sequence_status["packed_vae_token_indexes"].extend( range(curr, curr + num_img_tokens) ) if item["loss"] == 1: sequence_status["mse_loss_indexes"].extend( range(curr, curr + num_img_tokens) ) if split_start: timestep = np.random.randn() else: timestep = float("-inf") sequence_status["packed_timesteps"].extend([timestep] * num_img_tokens) curr += num_img_tokens curr_split_len += num_img_tokens # add a <|endofimage|> token sequence_status["packed_text_ids"].append(self.end_of_image) sequence_status["packed_text_indexes"].append(curr) # <|endofimage|> may have loss if item["special_token_loss"] == 1: sequence_status["ce_loss_indexes"].append(curr) sequence_status["ce_loss_weights"].append(1.0) sequence_status["packed_label_ids"].append( item["special_token_label"] ) curr += 1 curr_split_len += 1 # update sequence status if split_start: if item["loss"] == 1 and "frame_delta" not in item.keys(): attn_modes.append("noise") else: attn_modes.append("full") sequence_status["packed_position_ids"].extend( [curr_rope_id] * (num_img_tokens + 2) ) if "frame_delta" in item.keys(): curr_rope_id += item["frame_delta"] elif item["loss"] == 0: curr_rope_id += 1 if item.get("split_end", True): split_lens.append(curr_split_len) sample_lens += curr_split_len sequence_status["curr"] = curr sequence_status["sample_lens"].append(sample_lens) # prepare attention mask if not self.use_flex: sequence_status["nested_attention_masks"].append( prepare_attention_mask_per_sample(split_lens, attn_modes) ) else: sequence_status["split_lens"].extend(split_lens) sequence_status["attn_modes"].extend(attn_modes) return sequence_status class SimpleCustomBatch: def __init__(self, batch): data = batch[0] self.batch_data_indexes = data["batch_data_indexes"] self.sequence_length = data["sequence_length"] self.sample_lens = data["sample_lens"] self.packed_text_ids = data["packed_text_ids"] self.packed_text_indexes = data["packed_text_indexes"] self.packed_position_ids = data["packed_position_ids"] self.use_flex = "nested_attention_masks" not in data.keys() if self.use_flex: self.split_lens = data["split_lens"] self.attn_modes = data["attn_modes"] else: self.nested_attention_masks = data["nested_attention_masks"] if "padded_images" in data.keys(): self.padded_images = data["padded_images"] self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"] self.packed_latent_position_ids = data["packed_latent_position_ids"] self.packed_vae_token_indexes = data["packed_vae_token_indexes"] if "packed_vit_tokens" in data.keys(): self.packed_vit_tokens = data["packed_vit_tokens"] self.packed_vit_position_ids = data["packed_vit_position_ids"] self.packed_vit_token_indexes = data["packed_vit_token_indexes"] self.vit_token_seqlens = data["vit_token_seqlens"] if "packed_timesteps" in data.keys(): self.packed_timesteps = data["packed_timesteps"] self.mse_loss_indexes = data["mse_loss_indexes"] if "packed_label_ids" in data.keys(): self.packed_label_ids = data["packed_label_ids"] self.ce_loss_indexes = data["ce_loss_indexes"] self.ce_loss_weights = data["ce_loss_weights"] def pin_memory(self): self.packed_text_ids = self.packed_text_ids.pin_memory() self.packed_text_indexes = self.packed_text_indexes.pin_memory() self.packed_position_ids = self.packed_position_ids.pin_memory() if not self.use_flex: self.nested_attention_masks = [ item.pin_memory() for item in self.nested_attention_masks ] if hasattr(self, "padded_images"): self.padded_images = self.padded_images.pin_memory() self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory() self.packed_latent_position_ids = ( self.packed_latent_position_ids.pin_memory() ) if hasattr(self, "packed_timesteps"): self.packed_timesteps = self.packed_timesteps.pin_memory() self.mse_loss_indexes = self.mse_loss_indexes.pin_memory() if hasattr(self, "packed_vit_tokens"): self.packed_vit_tokens = self.packed_vit_tokens.pin_memory() self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory() self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory() self.vit_token_seqlens = self.vit_token_seqlens.pin_memory() if hasattr(self, "packed_label_ids"): self.packed_label_ids = self.packed_label_ids.pin_memory() self.ce_loss_indexes = self.ce_loss_indexes.pin_memory() self.ce_loss_weights = self.ce_loss_weights.pin_memory() return self def cuda(self, device): self.packed_text_ids = self.packed_text_ids.to(device) self.packed_text_indexes = self.packed_text_indexes.to(device) self.packed_position_ids = self.packed_position_ids.to(device) if not self.use_flex: self.nested_attention_masks = [ item.to(device) for item in self.nested_attention_masks ] if hasattr(self, "padded_images"): self.padded_images = self.padded_images.to(device) self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device) self.packed_latent_position_ids = self.packed_latent_position_ids.to(device) if hasattr(self, "packed_timesteps"): self.packed_timesteps = self.packed_timesteps.to(device) self.mse_loss_indexes = self.mse_loss_indexes.to(device) if hasattr(self, "packed_vit_tokens"): self.packed_vit_tokens = self.packed_vit_tokens.to(device) self.packed_vit_position_ids = self.packed_vit_position_ids.to(device) self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device) self.vit_token_seqlens = self.vit_token_seqlens.to(device) if hasattr(self, "packed_label_ids"): self.packed_label_ids = self.packed_label_ids.to(device) self.ce_loss_indexes = self.ce_loss_indexes.to(device) self.ce_loss_weights = self.ce_loss_weights.to(device) return self def to_dict(self): data = dict( sequence_length=self.sequence_length, sample_lens=self.sample_lens, packed_text_ids=self.packed_text_ids, packed_text_indexes=self.packed_text_indexes, packed_position_ids=self.packed_position_ids, batch_data_indexes=self.batch_data_indexes, ) if not self.use_flex: data["nested_attention_masks"] = self.nested_attention_masks else: data["split_lens"] = self.split_lens data["attn_modes"] = self.attn_modes if hasattr(self, "padded_images"): data["padded_images"] = self.padded_images data["patchified_vae_latent_shapes"] = self.patchified_vae_latent_shapes data["packed_latent_position_ids"] = self.packed_latent_position_ids data["packed_vae_token_indexes"] = self.packed_vae_token_indexes if hasattr(self, "packed_vit_tokens"): data["packed_vit_tokens"] = self.packed_vit_tokens data["packed_vit_position_ids"] = self.packed_vit_position_ids data["packed_vit_token_indexes"] = self.packed_vit_token_indexes data["vit_token_seqlens"] = self.vit_token_seqlens if hasattr(self, "packed_timesteps"): data["packed_timesteps"] = self.packed_timesteps data["mse_loss_indexes"] = self.mse_loss_indexes if hasattr(self, "packed_label_ids"): data["packed_label_ids"] = self.packed_label_ids data["ce_loss_indexes"] = self.ce_loss_indexes data["ce_loss_weights"] = self.ce_loss_weights return data def collate_wrapper(): def collate_fn(batch): return SimpleCustomBatch(batch) return collate_fn