import random import numpy as np import torch from lightx2v.models.schedulers.scheduler import BaseScheduler def set_seed(seed=42): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size coords_h = torch.arange(0, num_patches_h) coords_w = torch.arange(0, num_patches_w) pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() return pos_ids def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() return pos_ids class BagelScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) self.config = config if self.config.interpolate_pos: self.get_flattened_position_ids = get_flattened_position_ids_interpolate else: self.get_flattened_position_ids = get_flattened_position_ids_extrapolate self.latent_patch_size = config.latent_patch_size self.latent_downsample = config.vae_config.downsample * config.latent_patch_size self.max_latent_size = config["max_latent_size_update"] self.latent_channel = config.vae_config.z_channels self.infer_steps = self.config.get("infer_steps", 50) inference_hyper = config["inference_hyper"] self.timestep_shift = inference_hyper["timestep_shift"] self.cache_dic = None self.current = None self.prepare() def set_timesteps(self): timesteps = torch.linspace(1, 0, self.infer_steps, device="cpu") timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) self.dts = timesteps[:-1] - timesteps[1:] self.timesteps = timesteps[:-1] def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids): packed_text_ids, packed_text_indexes = list(), list() packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list() packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list() packed_key_value_indexes = list() query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_text_ids.append(new_token_ids["start_of_image"]) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 vae_posiiton_ids = self.get_flattened_position_ids(H, W, self.latent_downsample, max_num_patches_per_side=self.max_latent_size) packed_vae_position_ids.append(vae_posiiton_ids) h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w set_seed() packed_init_noises.append(torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size**2)) packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens)) packed_indexes.extend(range(curr, curr + num_image_tokens)) curr += num_image_tokens query_curr += num_image_tokens packed_text_ids.append(new_token_ids["end_of_image"]) packed_text_indexes.append(query_curr) packed_indexes.append(curr) curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) packed_seqlens.append(num_image_tokens + 2) generation_input = { "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long), "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long), "packed_init_noises": torch.cat(packed_init_noises, dim=0), "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0), "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long), "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int), "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long), "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes): packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list() query_curr = curr = 0 for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope): packed_key_value_indexes.extend(range(curr, curr + curr_kvlen)) curr += curr_kvlen packed_indexes.append(curr) curr += 1 query_curr += 1 h, w = H // self.latent_downsample, W // self.latent_downsample num_image_tokens = h * w packed_indexes.extend(range(curr, curr + num_image_tokens)) curr += num_image_tokens query_curr += num_image_tokens packed_indexes.append(curr) curr += 1 query_curr += 1 packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2)) generation_input = { "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long), "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int), "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long), "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long), } return generation_input def prepare(self): self.generator = torch.Generator().manual_seed(42) self.set_timesteps() self.generation_input = None self.generation_input_cfg_text = None self.generation_input_cfg_image = None self.latents = None self.noise_pred = None def step_pre(self, step_index): self.step_index = step_index def step_post(self): self.latents = self.latents - self.noise_pred.to(self.latents.device) * self.dts[self.step_index]