import functools import inspect import json import math import os from math import prod from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from torch import nn from torch.nn import functional as F from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM try: from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda TIMESTEP_EMBEDDING_CUDA_AVAILABLE = PLATFORM == "cuda" except ImportError: TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.") scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"): """Retrieve latents from VAE encoder output.""" if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional[Union[str, "torch.device"]] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device if isinstance(device, str): device = torch.device(device) rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": print( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slightly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") # make sure generator list of length 1 is treated like a non-list if isinstance(generator, list) and len(generator) == 1: generator = generator[0] if isinstance(generator, list): shape = (1,) + shape[1:] latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int = 256, flip_sin_to_cos: bool = True, downscale_freq_shift: float = 0, scale: float = 1000, max_period: int = 10000, ) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ if TIMESTEP_EMBEDDING_CUDA_AVAILABLE: return timestep_embedding_cuda( timesteps, embedding_dim, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift, scale=scale, max_period=max_period, ) assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class ZEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim pos_index = torch.arange(4096) neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[2], self.theta), ], dim=1, ) self.neg_freqs = torch.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[2], self.theta), ], dim=1, ) self.rope_cache = {} # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token """ assert dim % 2 == 0 freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def forward(self, video_fhw, txt_seq_lens, device): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] if not isinstance(video_fhw, list): video_fhw = [video_fhw] vid_freqs = [] max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw rope_key = f"{idx}_{height}_{width}" if not torch.compiler.is_compiling(): if rope_key not in self.rope_cache: self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) video_freq = self.rope_cache[rope_key] else: video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: max_vid_index = max(height // 2, width // 2, max_vid_index) else: max_vid_index = max(height, width, max_vid_index) max_len = txt_seq_lens txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return [vid_freqs, txt_freqs] @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() class RopeEmbedder: def __init__( self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512), ): self.theta = theta self.axes_dims = axes_dims self.axes_lens = axes_lens assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length" self.freqs_cis = None @staticmethod def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0): with torch.device("cpu"): freqs_cis = [] for i, (d, e) in enumerate(zip(dim, end)): # Compute base frequencies: [1/theta^0, 1/theta^(2/d), 1/theta^(4/d), ...] freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d)) # Compute timestep positions: [0, 1, 2, ..., e-1] timestep = torch.arange(e, device=freqs.device, dtype=torch.float64) # Outer product: [e, d//2] freqs = torch.outer(timestep, freqs).float() # Convert to complex: polar(1, angle) = e^(i*angle) freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) freqs_cis.append(freqs_cis_i) return freqs_cis def __call__(self, ids: torch.Tensor): assert ids.ndim == 2 assert ids.shape[-1] == len(self.axes_dims) device = ids.device if self.freqs_cis is None: self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] else: if self.freqs_cis[0].device != device: self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis] result = [] for i in range(len(self.axes_dims)): index = ids[:, i] result.append(self.freqs_cis[i][index]) return torch.cat(result, dim=-1) class ZImageScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) self.config = config self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(os.path.join(config["model_path"], "scheduler")) with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f: self.scheduler_config = json.load(f) self.dtype = torch.bfloat16 self.sample_guide_scale = self.config["sample_guide_scale"] self.zero_cond_t = config.get("zero_cond_t", False) if self.config["seq_parallel"]: self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") else: self.seq_p_group = None self.pos_embed = ZEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) # Initialize RopeEmbedder for generating freqs_cis from position IDs (used in pre_infer) rope_theta = config.get("rope_theta", 256.0) axes_dims = config.get("axes_dims", [32, 48, 48]) axes_lens = config.get("axes_lens", [1024, 512, 512]) self.rope_embedder = RopeEmbedder( theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens, ) @staticmethod def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, 1, 1, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 6, 3, 5, 7, 1) latents = latents.reshape(batch_size, 1 * (height // 2) * (width // 2), 1 * 2 * 2 * num_channels_latents) return latents @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) return latents @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod def create_coordinate_grid(size, start=None, device=None): """Create a 3D coordinate grid.""" if start is None: start = (0 for _ in size) axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] grids = torch.meshgrid(axes, indexing="ij") return torch.stack(grids, dim=-1) def prepare_latents(self, input_info): self.input_info = input_info shape = input_info.target_shape if len(shape) != 4: raise ValueError(f"target_shape must be 4D [B, C, H, W], got {len(shape)}D: {shape}") batch_size, num_channels, height, width = shape latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype) latent_image_ids = self._prepare_latent_image_ids(1, height // 2, width // 2, AI_DEVICE, self.dtype) self.latents = latents self.latent_image_ids = latent_image_ids self.noise_pred = None def generate_freqs_cis_from_position_ids(self, position_ids: torch.Tensor, device: torch.device = None) -> torch.Tensor: if device is None: device = position_ids.device freqs_cis = self.rope_embedder(position_ids.to(device)) rope_type = self.config.get("rope_type", "flashinfer") if rope_type == "flashinfer": freqs_cis = torch.cat([freqs_cis.real, freqs_cis.imag], dim=-1).float() return freqs_cis def get_timesteps(self, num_inference_steps, strength, device): init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def set_timesteps(self): sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"]) image_seq_len = self.latents.shape[1] mu = calculate_shift( image_seq_len, self.scheduler_config.get("base_image_seq_len", 256), self.scheduler_config.get("max_image_seq_len", 4096), self.scheduler_config.get("base_shift", 0.5), self.scheduler_config.get("max_shift", 1.15), ) num_inference_steps = self.config["infer_steps"] timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, AI_DEVICE, sigmas=sigmas, mu=mu, ) self.timesteps = timesteps self.infer_steps = num_inference_steps # Adjust timesteps based on strength for i2i task if self.config["task"] == "i2i" and hasattr(self.input_info, "strength"): strength = getattr(self.input_info, "strength", 0.6) if strength < 0.0 or strength > 1.0: raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}") timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, AI_DEVICE) if num_inference_steps < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) self.timesteps = timesteps self.infer_steps = num_inference_steps num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) self.num_warmup_steps = num_warmup_steps def prepare(self, input_info): self.generator = torch.Generator(device=AI_DEVICE).manual_seed(input_info.seed) self.prepare_latents(input_info) self.set_timesteps() if self.config["task"] == "i2i" and hasattr(input_info, "image_encoder_output") and input_info.image_encoder_output is not None: strength = getattr(input_info, "strength", 0.6) if strength > 0.0: image_latents_list = [item["image_latents"] for item in input_info.image_encoder_output] if len(image_latents_list) > 0: image_latents = torch.cat(image_latents_list, dim=0) if len(image_latents_list) > 1 else image_latents_list[0] batch_size = self.latents.shape[0] if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.") _, _, height, width = self.latents.shape if image_latents.shape[2:] != (height, width): image_latents = F.interpolate(image_latents, size=(height, width), mode="bilinear", align_corners=False) latent_timestep = self.timesteps[:1].repeat(self.latents.shape[0]) noise = randn_tensor(self.latents.shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype) if image_latents.shape[1] != self.latents.shape[1]: repeat_factor = self.latents.shape[1] // image_latents.shape[1] # 64 // 16 = 4 image_latents = image_latents.repeat(1, repeat_factor, 1, 1) self.latents = self.scheduler.scale_noise(image_latents, latent_timestep, noise) self.image_rotary_emb = self.pos_embed(self.input_info.image_shapes, input_info.txt_seq_lens[0], device=AI_DEVICE) if self.config.get("rope_type", "flashinfer") == "flashinfer": cos_half_img = self.image_rotary_emb[0].real.contiguous() sin_half_img = self.image_rotary_emb[0].imag.contiguous() cos_half_txt = self.image_rotary_emb[1].real.contiguous() sin_half_txt = self.image_rotary_emb[1].imag.contiguous() self.image_rotary_emb[0] = torch.cat([cos_half_img, sin_half_img], dim=-1) self.image_rotary_emb[1] = torch.cat([cos_half_txt, sin_half_txt], dim=-1) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.image_rotary_emb[0].shape[0] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.image_rotary_emb[0] = F.pad(self.image_rotary_emb[0], (0, 0, 0, padding_size)) self.image_rotary_emb[0] = torch.chunk(self.image_rotary_emb[0], world_size, dim=0)[cur_rank] if self.config["enable_cfg"]: self.negative_image_rotary_emb = self.pos_embed(self.input_info.image_shapes, input_info.txt_seq_lens[1], device=AI_DEVICE) if self.config.get("rope_type", "flashinfer") == "flashinfer": cos_half_img = self.negative_image_rotary_emb[0].real.contiguous() sin_half_img = self.negative_image_rotary_emb[0].imag.contiguous() cos_half_txt = self.negative_image_rotary_emb[1].real.contiguous() sin_half_txt = self.negative_image_rotary_emb[1].imag.contiguous() self.negative_image_rotary_emb[0] = torch.cat([cos_half_img, sin_half_img], dim=-1) self.negative_image_rotary_emb[1] = torch.cat([cos_half_txt, sin_half_txt], dim=-1) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.negative_image_rotary_emb[0].shape[0] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.negative_image_rotary_emb[0] = F.pad(self.negative_image_rotary_emb[0], (0, 0, 0, padding_size)) self.negative_image_rotary_emb[0] = torch.chunk(self.negative_image_rotary_emb[0], world_size, dim=0)[cur_rank] if self.zero_cond_t: self.modulate_index = torch.tensor([[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in self.input_info.image_shapes], device=AI_DEVICE, dtype=torch.int) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.modulate_index.shape[1] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.modulate_index = F.pad(self.modulate_index, (0, padding_size)) self.modulate_index = torch.chunk(self.modulate_index, world_size, dim=1)[cur_rank] else: self.modulate_index = None def step_pre(self, step_index): super().step_pre(step_index) timestep_value = self.timesteps[self.step_index].item() timestep_input = torch.tensor([1000.0 - timestep_value], device=AI_DEVICE, dtype=torch.float32) if self.zero_cond_t: timestep_input = torch.cat([timestep_input, timestep_input * 0], dim=0) timesteps_proj_float32 = get_timestep_embedding(timestep_input, scale=1.0) self.timesteps_proj = timesteps_proj_float32.to(torch.bfloat16) def step_post(self): noise_pred = -self.noise_pred noise_pred = noise_pred.to(torch.float32) latents = self.latents t = self.timesteps[self.step_index] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] self.latents = latents