import math import torch import comfy.ldm.common_dit import comfy.model_management as mm from torch import Tensor from einops import repeat from typing import Optional from unittest.mock import patch from comfy.ldm.flux.layers import timestep_embedding, apply_mod from comfy.ldm.lightricks.model import precompute_freqs_cis from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords from comfy.ldm.wan.model import sinusoidal_embedding_1d SUPPORTED_MODELS_COEFFICIENTS = { "flux": [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01], "flux-kontext": [-1.04655119e+03, 3.12563399e+02, -1.69500694e+01, 4.10995971e-01, 3.74537863e-02], "ltxv": [2.14700694e+01, -1.28016453e+01, 2.31279151e+00, 7.92487521e-01, 9.69274326e-03], "lumina_2": [-8.74643948e+02, 4.66059906e+02, -7.51559762e+01, 5.32836175e+00, -3.27258296e-02], "hunyuan_video": [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02], "hidream_i1_full": [-3.13605009e+04, -7.12425503e+02, 4.91363285e+01, 8.26515490e+00, 1.08053901e-01], "hidream_i1_dev": [1.39997273, -4.30130469, 5.01534416, -2.20504164, 0.93942874], "hidream_i1_fast": [2.26509623, -6.88864563, 7.61123826, -3.10849353, 0.99927602], "wan2.1_t2v_1.3B": [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01], "wan2.1_t2v_14B": [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404], "wan2.1_i2v_480p_14B": [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01], "wan2.1_i2v_720p_14B": [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683], "wan2.1_t2v_1.3B_ret_mode": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], "wan2.1_t2v_14B_ret_mode": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], "wan2.1_i2v_480p_14B_ret_mode": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], "wan2.1_i2v_720p_14B_ret_mode": [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], } def poly1d(coefficients, x): result = torch.zeros_like(x) for i, coeff in enumerate(coefficients): result += coeff * (x ** (len(coefficients) - 1 - i)) return result def teacache_flux_forward( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor = None, control = None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype)) if self.params.guidance_embed: if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) txt = self.txt_in(txt) if img_ids is not None: ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) else: pe = None blocks_replace = patches_replace.get("dit", {}) # enable teacache img_mod1, _ = self.double_blocks[0].img_mod(vec) modulated_inp = self.double_blocks[0].img_norm1(img) modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift).to(cache_device) ca_idx = 0 if not hasattr(self, 'accumulated_rel_l1_distance'): should_calc = True self.accumulated_rel_l1_distance = 0 else: try: self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())).abs() if self.accumulated_rel_l1_distance < rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 except: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp if not enable_teacache: should_calc = True if not should_calc: img += self.previous_residual.to(img.device) else: ori_img = img.to(cache_device) for i, block in enumerate(self.double_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask")) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attn_mask": attn_mask}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: img += add # PuLID attention if getattr(self, "pulid_data", {}): if i % self.pulid_double_interval == 0: # Will calculate influence of all pulid nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): img = img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], img) ca_idx += 1 if img.dtype == torch.float16: img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504) img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask")) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attn_mask": attn_mask}, {"original_block": block_wrap}) img = out["img"] else: img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: img[:, txt.shape[1] :, ...] += add # PuLID attention if getattr(self, "pulid_data", {}): real_img, txt = img[:, txt.shape[1]:, ...], img[:, :txt.shape[1], ...] if i % self.pulid_single_interval == 0: # Will calculate influence of all nodes at once for _, node_data in self.pulid_data.items(): if torch.any((node_data['sigma_start'] >= timesteps) & (timesteps >= node_data['sigma_end'])): real_img = real_img + node_data['weight'] * self.pulid_ca[ca_idx](node_data['embedding'], real_img) ca_idx += 1 img = torch.cat((txt, real_img), 1) img = img[:, txt.shape[1] :, ...] self.previous_residual = img.to(cache_device) - ori_img img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img def teacache_hidream_forward( self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None, encoder_hidden_states_llama3=None, image_cond=None, control = None, transformer_options = {}, ) -> torch.Tensor: rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") cond_or_uncond = transformer_options.get("cond_or_uncond") model_type = transformer_options.get("model_type") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") bs, c, h, w = x.shape if image_cond is not None: x = torch.cat([x, image_cond], dim=-1) hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) timesteps = t pooled_embeds = y T5_encoder_hidden_states = context img_sizes = None # spatial forward batch_size = hidden_states.shape[0] hidden_states_type = hidden_states.dtype # 0. time timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device) timesteps = self.t_embedder(timesteps, hidden_states_type) p_embedder = self.p_embedder(pooled_embeds) adaln_input = timesteps + p_embedder hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes) if image_tokens_masks is None: pH, pW = img_sizes[0] img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device) img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size) hidden_states = self.x_embedder(hidden_states) # T5_encoder_hidden_states = encoder_hidden_states[0] encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0) encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers] if self.caption_projection is not None: new_encoder_hidden_states = [] for i, enc_hidden_state in enumerate(encoder_hidden_states): enc_hidden_state = self.caption_projection[i](enc_hidden_state) enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1]) new_encoder_hidden_states.append(enc_hidden_state) encoder_hidden_states = new_encoder_hidden_states T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states) T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) encoder_hidden_states.append(T5_encoder_hidden_states) txt_ids = torch.zeros( batch_size, encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1], 3, device=img_ids.device, dtype=img_ids.dtype ) ids = torch.cat((img_ids, txt_ids), dim=1) rope = self.pe_embedder(ids) # enable teacache modulated_inp = timesteps.to(cache_device) if "full" in model_type else hidden_states.to(cache_device) if not hasattr(self, 'teacache_state'): self.teacache_state = { 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} } def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 except: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp b = int(len(hidden_states) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: should_calc = False for k in cond_or_uncond: should_calc = (should_calc or self.teacache_state[k]['should_calc']) else: should_calc = True if not should_calc: for i, k in enumerate(cond_or_uncond): hidden_states[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(hidden_states.device) else: # 2. Blocks ori_hidden_states = hidden_states.to(cache_device) block_id = 0 initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1) initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1] for bid, block in enumerate(self.double_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1) hidden_states, initial_encoder_hidden_states = block( image_tokens = hidden_states, image_tokens_masks = image_tokens_masks, text_tokens = cur_encoder_hidden_states, adaln_input = adaln_input, rope = rope, ) initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len] block_id += 1 image_tokens_seq_len = hidden_states.shape[1] hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1) hidden_states_seq_len = hidden_states.shape[1] if image_tokens_masks is not None: encoder_attention_mask_ones = torch.ones( (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]), device=image_tokens_masks.device, dtype=image_tokens_masks.dtype ) image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1) for bid, block in enumerate(self.single_stream_blocks): cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id] hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1) hidden_states = block( image_tokens=hidden_states, image_tokens_masks=image_tokens_masks, text_tokens=None, adaln_input=adaln_input, rope=rope, ) hidden_states = hidden_states[:, :hidden_states_seq_len] block_id += 1 hidden_states = hidden_states[:, :image_tokens_seq_len, ...] for i, k in enumerate(cond_or_uncond): self.teacache_state[k]['previous_residual'] = (hidden_states.to(cache_device) - ori_hidden_states)[i*b:(i+1)*b] output = self.final_layer(hidden_states, adaln_input) output = self.unpatchify(output, img_sizes) return -output[:, :, :h, :w] def teacache_lumina_forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") cond_or_uncond = transformer_options.get("cond_or_uncond") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask bs, c, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) t = self.t_embedder(t, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute x_is_tensor = isinstance(x, torch.Tensor) x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens) freqs_cis = freqs_cis.to(x.device) # enable teacache modulated_inp = t.to(cache_device) if not hasattr(self, 'teacache_state'): self.teacache_state = { 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} } def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 except: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp b = int(len(x) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: should_calc = False for k in cond_or_uncond: should_calc = (should_calc or self.teacache_state[k]['should_calc']) else: should_calc = True if not should_calc: for i, k in enumerate(cond_or_uncond): x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) else: ori_x = x.to(cache_device) # 2. Blocks for layer in self.layers: x = layer(x, mask, freqs_cis, adaln_input) for i, k in enumerate(cond_or_uncond): self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] x = self.final_layer(x, adaln_input) x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] return -x def teacache_hunyuanvideo_forward( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, txt_mask: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, ) -> Tensor: patches_replace = transformer_options.get("patches_replace", {}) rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") initial_shape = list(img.shape) # running on sequences img img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) ref_latent = self.img_in(ref_latent) img = torch.cat([ref_latent, img], dim=-2) ref_latent_ids[..., 0] = -1 ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1]) img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2) if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims_txt = [(0, None, 1)] else: vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) modulation_dims = None modulation_dims_txt = None if self.params.guidance_embed: if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) if txt_mask is not None and not torch.is_floating_point(txt_mask): txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max txt = self.txt_in(txt, timesteps, txt_mask) ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) img_len = img.shape[1] if txt_mask is not None: attn_mask_len = img_len + txt.shape[1] attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) attn_mask[:, 0, img_len:] = txt_mask else: attn_mask = None blocks_replace = patches_replace.get("dit", {}) # enable teacache img_mod1, _ = self.double_blocks[0].img_mod(vec) modulated_inp = self.double_blocks[0].img_norm1(img) modulated_inp = apply_mod(modulated_inp, (1 + img_mod1.scale), img_mod1.shift, modulation_dims).to(cache_device) if not hasattr(self, 'accumulated_rel_l1_distance'): should_calc = True self.accumulated_rel_l1_distance = 0 else: try: self.accumulated_rel_l1_distance += poly1d(coefficients, ((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())) if self.accumulated_rel_l1_distance < rel_l1_thresh: should_calc = False else: should_calc = True self.accumulated_rel_l1_distance = 0 except: should_calc = True self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = modulated_inp if not enable_teacache: should_calc = True if not should_calc: img += self.previous_residual.to(img.device) else: ori_img = img.to(cache_device) for i, block in enumerate(self.double_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) return out out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) if control is not None: # Controlnet control_i = control.get("input") if i < len(control_i): add = control_i[i] if add is not None: img += add img = torch.cat((img, txt), 1) for i, block in enumerate(self.single_blocks): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) return out out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) img = out["img"] else: img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_o = control.get("output") if i < len(control_o): add = control_o[i] if add is not None: img[:, : img_len] += add img = img[:, : img_len] self.previous_residual = (img.to(cache_device) - ori_img) if ref_latent is not None: img = img[:, ref_latent.shape[1]:] img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] for i in range(len(shape)): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img def teacache_ltxvmodel_forward( self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs ): patches_replace = transformer_options.get("patches_replace", {}) rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") cond_or_uncond = transformer_options.get("cond_or_uncond") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") orig_shape = list(x.shape) x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, scale_factors=self.vae_scale_factors, causal_fix=self.causal_temporal_positioning, ) if keyframe_idxs is not None: pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs fractional_coords = pixel_coords.to(torch.float32) fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) x = self.patchify_proj(x) timestep = timestep * 1000.0 if attention_mask is not None and not torch.is_floating_point(attention_mask): attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) batch_size = x.shape[0] timestep, embedded_timestep = self.adaln_single( timestep.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=x.dtype, ) # Second dimension is 1 or number of tokens (if timestep_per_token) timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view( batch_size, -1, embedded_timestep.shape[-1] ) # 2. Blocks if self.caption_projection is not None: batch_size = x.shape[0] context = self.caption_projection(context) context = context.view( batch_size, -1, x.shape[-1] ) blocks_replace = patches_replace.get("dit", {}) # enable teacache inp = x.to(cache_device) timestep_ = timestep.to(cache_device) num_ada_params = self.transformer_blocks[0].scale_shift_table.shape[0] ada_values = self.transformer_blocks[0].scale_shift_table[None, None].to(timestep_.device) + timestep_.reshape(batch_size, timestep_.size(1), num_ada_params, -1) shift_msa, scale_msa, _, _, _, _ = ada_values.unbind(dim=2) modulated_inp = comfy.ldm.common_dit.rms_norm(inp) modulated_inp = modulated_inp * (1 + scale_msa) + shift_msa if not hasattr(self, 'teacache_state'): self.teacache_state = { 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} } def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 except: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp b = int(len(x) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: should_calc = False for k in cond_or_uncond: should_calc = (should_calc or self.teacache_state[k]['should_calc']) else: should_calc = True if not should_calc: for i, k in enumerate(cond_or_uncond): x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) else: ori_x = x.to(cache_device) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"]) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap}) x = out["img"] else: x = block( x, context=context, attention_mask=attention_mask, timestep=timestep, pe=pe ) # 3. Output scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation x = x * (1 + scale) + shift for i, k in enumerate(cond_or_uncond): self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] x = self.proj_out(x) x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], output_width=orig_shape[4], output_num_frames=orig_shape[2], out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size), ) return x def teacache_wanmodel_forward( self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, **kwargs, ): patches_replace = transformer_options.get("patches_replace", {}) rel_l1_thresh = transformer_options.get("rel_l1_thresh") coefficients = transformer_options.get("coefficients") cond_or_uncond = transformer_options.get("cond_or_uncond") model_type = transformer_options.get("model_type") enable_teacache = transformer_options.get("enable_teacache", True) cache_device = transformer_options.get("cache_device") # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] x = x.flatten(2).transpose(1, 2) # time embeddings e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype)) e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # context context = self.text_embedding(context) context_img_len = None if clip_fea is not None: if self.img_emb is not None: context_clip = self.img_emb(clip_fea) # bs x 257 x dim context = torch.concat([context_clip, context], dim=1) context_img_len = clip_fea.shape[-2] blocks_replace = patches_replace.get("dit", {}) # enable teacache modulated_inp = e0.to(cache_device) if "ret_mode" in model_type else e.to(cache_device) if not hasattr(self, 'teacache_state'): self.teacache_state = { 0: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None}, 1: {'should_calc': True, 'accumulated_rel_l1_distance': 0, 'previous_modulated_input': None, 'previous_residual': None} } def update_cache_state(cache, modulated_inp): if cache['previous_modulated_input'] is not None: try: cache['accumulated_rel_l1_distance'] += poly1d(coefficients, ((modulated_inp-cache['previous_modulated_input']).abs().mean() / cache['previous_modulated_input'].abs().mean())) if cache['accumulated_rel_l1_distance'] < rel_l1_thresh: cache['should_calc'] = False else: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 except: cache['should_calc'] = True cache['accumulated_rel_l1_distance'] = 0 cache['previous_modulated_input'] = modulated_inp b = int(len(x) / len(cond_or_uncond)) for i, k in enumerate(cond_or_uncond): update_cache_state(self.teacache_state[k], modulated_inp[i*b:(i+1)*b]) if enable_teacache: should_calc = False for k in cond_or_uncond: should_calc = (should_calc or self.teacache_state[k]['should_calc']) else: should_calc = True if not should_calc: for i, k in enumerate(cond_or_uncond): x[i*b:(i+1)*b] += self.teacache_state[k]['previous_residual'].to(x.device) else: ori_x = x.to(cache_device) for i, block in enumerate(self.blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len) return out out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap, "transformer_options": transformer_options}) x = out["img"] else: x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len) for i, k in enumerate(cond_or_uncond): self.teacache_state[k]['previous_residual'] = (x.to(cache_device) - ori_x)[i*b:(i+1)*b] # head x = self.head(x, e) # unpatchify x = self.unpatchify(x, grid_sizes) return x class TeaCache: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", {"tooltip": "The diffusion model the TeaCache will be applied to."}), "model_type": (["flux", "flux-kontext", "ltxv", "lumina_2", "hunyuan_video", "hidream_i1_full", "hidream_i1_dev", "hidream_i1_fast", "wan2.1_t2v_1.3B", "wan2.1_t2v_14B", "wan2.1_i2v_480p_14B", "wan2.1_i2v_720p_14B", "wan2.1_t2v_1.3B_ret_mode", "wan2.1_t2v_14B_ret_mode", "wan2.1_i2v_480p_14B_ret_mode", "wan2.1_i2v_720p_14B_ret_mode"], {"default": "flux", "tooltip": "Supported diffusion model."}), "rel_l1_thresh": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "How strongly to cache the output of diffusion model. This value must be non-negative."}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The start percentage of the steps that will apply TeaCache."}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The end percentage of the steps that will apply TeaCache."}), "cache_device": (["cuda", "cpu"], {"default": "cuda", "tooltip": "Device where the cache will reside."}), } } RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "apply_teacache" CATEGORY = "TeaCache" TITLE = "TeaCache" def apply_teacache(self, model, model_type: str, rel_l1_thresh: float, start_percent: float, end_percent: float, cache_device: str): if rel_l1_thresh == 0: return (model,) new_model = model.clone() if 'transformer_options' not in new_model.model_options: new_model.model_options['transformer_options'] = {} new_model.model_options["transformer_options"]["rel_l1_thresh"] = rel_l1_thresh new_model.model_options["transformer_options"]["coefficients"] = SUPPORTED_MODELS_COEFFICIENTS[model_type] new_model.model_options["transformer_options"]["model_type"] = model_type new_model.model_options["transformer_options"]["cache_device"] = mm.get_torch_device() if cache_device == "cuda" else torch.device("cpu") diffusion_model = new_model.get_model_object("diffusion_model") if "flux" in model_type: is_cfg = False context = patch.multiple( diffusion_model, forward_orig=teacache_flux_forward.__get__(diffusion_model, diffusion_model.__class__) ) elif "lumina_2" in model_type: is_cfg = True context = patch.multiple( diffusion_model, forward=teacache_lumina_forward.__get__(diffusion_model, diffusion_model.__class__) ) elif "hidream_i1" in model_type: is_cfg = True if "full" in model_type else False context = patch.multiple( diffusion_model, forward=teacache_hidream_forward.__get__(diffusion_model, diffusion_model.__class__) ) elif "ltxv" in model_type: is_cfg = True context = patch.multiple( diffusion_model, forward=teacache_ltxvmodel_forward.__get__(diffusion_model, diffusion_model.__class__) ) elif "hunyuan_video" in model_type: is_cfg = False context = patch.multiple( diffusion_model, forward_orig=teacache_hunyuanvideo_forward.__get__(diffusion_model, diffusion_model.__class__) ) elif "wan2.1" in model_type: is_cfg = True context = patch.multiple( diffusion_model, forward_orig=teacache_wanmodel_forward.__get__(diffusion_model, diffusion_model.__class__) ) else: raise ValueError(f"Unknown type {model_type}") def unet_wrapper_function(model_function, kwargs): input = kwargs["input"] timestep = kwargs["timestep"] c = kwargs["c"] # referenced from https://github.com/kijai/ComfyUI-KJNodes/blob/d126b62cebee81ea14ec06ea7cd7526999cb0554/nodes/model_optimization_nodes.py#L868 sigmas = c["transformer_options"]["sample_sigmas"] matched_step_index = (sigmas == timestep[0]).nonzero() if len(matched_step_index) > 0: current_step_index = matched_step_index.item() else: current_step_index = 0 for i in range(len(sigmas) - 1): # walk from beginning of steps until crossing the timestep if (sigmas[i] - timestep[0]) * (sigmas[i + 1] - timestep[0]) <= 0: current_step_index = i break if current_step_index == 0: if is_cfg: # uncond -> 1, cond -> 0 if hasattr(diffusion_model, 'teacache_state') and \ diffusion_model.teacache_state[0]['previous_modulated_input'] is not None and \ diffusion_model.teacache_state[1]['previous_modulated_input'] is not None: delattr(diffusion_model, 'teacache_state') else: if hasattr(diffusion_model, 'teacache_state'): delattr(diffusion_model, 'teacache_state') if hasattr(diffusion_model, 'accumulated_rel_l1_distance'): delattr(diffusion_model, 'accumulated_rel_l1_distance') current_percent = current_step_index / (len(sigmas) - 1) c["transformer_options"]["current_percent"] = current_percent if start_percent <= current_percent <= end_percent: c["transformer_options"]["enable_teacache"] = True else: c["transformer_options"]["enable_teacache"] = False with context: return model_function(input, timestep, **c) new_model.set_model_unet_function_wrapper(unet_wrapper_function) return (new_model,) def patch_optimized_module(): try: from torch._dynamo.eval_frame import OptimizedModule except ImportError: return if getattr(OptimizedModule, "_patched", False): return def __getattribute__(self, name): if name == "_orig_mod": return object.__getattribute__(self, "_modules")[name] if name in ( "__class__", "_modules", "state_dict", "load_state_dict", "parameters", "named_parameters", "buffers", "named_buffers", "children", "named_children", "modules", "named_modules", ): return getattr(object.__getattribute__(self, "_orig_mod"), name) return object.__getattribute__(self, name) def __delattr__(self, name): return delattr(self._orig_mod, name) @classmethod def __instancecheck__(cls, instance): return isinstance(instance, OptimizedModule) or issubclass( object.__getattribute__(instance, "__class__"), cls ) OptimizedModule.__getattribute__ = __getattribute__ OptimizedModule.__delattr__ = __delattr__ OptimizedModule.__instancecheck__ = __instancecheck__ OptimizedModule._patched = True def patch_same_meta(): try: from torch._inductor.fx_passes import post_grad except ImportError: return same_meta = getattr(post_grad, "same_meta", None) if same_meta is None: return if getattr(same_meta, "_patched", False): return def new_same_meta(a, b): try: return same_meta(a, b) except Exception: return False post_grad.same_meta = new_same_meta new_same_meta._patched = True class CompileModel: @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL", {"tooltip": "The diffusion model the torch.compile will be applied to."}), "mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}), "backend": (["inductor","cudagraphs", "eager", "aot_eager"], {"default": "inductor"}), "fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}), "dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}), } } RETURN_TYPES = ("MODEL",) RETURN_NAMES = ("model",) FUNCTION = "apply_compile" CATEGORY = "TeaCache" TITLE = "Compile Model" def apply_compile(self, model, mode: str, backend: str, fullgraph: bool, dynamic: bool): patch_optimized_module() patch_same_meta() torch._dynamo.config.suppress_errors = True new_model = model.clone() new_model.add_object_patch( "diffusion_model", torch.compile( new_model.get_model_object("diffusion_model"), mode=mode, backend=backend, fullgraph=fullgraph, dynamic=dynamic ) ) return (new_model,) NODE_CLASS_MAPPINGS = { "TeaCache": TeaCache, "CompileModel": CompileModel } NODE_DISPLAY_NAME_MAPPINGS = {k: v.TITLE for k, v in NODE_CLASS_MAPPINGS.items()}