Commit 8ddd33a5 authored by GoatWu's avatar GoatWu
Browse files

update dynamic-cfg settings

parent 91ef1bd1
import torch
from diffusers.models.embeddings import TimestepEmbedding
from .utils import rope_params, sinusoidal_embedding_1d, guidance_scale_embedding
from lightx2v.utils.envs import *
......@@ -64,8 +65,10 @@ class WanPreInfer:
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg:
s = torch.tensor([self.cfg_scale], dtype=torch.float32).to(x.device)
cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 8.0), target_range=1000.0, dtype=torch.float32).type_as(x)
cfg_embed = weights.cfg_cond_proj.apply(cfg_embed)
cfg_embed = guidance_scale_embedding(s, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32).type_as(x)
cfg_embed = weights.cfg_cond_proj_1.apply(cfg_embed)
cfg_embed = torch.nn.functional.silu(cfg_embed)
cfg_embed = weights.cfg_cond_proj_2.apply(cfg_embed)
embed = embed + cfg_embed
if GET_DTYPE() != "BF16":
embed = weights.time_embedding_0.apply(embed.float())
......
......@@ -172,7 +172,7 @@ def sinusoidal_embedding_1d(dim, position):
return x
def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(0.0, 8.0), target_range=1000.0, dtype=torch.float32):
def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(1.0, 6.0), target_range=1000.0, dtype=torch.float32):
"""
Args:
timesteps: torch.Tensor: generate embedding vectors at these timesteps
......@@ -184,6 +184,8 @@ def guidance_scale_embedding(w, embedding_dim=256, cfg_range=(0.0, 8.0), target_
"""
assert len(w.shape) == 1
cfg_min, cfg_max = cfg_range
w = torch.round(w)
w = torch.clamp(w, min=cfg_min, max=cfg_max)
w = (w - cfg_min) / (cfg_max - cfg_min) # [0, 1]
w = w * target_range
half_dim = embedding_dim // 2
......
......@@ -59,6 +59,10 @@ class WanPreWeights(WeightModule):
if config.model_cls == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
self.add_module(
"cfg_cond_proj",
MM_WEIGHT_REGISTER["Default"]("cfg_cond_proj.weight", "cfg_cond_proj.bias"),
"cfg_cond_proj_1",
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"),
)
self.add_module(
"cfg_cond_proj_2",
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"),
)
#!/bin/bash
# set path and first
lightx2v_path="/data/lightx2v-dev/"
model_path="/data/lightx2v-dev/Wan2.1-T2V-14B/"
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment