"vscode:/vscode.git/clone" did not exist on "7e2435e509fead8e5edf9657e310dd410128c19b"
Commit ad051778 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Fix (#80)



* Fix

* Fix

* Fix

---------
Co-authored-by: default avatargushiqiao <gushiqiao@sensetime.com>
parent fb69083e
...@@ -24,7 +24,7 @@ class WanDistillModel(WanModel): ...@@ -24,7 +24,7 @@ class WanDistillModel(WanModel):
ckpt_path = os.path.join(self.model_path, "distill_model.pt") ckpt_path = os.path.join(self.model_path, "distill_model.pt")
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
# 文件不存在,调用父类的 _load_ckpt 方法 # 文件不存在,调用父类的 _load_ckpt 方法
return super()._load_ckpt() return super()._load_ckpt(use_bf16, skip_bf16)
weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) weight_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()} weight_dict = {key: (weight_dict[key].to(torch.bfloat16) if use_bf16 or all(s not in key for s in skip_bf16) else weight_dict[key]).pin_memory().to(self.device) for key in weight_dict.keys()}
......
...@@ -8,6 +8,7 @@ class WanPostInfer: ...@@ -8,6 +8,7 @@ class WanPostInfer:
def __init__(self, config): def __init__(self, config):
self.out_dim = config["out_dim"] self.out_dim = config["out_dim"]
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
...@@ -21,16 +22,21 @@ class WanPostInfer: ...@@ -21,16 +22,21 @@ class WanPostInfer:
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e] e = [ei.squeeze(1) for ei in e]
norm_out = weights.norm.apply(x) x = weights.norm.apply(x)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
norm_out = norm_out.float() x = x.float()
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0) x.mul_(1 + e[1].squeeze(0)).add_(e[0].squeeze(0))
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
out = out.to(torch.bfloat16) x = x.to(torch.bfloat16)
x = weights.head.apply(out) x = weights.head.apply(x)
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
if self.clean_cuda_cache:
del e, grid_sizes
torch.cuda.empty_cache()
return [u.float() for u in x] return [u.float() for u in x]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
......
...@@ -7,7 +7,7 @@ class WanPreInfer: ...@@ -7,7 +7,7 @@ class WanPreInfer:
def __init__(self, config): def __init__(self, config):
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0 assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
d = config["dim"] // config["num_heads"] d = config["dim"] // config["num_heads"]
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"] self.task = config["task"]
self.freqs = torch.cat( self.freqs = torch.cat(
[ [
...@@ -87,6 +87,9 @@ class WanPreInfer: ...@@ -87,6 +87,9 @@ class WanPreInfer:
out = weights.text_embedding_0.apply(stacked.squeeze(0)) out = weights.text_embedding_0.apply(stacked.squeeze(0))
out = torch.nn.functional.gelu(out, approximate="tanh") out = torch.nn.functional.gelu(out, approximate="tanh")
context = weights.text_embedding_2.apply(out) context = weights.text_embedding_2.apply(out)
if self.clean_cuda_cache:
del out, stacked
torch.cuda.empty_cache()
if self.task == "i2v": if self.task == "i2v":
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
...@@ -95,7 +98,9 @@ class WanPreInfer: ...@@ -95,7 +98,9 @@ class WanPreInfer:
context_clip = weights.proj_3.apply(context_clip) context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip) context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0) context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
del context_clip, clip_fea
torch.cuda.empty_cache()
return ( return (
embed, embed,
grid_sizes, grid_sizes,
......
import torch import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb, apply_rotary_emb_chunk
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import (
WeightAsyncStreamManager, WeightAsyncStreamManager,
LazyWeightAsyncStreamManager, LazyWeightAsyncStreamManager,
...@@ -14,11 +14,13 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -14,11 +14,13 @@ class WanTransformerInfer(BaseTransformerInfer):
self.task = config["task"] self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.phases_num = 3 self.phases_num = 4
self.num_heads = config["num_heads"] self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"] self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
self.apply_rotary_emb_func = apply_rotary_emb_chunk if config.get("rotary_chunk", False) else apply_rotary_emb
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
if "offload_ratio" in self.config: if "offload_ratio" in self.config:
offload_ratio = self.config["offload_ratio"] offload_ratio = self.config["offload_ratio"]
...@@ -92,10 +94,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -92,10 +94,6 @@ class WanTransformerInfer(BaseTransformerInfer):
def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def _infer_with_phases_offload(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
weights.blocks[block_idx].modulation.to_cuda()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
phase = weights.blocks[block_idx].compute_phases[phase_idx] phase = weights.blocks[block_idx].compute_phases[phase_idx]
...@@ -105,12 +103,23 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -105,12 +103,23 @@ class WanTransformerInfer(BaseTransformerInfer):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0] cur_phase_idx, cur_phase = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0: if cur_phase_idx == 0:
y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(cur_phase, embed0)
elif cur_phase_idx == 1: elif cur_phase_idx == 1:
attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa) y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
x = self.infer_phase_5(x, y, c_gate_msa) elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
...@@ -120,8 +129,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -120,8 +129,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases() self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return x return x
...@@ -130,11 +137,6 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -130,11 +137,6 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.prefetch_weights_from_disk(weights) self.weights_stream_mgr.prefetch_weights_from_disk(weights)
for block_idx in range(weights.blocks_num): for block_idx in range(weights.blocks_num):
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
weights.blocks[block_idx].modulation.to_cuda()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights.blocks[block_idx], grid_sizes, embed, x, embed0, seq_lens, freqs, context)
for phase_idx in range(self.weights_stream_mgr.phases_num): for phase_idx in range(self.weights_stream_mgr.phases_num):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
obj_key = (block_idx, phase_idx) obj_key = (block_idx, phase_idx)
...@@ -152,12 +154,25 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -152,12 +154,25 @@ class WanTransformerInfer(BaseTransformerInfer):
) = self.weights_stream_mgr.active_weights[0] ) = self.weights_stream_mgr.active_weights[0]
if cur_phase_idx == 0: if cur_phase_idx == 0:
y_out = self.infer_phase_2(cur_phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
cur_phase,
embed0,
)
elif cur_phase_idx == 1: elif cur_phase_idx == 1:
attn_out = self.infer_phase_3(cur_phase, x, context, y_out, gate_msa) y_out = self.infer_self_attn(
cur_phase,
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
elif cur_phase_idx == 2: elif cur_phase_idx == 2:
y = self.infer_phase_4(cur_phase, x, attn_out, c_shift_msa, c_scale_msa) attn_out = self.infer_cross_attn(cur_phase, x, context, y_out, gate_msa)
x = self.infer_phase_5(x, y, c_gate_msa) elif cur_phase_idx == 3:
y = self.infer_ffn(cur_phase, x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1): if not (block_idx == weights.blocks_num - 1 and phase_idx == self.phases_num - 1):
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
...@@ -166,10 +181,16 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -166,10 +181,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self.weights_stream_mgr.swap_phases() self.weights_stream_mgr.swap_phases()
weights.blocks[block_idx].modulation.to_cpu()
self.weights_stream_mgr._async_prefetch_block(weights) self.weights_stream_mgr._async_prefetch_block(weights)
torch.cuda.empty_cache() if self.clean_cuda_cache:
del attn_out, y_out, y
torch.cuda.empty_cache()
if self.clean_cuda_cache:
del shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
del grid_sizes, embed, embed0, seq_lens, freqs, context
torch.cuda.empty_cache()
return x return x
...@@ -188,36 +209,51 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -188,36 +209,51 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_phase_1(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.infer_modulation(
y_out = self.infer_phase_2(weights.compute_phases[0], grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa) weights.compute_phases[0],
attn_out = self.infer_phase_3(weights.compute_phases[1], x, context, y_out, gate_msa) embed0,
y = self.infer_phase_4(weights.compute_phases[2], x, attn_out, c_shift_msa, c_scale_msa) )
x = self.infer_phase_5(x, y, c_gate_msa) y_out = self.infer_self_attn(
weights.compute_phases[1],
grid_sizes,
x,
seq_lens,
freqs,
shift_msa,
scale_msa,
)
attn_out = self.infer_cross_attn(weights.compute_phases[2], x, context, y_out, gate_msa)
y = self.infer_ffn(weights.compute_phases[3], x, attn_out, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa)
return x return x
def infer_phase_1(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_modulation(self, weights, embed0):
if embed0.dim() == 3: if embed0.dim() == 3:
modulation = weights.modulation.tensor.unsqueeze(2) modulation = weights.modulation.tensor.unsqueeze(2)
embed0 = (modulation + embed0).chunk(6, dim=1) embed0 = (modulation + embed0).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0] shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2: elif embed0.dim() == 2:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (weights.modulation.tensor + embed0).chunk(6, dim=1)
if self.clean_cuda_cache:
del embed0
torch.cuda.empty_cache()
return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa return shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa
def infer_phase_2(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa): def infer_self_attn(self, weights, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(weights, "smooth_norm1_weight"): if hasattr(weights, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa) * weights.smooth_norm1_weight.tensor norm1_weight = (1 + scale_msa.squeeze(0)) * weights.smooth_norm1_weight.tensor
norm1_bias = shift_msa * weights.smooth_norm1_bias.tensor norm1_bias = shift_msa.squeeze(0) * weights.smooth_norm1_bias.tensor
else: else:
norm1_weight = 1 + scale_msa norm1_weight = 1 + scale_msa.squeeze(0)
norm1_bias = shift_msa norm1_bias = shift_msa.squeeze(0)
norm1_out = weights.norm1.apply(x) norm1_out = weights.norm1.apply(x)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
norm1_out = norm1_out.float() norm1_out = norm1_out.float()
norm1_out = (norm1_out * norm1_weight + norm1_bias).squeeze(0) norm1_out.mul_(norm1_weight).add_(norm1_bias)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
norm1_out = norm1_out.to(torch.bfloat16) norm1_out = norm1_out.to(torch.bfloat16)
...@@ -233,8 +269,8 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -233,8 +269,8 @@ class WanTransformerInfer(BaseTransformerInfer):
else: else:
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs) freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
q = apply_rotary_emb(q, freqs_i) q = self.apply_rotary_emb_func(q, freqs_i)
k = apply_rotary_emb(k, freqs_i) k = self.apply_rotary_emb_func(k, freqs_i)
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens) cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=seq_lens)
...@@ -260,9 +296,14 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -260,9 +296,14 @@ class WanTransformerInfer(BaseTransformerInfer):
) )
y = weights.self_attn_o.apply(attn_out) y = weights.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out, freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
return y return y
def infer_phase_3(self, weights, x, context, y_out, gate_msa): def infer_cross_attn(self, weights, x, context, y_out, gate_msa):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.float() + y_out.float() * gate_msa.squeeze(0) x = x.float() + y_out.float() * gate_msa.squeeze(0)
else: else:
...@@ -319,13 +360,26 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -319,13 +360,26 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_kv=k_img.size(0), max_seqlen_kv=k_img.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
attn_out = attn_out + img_attn_out attn_out.add_(img_attn_out)
if self.clean_cuda_cache:
del k_img, v_img, img_attn_out
torch.cuda.empty_cache()
attn_out = weights.cross_attn_o.apply(attn_out) attn_out = weights.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return attn_out return attn_out
def infer_phase_4(self, weights, x, attn_out, c_shift_msa, c_scale_msa): def infer_ffn(self, weights, x, attn_out, c_shift_msa, c_scale_msa):
x.add_(attn_out) x.add_(attn_out)
if self.clean_cuda_cache:
del attn_out
torch.cuda.empty_cache()
if hasattr(weights, "smooth_norm2_weight"): if hasattr(weights, "smooth_norm2_weight"):
norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor norm2_weight = (1 + c_scale_msa.squeeze(0)) * weights.smooth_norm2_weight.tensor
norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor norm2_bias = c_shift_msa.squeeze(0) * weights.smooth_norm2_bias.tensor
...@@ -333,21 +387,30 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -333,21 +387,30 @@ class WanTransformerInfer(BaseTransformerInfer):
norm2_weight = 1 + c_scale_msa.squeeze(0) norm2_weight = 1 + c_scale_msa.squeeze(0)
norm2_bias = c_shift_msa.squeeze(0) norm2_bias = c_shift_msa.squeeze(0)
norm2_out = weights.norm2.apply(x) x = weights.norm2.apply(x)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
norm2_out = norm2_out.float() x = x.float()
norm2_out = norm2_out * norm2_weight + norm2_bias x.mul_(norm2_weight).add_(norm2_bias)
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
norm2_out = norm2_out.to(torch.bfloat16) x = x.to(torch.bfloat16)
y = weights.ffn_0.apply(norm2_out) x = weights.ffn_0.apply(x)
y = torch.nn.functional.gelu(y, approximate="tanh") if self.clean_cuda_cache:
y = weights.ffn_2.apply(y) torch.cuda.empty_cache()
return y x = torch.nn.functional.gelu(x, approximate="tanh")
if self.clean_cuda_cache:
torch.cuda.empty_cache()
x = weights.ffn_2.apply(x)
def infer_phase_5(self, x, y, c_gate_msa): return x
def post_process(self, x, y, c_gate_msa):
if GET_DTYPE() != "BF16": if GET_DTYPE() != "BF16":
x = x.float() + y.float() * c_gate_msa.squeeze(0) x = x.float() + y.float() * c_gate_msa.squeeze(0)
else: else:
x.add_(y * c_gate_msa.squeeze(0)) x.add_(y * c_gate_msa.squeeze(0))
if self.clean_cuda_cache:
del y, c_gate_msa
torch.cuda.empty_cache()
return x return x
...@@ -75,6 +75,39 @@ def apply_rotary_emb(x, freqs_i): ...@@ -75,6 +75,39 @@ def apply_rotary_emb(x, freqs_i):
return x_i.to(torch.bfloat16) return x_i.to(torch.bfloat16)
def apply_rotary_emb_chunk(x, freqs_i, chunk_size=100, remaining_chunk_size=100):
n = x.size(1)
seq_len = freqs_i.size(0)
output_chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
x_chunk = x[start:end]
freqs_chunk = freqs_i[start:end]
x_chunk_complex = torch.view_as_complex(x_chunk.to(torch.float32).reshape(end - start, n, -1, 2))
x_chunk_embedded = torch.view_as_real(x_chunk_complex * freqs_chunk).flatten(2).to(torch.bfloat16)
output_chunks.append(x_chunk_embedded)
del x_chunk_complex, x_chunk_embedded
torch.cuda.empty_cache()
result = []
for chunk in output_chunks:
result.append(chunk)
del output_chunks
torch.cuda.empty_cache()
for start in range(seq_len, x.size(0), remaining_chunk_size):
end = min(start + remaining_chunk_size, x.size(0))
result.append(x[start:end])
x_i = torch.cat(result, dim=0)
del result
torch.cuda.empty_cache()
return x_i.to(torch.bfloat16)
def rope_params(max_seq_len, dim, theta=10000): def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0 assert dim % 2 == 0
freqs = torch.outer( freqs = torch.outer(
......
...@@ -34,6 +34,7 @@ class WanModel: ...@@ -34,6 +34,7 @@ class WanModel:
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
self.model_path = model_path self.model_path = model_path
self.config = config self.config = config
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None) self.dit_quantized_ckpt = self.config.get("dit_quantized_ckpt", None)
...@@ -133,22 +134,7 @@ class WanModel: ...@@ -133,22 +134,7 @@ class WanModel:
else: else:
pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device) pre_post_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
safetensors_pattern = os.path.join(lazy_load_model_path, "block_*.safetensors") return pre_post_weight_dict
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {lazy_load_model_path}")
for file_path in safetensors_files:
with safe_open(file_path, framework="pt") as f:
for k in f.keys():
if "modulation" in k:
if f.get_tensor(k).dtype == torch.float:
if use_bf16 or all(s not in k for s in skip_bf16):
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(torch.bfloat16).to(self.device)
else:
transformer_weight_dict[k] = f.get_tensor(k).pin_memory().to(self.device)
return pre_post_weight_dict, transformer_weight_dict
def _init_weights(self, weight_dict=None): def _init_weights(self, weight_dict=None):
use_bf16 = GET_DTYPE() == "BF16" use_bf16 = GET_DTYPE() == "BF16"
...@@ -161,10 +147,7 @@ class WanModel: ...@@ -161,10 +147,7 @@ class WanModel:
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16) self.original_weight_dict = self._load_quant_ckpt(use_bf16, skip_bf16)
else: else:
( self.original_weight_dict = self._load_quant_split_ckpt(use_bf16, skip_bf16)
self.original_weight_dict,
self.transformer_weight_dict,
) = self._load_quant_split_ckpt(use_bf16, skip_bf16)
else: else:
self.original_weight_dict = weight_dict self.original_weight_dict = weight_dict
# init weights # init weights
...@@ -174,10 +157,7 @@ class WanModel: ...@@ -174,10 +157,7 @@ class WanModel:
# load weights # load weights
self.pre_weight.load(self.original_weight_dict) self.pre_weight.load(self.original_weight_dict)
self.post_weight.load(self.original_weight_dict) self.post_weight.load(self.original_weight_dict)
if hasattr(self, "transformer_weight_dict"): self.transformer_weights.load(self.original_weight_dict)
self.transformer_weights.load(self.transformer_weight_dict)
else:
self.transformer_weights.load(self.original_weight_dict)
def _init_infer(self): def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
...@@ -212,13 +192,21 @@ class WanModel: ...@@ -212,13 +192,21 @@ class WanModel:
self.scheduler.noise_pred = noise_pred_cond self.scheduler.noise_pred = noise_pred_cond
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_cond, grid_sizes
torch.cuda.empty_cache()
if self.config["enable_cfg"]: if self.config["enable_cfg"]:
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False) embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, positive=False)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] noise_pred_uncond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (noise_pred_cond - noise_pred_uncond) self.scheduler.noise_pred = noise_pred_uncond + self.config.sample_guide_scale * (self.scheduler.noise_pred - noise_pred_uncond)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
self.post_weight.to_cpu() self.post_weight.to_cpu()
if self.clean_cuda_cache:
del x, embed, pre_infer_out, noise_pred_uncond, grid_sizes
torch.cuda.empty_cache()
...@@ -34,13 +34,8 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -34,13 +34,8 @@ class WanTransformerAttentionBlock(WeightModule):
self.config = config self.config = config
self.quant_method = config["mm_config"].get("quant_method", None) self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False) self.sparge = config.get("sparge", False)
self.register_parameter(
"modulation",
TENSOR_REGISTER["Default"](f"blocks.{self.block_index}.modulation"),
)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors") lazy_load_path = os.path.join(self.config.dit_quantized_ckpt, f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
...@@ -49,6 +44,14 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -49,6 +44,14 @@ class WanTransformerAttentionBlock(WeightModule):
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
WanModulation(
block_index,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanSelfAttention( WanSelfAttention(
block_index, block_index,
task, task,
...@@ -79,6 +82,29 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -79,6 +82,29 @@ class WanTransformerAttentionBlock(WeightModule):
self.add_module("compute_phases", self.compute_phases) self.add_module("compute_phases", self.compute_phases)
class WanModulation(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.add_module(
"modulation",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.modulation",
self.lazy_load,
self.lazy_load_file,
),
)
class WanSelfAttention(WeightModule): class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__() super().__init__()
...@@ -92,7 +118,7 @@ class WanSelfAttention(WeightModule): ...@@ -92,7 +118,7 @@ class WanSelfAttention(WeightModule):
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.register_parameter( self.add_module(
"norm1", "norm1",
LN_WEIGHT_REGISTER["Default"](), LN_WEIGHT_REGISTER["Default"](),
) )
...@@ -160,7 +186,7 @@ class WanSelfAttention(WeightModule): ...@@ -160,7 +186,7 @@ class WanSelfAttention(WeightModule):
else: else:
self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) self.add_module("self_attn_1", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter( self.add_module(
"smooth_norm1_weight", "smooth_norm1_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.weight", f"blocks.{self.block_index}.affine_norm1.weight",
...@@ -168,7 +194,7 @@ class WanSelfAttention(WeightModule): ...@@ -168,7 +194,7 @@ class WanSelfAttention(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.register_parameter( self.add_module(
"smooth_norm1_bias", "smooth_norm1_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.bias", f"blocks.{self.block_index}.affine_norm1.bias",
...@@ -292,7 +318,7 @@ class WanFFN(WeightModule): ...@@ -292,7 +318,7 @@ class WanFFN(WeightModule):
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file self.lazy_load_file = lazy_load_file
self.register_parameter( self.add_module(
"norm2", "norm2",
LN_WEIGHT_REGISTER["Default"](), LN_WEIGHT_REGISTER["Default"](),
) )
...@@ -317,7 +343,7 @@ class WanFFN(WeightModule): ...@@ -317,7 +343,7 @@ class WanFFN(WeightModule):
) )
if self.quant_method in ["smoothquant", "awq"]: if self.quant_method in ["smoothquant", "awq"]:
self.register_parameter( self.add_module(
"smooth_norm2_weight", "smooth_norm2_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.weight", f"blocks.{self.block_index}.affine_norm3.weight",
...@@ -325,7 +351,7 @@ class WanFFN(WeightModule): ...@@ -325,7 +351,7 @@ class WanFFN(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.register_parameter( self.add_module(
"smooth_norm2_bias", "smooth_norm2_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.bias", f"blocks.{self.block_index}.affine_norm3.bias",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment