Unverified Commit 47b3ce2f authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Update wan infer rope (#518)

parent 5f277e80
...@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer): ...@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
@torch.no_grad() @torch.no_grad()
def infer(self, x, pre_infer_out): def infer(self, x, pre_infer_out):
x = x[: pre_infer_out.seq_lens[0]]
t, h, w = pre_infer_out.grid_sizes.tuple t, h, w = pre_infer_out.grid_sizes.tuple
grid_sizes = (t - 1, h, w) grid_sizes = (t - 1, h, w)
......
...@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer ...@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from ..module_io import GridOutput, WanPreInferModuleOutput from ..module_io import GridOutput, WanPreInferModuleOutput
from ..utils import rope_params, sinusoidal_embedding_1d from ..utils import sinusoidal_embedding_1d
class WanAudioPreInfer(WanPreInfer): class WanAudioPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"]
self.config = config self.config = config
self.task = config["task"] self.task = config["task"]
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
...@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
x = weights.patch_embedding.apply(x.unsqueeze(0)) x = weights.patch_embedding.apply(x.unsqueeze(0))
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
y = weights.patch_embedding.apply(y.unsqueeze(0)) y = weights.patch_embedding.apply(y.unsqueeze(0))
y = y.flatten(2).transpose(1, 2).contiguous() y = y.flatten(2).transpose(1, 2).contiguous()
x = torch.cat([x, y], dim=1).squeeze(0) x = torch.cat([x, y], dim=1).squeeze(0)
####for r2v # zero temporl component corresponding to ref embeddings ####for r2v # zero temporl component corresponding to ref embeddings
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0 # self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t += 1 grid_sizes_t += 1
person_mask_latens = inputs["person_mask_latens"] person_mask_latens = inputs["person_mask_latens"]
...@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
x=x, x=x,
embed0=embed0.squeeze(0), embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context, context=context,
adapter_args={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens}, adapter_args={"audio_encoder_output": inputs["audio_encoder_output"], "person_mask_latens": person_mask_latens},
) )
...@@ -16,8 +16,6 @@ class WanPreInferModuleOutput: ...@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
grid_sizes: GridOutput grid_sizes: GridOutput
x: torch.Tensor x: torch.Tensor
embed0: torch.Tensor embed0: torch.Tensor
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor context: torch.Tensor
adapter_args: Dict[str, Any] = field(default_factory=dict) adapter_args: Dict[str, Any] = field(default_factory=dict)
conditional_dict: Dict[str, Any] = field(default_factory=dict) conditional_dict: Dict[str, Any] = field(default_factory=dict)
...@@ -3,26 +3,17 @@ import torch ...@@ -3,26 +3,17 @@ import torch
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .module_io import GridOutput, WanPreInferModuleOutput from .module_io import GridOutput, WanPreInferModuleOutput
from .utils import guidance_scale_embedding, rope_params, sinusoidal_embedding_1d from .utils import guidance_scale_embedding, sinusoidal_embedding_1d
class WanPreInfer: class WanPreInfer:
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config self.config = config
d = config["dim"] // config["num_heads"]
self.run_device = self.config.get("run_device", "cuda") self.run_device = self.config.get("run_device", "cuda")
self.clean_cuda_cache = config.get("clean_cuda_cache", False) self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda")) self.device = torch.device(self.config.get("run_device", "cuda"))
self.freqs = torch.cat(
[
rope_params(1024, d - 4 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"] self.freq_dim = config["freq_dim"]
self.dim = config["dim"] self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False) self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
...@@ -71,7 +62,7 @@ class WanPreInfer: ...@@ -71,7 +62,7 @@ class WanPreInfer:
grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:] grid_sizes_t, grid_sizes_h, grid_sizes_w = x.shape[2:]
x = x.flatten(2).transpose(1, 2).contiguous() x = x.flatten(2).transpose(1, 2).contiguous()
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0) # seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten()) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
if self.enable_dynamic_cfg: if self.enable_dynamic_cfg:
...@@ -130,8 +121,6 @@ class WanPreInfer: ...@@ -130,8 +121,6 @@ class WanPreInfer:
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
x=x.squeeze(0), x=x.squeeze(0),
embed0=embed0.squeeze(0), embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context, context=context,
adapter_args={"motion_vec": motion_vec}, adapter_args={"motion_vec": motion_vec},
) )
from dataclasses import dataclass
import torch import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput, WanPreInferModuleOutput from lightx2v.models.networks.wan.infer.module_io import GridOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
...@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000): ...@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
return freqs return freqs
@dataclass
class WanSFPreInferModuleOutput:
embed: torch.Tensor
grid_sizes: GridOutput
x: torch.Tensor
embed0: torch.Tensor
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
class WanSFPreInfer(WanPreInfer): class WanSFPreInfer(WanPreInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
...@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer): ...@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w)) grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
return WanPreInferModuleOutput( return WanSFPreInferModuleOutput(
embed=embed, embed=embed,
grid_sizes=grid_sizes, grid_sizes=grid_sizes,
x=x.squeeze(0), x=x.squeeze(0),
......
...@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer): ...@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
self.infer_func = self.infer_with_kvcache self.infer_func = self.infer_with_kvcache
def get_scheduler_values(self):
pass
def _initialize_kv_cache(self, dtype, device): def _initialize_kv_cache(self, dtype, device):
""" """
Initialize a Per-GPU KV cache for the Wan model. Initialize a Per-GPU KV cache for the Wan model.
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, compute_freqs_dist from .utils import apply_wan_rope_with_chunk, apply_wan_rope_with_flashinfer, apply_wan_rope_with_torch
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
...@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if config.get("rotary_chunk", False): if self.config.get("rope_type", "flashinfer") == "flashinfer":
chunk_size = config.get("rotary_chunk_size", 100) if self.config.get("rope_chunk", False):
self.apply_rotary_emb_func = partial(apply_rotary_emb_chunk, chunk_size=chunk_size) self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_flashinfer)
else: else:
self.apply_rotary_emb_func = apply_rotary_emb self.apply_rope_func = apply_wan_rope_with_flashinfer
else:
if self.config.get("rope_chunk", False):
self.apply_rope_func = partial(apply_wan_rope_with_chunk, chunk_size=self.config.get("rope_chunk_size", 100), rope_func=apply_wan_rope_with_torch)
else:
self.apply_rope_func = apply_wan_rope_with_torch
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.infer_dtype = GET_DTYPE() self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE() self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
...@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
self.seq_p_group = None self.seq_p_group = None
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
self.cos_sin = None
def _calculate_q_k_len(self, q, k_lens): def _calculate_q_k_len(self, q, k_lens):
q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device) q_lens = torch.tensor([q.size(0)], dtype=torch.int32, device=q.device)
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k return cu_seqlens_q, cu_seqlens_k
def compute_freqs(self, q, grid_sizes, freqs): def get_scheduler_values(self):
if self.config["seq_parallel"]: self.cos_sin = self.scheduler.cos_sin
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs, self.seq_p_group)
else:
freqs_i = compute_freqs(q.size(2) // 2, grid_sizes, freqs)
return freqs_i
@torch.no_grad() @torch.no_grad()
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
self.get_scheduler_values()
x = self.infer_main_blocks(weights.blocks, pre_infer_out) x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed) return self.infer_non_blocks(weights, x, pre_infer_out.embed)
...@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
y_out = self.infer_self_attn( y_out = self.infer_self_attn(
block.compute_phases[0], block.compute_phases[0],
pre_infer_out.grid_sizes.tuple,
x, x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa, shift_msa,
scale_msa, scale_msa,
) )
...@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_self_attn(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): def infer_self_attn(self, phase, x, shift_msa, scale_msa):
cos_sin = self.cos_sin
if hasattr(phase, "smooth_norm1_weight"): if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
...@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d) v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
freqs_i = self.compute_freqs(q, grid_sizes, freqs) q, k = self.apply_rope_func(q, k, cos_sin)
q = self.apply_rotary_emb_func(q, freqs_i) img_qkv_len = q.shape[0]
k = self.apply_rotary_emb_func(k, freqs_i) cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)
k_lens = torch.empty_like(seq_lens).fill_(freqs_i.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias del norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
...@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
q=q, q=q,
k=k, k=k,
v=v, v=v,
img_qkv_len=q.shape[0], img_qkv_len=img_qkv_len,
cu_seqlens_qkv=cu_seqlens_q, cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=phase.self_attn_1, attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
...@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
q=q, q=q,
k=k, k=k,
v=v, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_k, cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=q.size(0), max_seqlen_q=img_qkv_len,
max_seqlen_kv=k.size(0), max_seqlen_kv=img_qkv_len,
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
def apply_wan_rope_with_torch(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
n = xq.size(1)
seq_len = cos_sin_cache.size(0)
xq = torch.view_as_complex(xq[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
xk = torch.view_as_complex(xk[:seq_len].to(torch.float32).reshape(seq_len, n, -1, 2))
# Apply rotary embedding
xq = torch.view_as_real(xq * cos_sin_cache).flatten(2)
xk = torch.view_as_real(xk * cos_sin_cache).flatten(2)
xq = torch.cat([xq, xq[seq_len:]])
xk = torch.cat([xk, xk[seq_len:]])
return xq.to(GET_DTYPE()), xk.to(GET_DTYPE())
def apply_wan_rope_with_chunk(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
chunk_size: int,
rope_func,
):
seq_len = cos_sin_cache.size(0)
xq_output_chunks = []
xk_output_chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
xq_chunk = xq[start:end]
xk_chunk = xk[start:end]
cos_sin_chunk = cos_sin_cache[start:end]
xq_chunk, xk_chunk = rope_func(xq_chunk, xk_chunk, cos_sin_chunk)
xq_output_chunks.append(xq_chunk)
xk_output_chunks.append(xk_chunk)
torch.cuda.empty_cache()
x_q = torch.cat(xq_output_chunks, dim=0)
del xq_output_chunks
torch.cuda.empty_cache()
x_k = torch.cat(xk_output_chunks, dim=0)
del xk_output_chunks
torch.cuda.empty_cache()
return x_q.to(GET_DTYPE()), x_k.to(GET_DTYPE())
def apply_wan_rope_with_flashinfer(
xq: torch.Tensor,
xk: torch.Tensor,
cos_sin_cache: torch.Tensor,
):
L, H, D = xq.shape
query = xq.reshape(L, H * D).contiguous()
key = xk.reshape(L, H * D).contiguous()
positions = torch.arange(L, device="cpu", dtype=torch.long).to(xq.device, non_blocking=True)
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
key=key,
head_size=D,
cos_sin_cache=cos_sin_cache,
is_neox=False,
)
xq_out = query.view(L, H, D)
xk_out = key.view(L, H, D)
return xq_out, xk_out
def compute_freqs(c, grid_sizes, freqs): def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes f, h, w = grid_sizes
......
...@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like ...@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
class EulerScheduler(WanScheduler): class EulerScheduler(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
d = config["dim"] // config["num_heads"]
self.rope_t_dim = d // 2 - 2 * (d // 6)
if self.config["parallel"]: if self.config["parallel"]:
self.sp_size = self.config["parallel"].get("seq_p_size", 1) self.sp_size = self.config["parallel"].get("seq_p_size", 1)
...@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler): ...@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
self.timesteps = self.sigmas * self.num_train_timesteps self.timesteps = self.sigmas * self.num_train_timesteps
self.freqs[latent_shape[1] // self.patch_size[0] :, : self.rope_t_dim] = 0
self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0] + 1, latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))
def step_post(self): def step_post(self):
model_output = self.noise_pred.to(torch.float32) model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32) sample = self.latents.to(torch.float32)
......
...@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler): ...@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
self.infer_steps = self.config["infer_steps"] self.infer_steps = self.config["infer_steps"]
self.target_video_length = self.config["target_video_length"] self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"] self.sample_shift = self.config["sample_shift"]
self.run_device = self.config.get("run_device", "cuda")
self.patch_size = (1, 2, 2)
self.shift = 1 self.shift = 1
self.num_train_timesteps = 1000 self.num_train_timesteps = 1000
self.disable_corrector = [] self.disable_corrector = []
...@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler): ...@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
self.noise_pred = None self.noise_pred = None
self.sample_guide_scale = self.config["sample_guide_scale"] self.sample_guide_scale = self.config["sample_guide_scale"]
self.caching_records_2 = [True] * self.config["infer_steps"] self.caching_records_2 = [True] * self.config["infer_steps"]
self.head_size = self.config["dim"] // self.config["num_heads"]
self.freqs = torch.cat(
[
self.rope_params(1024, self.head_size - 4 * (self.head_size // 6)),
self.rope_params(1024, 2 * (self.head_size // 6)),
self.rope_params(1024, 2 * (self.head_size // 6)),
],
dim=1,
).to(torch.device(self.run_device))
def rope_params(self, max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)),
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def prepare(self, seed, latent_shape, image_encoder_output=None): def prepare(self, seed, latent_shape, image_encoder_output=None):
if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
...@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler): ...@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1] // self.patch_size[0], latent_shape[2] // self.patch_size[1], latent_shape[3] // self.patch_size[2]))
def prepare_cos_sin(self, grid_sizes):
c = self.head_size // 2
freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes
seq_len = f * h * w
cos_sin = torch.cat(
[
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
],
dim=-1,
)
if self.config.get("rope_type", "flashinfer") == "flashinfer":
cos_sin = cos_sin.reshape(seq_len, -1)
# Extract cos and sin parts separately and concatenate
cos_half = cos_sin.real.contiguous()
sin_half = cos_sin.imag.contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1)
else:
cos_sin = cos_sin.reshape(seq_len, 1, -1)
return cos_sin
def prepare_latents(self, seed, latent_shape, dtype=torch.float32): def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(seed) self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn( self.latents = torch.randn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment