import functools import inspect import json import math import os from math import prod from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.distributed as dist from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from torch import nn from torch.nn import functional as F from lightx2v.models.schedulers.scheduler import BaseScheduler from lightx2v_platform.base.global_var import AI_DEVICE, PLATFORM try: from sgl_kernel.elementwise import timestep_embedding as timestep_embedding_cuda TIMESTEP_EMBEDDING_CUDA_AVAILABLE = PLATFORM == "cuda" except ImportError: TIMESTEP_EMBEDDING_CUDA_AVAILABLE = False def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError(f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.") scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, device: Optional[Union[str, "torch.device"]] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """A helper function to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor is always created on the CPU. """ # device on which tensor is created defaults to device if isinstance(device, str): device = torch.device(device) rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": print( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slightly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") # make sure generator list of length 1 is treated like a non-list if isinstance(generator, list) and len(generator) == 1: generator = generator[0] if isinstance(generator, list): shape = (1,) + shape[1:] latents = [torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) for i in range(batch_size)] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int = 256, flip_sin_to_cos: bool = True, downscale_freq_shift: float = 0, scale: float = 1000, max_period: int = 10000, ) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ if TIMESTEP_EMBEDDING_CUDA_AVAILABLE: return timestep_embedding_cuda( timesteps, embedding_dim, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift, scale=scale, max_period=max_period, ) assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class QwenEmbedRope(nn.Module): def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim pos_index = torch.arange(4096) neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[2], self.theta), ], dim=1, ) self.neg_freqs = torch.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[2], self.theta), ], dim=1, ) self.rope_cache = {} # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token """ assert dim % 2 == 0 freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def forward(self, video_fhw, txt_seq_lens, device): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] if not isinstance(video_fhw, list): video_fhw = [video_fhw] vid_freqs = [] max_vid_index = 0 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw rope_key = f"{idx}_{height}_{width}" if not torch.compiler.is_compiling(): if rope_key not in self.rope_cache: self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) video_freq = self.rope_cache[rope_key] else: video_freq = self._compute_video_freqs(frame, height, width, idx) video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: max_vid_index = max(height // 2, width // 2, max_vid_index) else: max_vid_index = max(height, width, max_vid_index) max_len = txt_seq_lens txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return [vid_freqs, txt_freqs] @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() class QwenEmbedLayer3DRope(nn.Module): def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): super().__init__() self.theta = theta self.axes_dim = axes_dim pos_index = torch.arange(4096) neg_index = torch.arange(4096).flip(0) * -1 - 1 self.pos_freqs = torch.cat( [ self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[2], self.theta), ], dim=1, ) self.neg_freqs = torch.cat( [ self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[2], self.theta), ], dim=1, ) self.scale_rope = scale_rope def rope_params(self, index, dim, theta=10000): """ Args: index: [0, 1, 2, 3] 1D Tensor representing the position index of the token """ assert dim % 2 == 0 freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))) freqs = torch.polar(torch.ones_like(freqs), freqs) return freqs def forward(self, video_fhw, txt_seq_lens, device): """ Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: txt_length: [bs] a list of 1 integers representing the length of the text """ if self.pos_freqs.device != device: self.pos_freqs = self.pos_freqs.to(device) self.neg_freqs = self.neg_freqs.to(device) if isinstance(video_fhw, list): video_fhw = video_fhw[0] if not isinstance(video_fhw, list): video_fhw = [video_fhw] vid_freqs = [] max_vid_index = 0 layer_num = len(video_fhw) - 1 for idx, fhw in enumerate(video_fhw): frame, height, width = fhw if idx != layer_num: video_freq = self._compute_video_freqs(frame, height, width, idx) else: ### For the condition image, we set the layer index to -1 video_freq = self._compute_condition_freqs(frame, height, width) video_freq = video_freq.to(device) vid_freqs.append(video_freq) if self.scale_rope: max_vid_index = max(height // 2, width // 2, max_vid_index) else: max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) max_len = txt_seq_lens txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() @functools.lru_cache(maxsize=None) def _compute_condition_freqs(self, frame, height, width): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_frame = freqs_neg[0][-1:].view(frame, 1, 1, -1).expand(frame, height, width, -1) if self.scale_rope: freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) else: freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1) return freqs.clone().contiguous() class QwenImageScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) self.config = config self.is_layered = config.get("layered", False) if self.is_layered: self.layers = config.get("layers", 4) scheduler_path = config.get("scheduler_path", os.path.join(config["model_path"], "scheduler")) self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(scheduler_path) with open(os.path.join(config["model_path"], "scheduler", "scheduler_config.json"), "r") as f: self.scheduler_config = json.load(f) self.dtype = torch.bfloat16 self.sample_guide_scale = self.config["sample_guide_scale"] self.zero_cond_t = config.get("zero_cond_t", False) if self.config["seq_parallel"]: self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") else: self.seq_p_group = None self.use_layer3d_rope = config.get("use_layer3d_rope", False) if self.use_layer3d_rope: self.pos_embed = QwenEmbedLayer3DRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) else: self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) @staticmethod def _pack_latents(latents, batchsize, num_channels_latents, height, width, layers=None): if not layers: latents = latents.view(batchsize, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batchsize, (height // 2) * (width // 2), num_channels_latents * 4) else: latents = latents.view(batchsize, layers, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 1, 3, 5, 2, 4, 6) latents = latents.reshape(batchsize, layers * (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (vae_scale_factor * 2)) width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) return latents @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height, width, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape(latent_image_id_height * latent_image_id_width, latent_image_id_channels) return latent_image_ids.to(device=device, dtype=dtype) def prepare_latents(self, input_info): self.input_info = input_info shape = input_info.target_shape width, height = shape[-1], shape[-2] latents = randn_tensor(shape, generator=self.generator, device=AI_DEVICE, dtype=self.dtype) if self.is_layered: latents = self._pack_latents(latents, 1, self.config.get("num_channels_latents", 16), height, width, self.layers + 1) else: latents = self._pack_latents(latents, 1, self.config.get("num_channels_latents", 16), height, width) latent_image_ids = self._prepare_latent_image_ids(1, height // 2, width // 2, AI_DEVICE, self.dtype) self.latents = latents self.latent_image_ids = latent_image_ids self.noise_pred = None def set_timesteps(self): sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"]) image_seq_len = self.latents.shape[1] if self.is_layered: base_seqlen = 256 * 256 / 16 / 16 image_seq_len = self.latents.shape[1] // 5 mu = (image_seq_len / base_seqlen) ** 0.5 else: mu = calculate_shift( image_seq_len, self.scheduler_config.get("base_image_seq_len", 256), self.scheduler_config.get("max_image_seq_len", 4096), self.scheduler_config.get("base_shift", 0.5), self.scheduler_config.get("max_shift", 1.15), ) num_inference_steps = self.config["infer_steps"] timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, AI_DEVICE, sigmas=sigmas, mu=mu, ) self.timesteps = timesteps self.infer_steps = num_inference_steps num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) self.num_warmup_steps = num_warmup_steps def prepare(self, input_info): if self.config["task"] == "i2i": self.generator = torch.Generator().manual_seed(input_info.seed) elif self.config["task"] == "t2i": self.generator = torch.Generator(device=AI_DEVICE).manual_seed(input_info.seed) self.prepare_latents(input_info) self.set_timesteps() self.image_rotary_emb = self.pos_embed(self.input_info.image_shapes, input_info.txt_seq_lens[0], device=AI_DEVICE) if self.config.get("rope_type", "flashinfer") == "flashinfer": cos_half_img = self.image_rotary_emb[0].real.contiguous() sin_half_img = self.image_rotary_emb[0].imag.contiguous() cos_half_txt = self.image_rotary_emb[1].real.contiguous() sin_half_txt = self.image_rotary_emb[1].imag.contiguous() self.image_rotary_emb[0] = torch.cat([cos_half_img, sin_half_img], dim=-1) self.image_rotary_emb[1] = torch.cat([cos_half_txt, sin_half_txt], dim=-1) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.image_rotary_emb[0].shape[0] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.image_rotary_emb[0] = F.pad(self.image_rotary_emb[0], (0, 0, 0, padding_size)) self.image_rotary_emb[0] = torch.chunk(self.image_rotary_emb[0], world_size, dim=0)[cur_rank] if self.config["enable_cfg"]: self.negative_image_rotary_emb = self.pos_embed(self.input_info.image_shapes, input_info.txt_seq_lens[1], device=AI_DEVICE) if self.config.get("rope_type", "flashinfer") == "flashinfer": cos_half_img = self.negative_image_rotary_emb[0].real.contiguous() sin_half_img = self.negative_image_rotary_emb[0].imag.contiguous() cos_half_txt = self.negative_image_rotary_emb[1].real.contiguous() sin_half_txt = self.negative_image_rotary_emb[1].imag.contiguous() self.negative_image_rotary_emb[0] = torch.cat([cos_half_img, sin_half_img], dim=-1) self.negative_image_rotary_emb[1] = torch.cat([cos_half_txt, sin_half_txt], dim=-1) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.negative_image_rotary_emb[0].shape[0] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.negative_image_rotary_emb[0] = F.pad(self.negative_image_rotary_emb[0], (0, 0, 0, padding_size)) self.negative_image_rotary_emb[0] = torch.chunk(self.negative_image_rotary_emb[0], world_size, dim=0)[cur_rank] if self.zero_cond_t: self.modulate_index = torch.tensor([[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in self.input_info.image_shapes], device=AI_DEVICE, dtype=torch.int) if self.seq_p_group is not None: world_size = dist.get_world_size(self.seq_p_group) cur_rank = dist.get_rank(self.seq_p_group) seqlen = self.modulate_index.shape[1] padding_size = (world_size - (seqlen % world_size)) % world_size if padding_size > 0: self.modulate_index = F.pad(self.modulate_index, (0, padding_size)) self.modulate_index = torch.chunk(self.modulate_index, world_size, dim=1)[cur_rank] else: self.modulate_index = None def step_pre(self, step_index): super().step_pre(step_index) timestep_input = torch.tensor([self.timesteps[self.step_index]], device=AI_DEVICE, dtype=self.dtype) / 1000 if self.zero_cond_t: timestep_input = torch.cat([timestep_input, timestep_input * 0], dim=0) self.timesteps_proj = get_timestep_embedding(timestep_input).to(torch.bfloat16) def step_post(self): # compute the previous noisy sample x_t -> x_t-1 t = self.timesteps[self.step_index] latents = self.scheduler.step(self.noise_pred, t, self.latents, return_dict=False)[0] self.latents = latents