Unverified Commit 74eeb429 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

reconstruct disk offload and fix lightx2v_platform bugs (#558)


Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent f7cdbcb5
...@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule): ...@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule):
self.lazy_load_file, self.lazy_load_file,
), ),
) )
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
...@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule): ...@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule):
self.add_module( self.add_module(
"time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias") "time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias")
) )
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
...@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule): ...@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule):
self.mm_type = config.get("dit_quant_scheme", "Default") self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default": if self.mm_type != "Default":
assert config.get("dit_quantized") is True assert config.get("dit_quantized") is True
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "transformer_blocks") for i in range(self.blocks_num)) blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, False, "transformer_blocks") for i in range(self.blocks_num))
self.register_offload_buffers(config) self.register_offload_buffers(config)
self.add_module("blocks", blocks) self.add_module("blocks", blocks)
...@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule): ...@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule):
if config["cpu_offload"]: if config["cpu_offload"]:
if config["offload_granularity"] == "block": if config["offload_granularity"] == "block":
self.offload_blocks_num = 2 self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList( self.offload_block_cuda_buffers = WeightModuleList(
[QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, "transformer_blocks") for i in range(self.offload_blocks_num)] [QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, False, "transformer_blocks") for i in range(self.offload_blocks_num)]
) )
self.add_module("offload_block_buffers", self.offload_block_buffers) self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers)
self.offload_phase_buffers = None self.offload_phase_cuda_buffers = None
else: else:
raise NotImplementedError raise NotImplementedError
class QwenImageTransformerAttentionBlock(WeightModule): class QwenImageTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, is_offload_buffer=False, block_prefix="transformer_blocks"): def __init__(self, block_index, task, mm_type, config, create_cuda_buffer=False, create_cpu_buffer=False, block_prefix="transformer_blocks"):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.1.weight", f"{block_prefix}.{self.block_index}.img_mod.1.weight",
f"{block_prefix}.{self.block_index}.img_mod.1.bias", f"{block_prefix}.{self.block_index}.img_mod.1.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"img_norm1", "img_norm1",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
) )
self.attn = QwenImageCrossAttention( self.attn = QwenImageCrossAttention(
block_index=block_index, block_index=block_index,
...@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"], task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
is_offload_buffer=is_offload_buffer, create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_file=self.lazy_load_file,
) )
...@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self.add_module( self.add_module(
"img_norm2", "img_norm2",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
) )
img_mlp = QwenImageFFN( img_mlp = QwenImageFFN(
block_index=block_index, block_index=block_index,
...@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"], task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
is_offload_buffer=is_offload_buffer, create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_file=self.lazy_load_file,
) )
...@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.1.weight", f"{block_prefix}.{self.block_index}.txt_mod.1.weight",
f"{block_prefix}.{self.block_index}.txt_mod.1.bias", f"{block_prefix}.{self.block_index}.txt_mod.1.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"txt_norm1", "txt_norm1",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
) )
# Text doesn't need separate attention - it's handled by img_attn joint computation # Text doesn't need separate attention - it's handled by img_attn joint computation
self.add_module( self.add_module(
"txt_norm2", "txt_norm2",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
) )
txt_mlp = QwenImageFFN( txt_mlp = QwenImageFFN(
block_index=block_index, block_index=block_index,
...@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"], task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
is_offload_buffer=is_offload_buffer, create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_file=self.lazy_load_file,
) )
...@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
class QwenImageCrossAttention(WeightModule): class QwenImageCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule): ...@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q # norm_q
self.add_module( self.add_module(
"norm_q", "norm_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=is_offload_buffer), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
) )
# norm_k # norm_k
self.add_module( self.add_module(
"norm_k", "norm_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=is_offload_buffer), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
) )
# to_q # to_q
self.add_module( self.add_module(
...@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_q.weight", f"{block_prefix}.{self.block_index}.attn.to_q.weight",
f"{block_prefix}.{self.block_index}.attn.to_q.bias", f"{block_prefix}.{self.block_index}.attn.to_q.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_k.weight", f"{block_prefix}.{self.block_index}.attn.to_k.weight",
f"{block_prefix}.{self.block_index}.attn.to_k.bias", f"{block_prefix}.{self.block_index}.attn.to_k.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_v.weight", f"{block_prefix}.{self.block_index}.attn.to_v.weight",
f"{block_prefix}.{self.block_index}.attn.to_v.bias", f"{block_prefix}.{self.block_index}.attn.to_v.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_out.0.weight", f"{block_prefix}.{self.block_index}.attn.to_out.0.weight",
f"{block_prefix}.{self.block_index}.attn.to_out.0.bias", f"{block_prefix}.{self.block_index}.attn.to_out.0.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule): ...@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_add_out.weight", f"{block_prefix}.{self.block_index}.attn.to_add_out.weight",
f"{block_prefix}.{self.block_index}.attn.to_add_out.bias", f"{block_prefix}.{self.block_index}.attn.to_add_out.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule): ...@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q # norm_added_q
self.add_module( self.add_module(
"norm_added_q", "norm_added_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=is_offload_buffer), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
) )
# norm_added_k # norm_added_k
self.add_module( self.add_module(
"norm_added_k", "norm_added_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=is_offload_buffer), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
) )
# attn # attn
self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]()) self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]())
...@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule):
class QwenImageFFN(WeightModule): class QwenImageFFN(WeightModule):
def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule): ...@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule): ...@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -7,6 +7,7 @@ import torch.nn.functional as F ...@@ -7,6 +7,7 @@ import torch.nn.functional as F
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
class WanTransformerInferCaching(WanOffloadTransformerInfer): class WanTransformerInferCaching(WanOffloadTransformerInfer):
...@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
else: else:
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp - self.previous_e0_even.cuda()).abs().mean() / self.previous_e0_even.cuda().abs().mean()).cpu().item()) self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_e0_even.to(AI_DEVICE)).abs().mean() / self.previous_e0_even.to(AI_DEVICE).abs().mean()).cpu().item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh: if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False should_calc = False
else: else:
...@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.accumulated_rel_l1_distance_odd = 0 self.accumulated_rel_l1_distance_odd = 0
else: else:
rescale_func = np.poly1d(self.coefficients) rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd.cuda()).abs().mean() / self.previous_e0_odd.cuda().abs().mean()).cpu().item()) self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd.to(AI_DEVICE)).abs().mean() / self.previous_e0_odd.to(AI_DEVICE).abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False should_calc = False
else: else:
...@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching): ...@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def infer_using_cache(self, x): def infer_using_cache(self, x):
if self.scheduler.infer_condition: if self.scheduler.infer_condition:
x.add_(self.previous_residual_even.cuda()) x.add_(self.previous_residual_even.to(AI_DEVICE))
else: else:
x.add_(self.previous_residual_odd.cuda()) x.add_(self.previous_residual_odd.to(AI_DEVICE))
return x return x
def clear(self): def clear(self):
...@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching): ...@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
def infer_using_cache(self, x): def infer_using_cache(self, x):
residual_x = self.residual_cache[self.scheduler.infer_condition] residual_x = self.residual_cache[self.scheduler.infer_condition]
x.add_(residual_x.cuda()) x.add_(residual_x.to(AI_DEVICE))
return x return x
def clear(self): def clear(self):
......
import torch import torch
from lightx2v.common.offload.manager import ( from lightx2v.common.offload.manager import WeightAsyncStreamManager
LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
)
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WanOffloadTransformerInfer(WanTransformerInfer): class WanOffloadTransformerInfer(WanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
self.offload_ratio = self.config["offload_ratio"]
else:
self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block") offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block": if offload_granularity == "block":
if not self.config.get("lazy_load", False): self.infer_func = self.infer_with_blocks_offload
self.infer_func = self.infer_with_blocks_offload
else:
self.infer_func = self.infer_with_blocks_lazy_offload
elif offload_granularity == "phase": elif offload_granularity == "phase":
if not self.config.get("lazy_load", False): self.infer_func = self.infer_with_phases_offload
self.infer_func = self.infer_with_phases_offload
else:
self.infer_func = self.infer_with_phases_lazy_offload
self.phase_params = { self.phase_params = {
"shift_msa": None, "shift_msa": None,
"scale_msa": None, "scale_msa": None,
...@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
if offload_granularity != "model": if offload_granularity != "model":
if not self.config.get("lazy_load", False): self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
else:
self.offload_manager = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=self.offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
def infer_with_blocks_offload(self, blocks, x, pre_infer_out): def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if self.offload_manager.need_init_first_buffer:
self.offload_manager.init_first_buffer(blocks) self.offload_manager.init_first_buffer(blocks)
if block_idx < len(blocks) - 1: self.offload_manager.prefetch_weights((block_idx + 1) % len(blocks), blocks)
self.offload_manager.prefetch_weights(block_idx + 1, blocks) with torch_device_module.stream(self.offload_manager.compute_stream):
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out) x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out)
self.offload_manager.swap_blocks() self.offload_manager.swap_blocks()
return x
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out):
self.offload_manager.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0:
block = self.offload_manager.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.offload_manager.cuda_buffers[0] = (block_idx, block)
if block_idx < len(blocks) - 1:
self.offload_manager.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.offload_manager.swap_blocks()
if block_idx == len(blocks) - 1:
self.offload_manager.pin_memory_buffer.pop_front()
self.offload_manager._async_prefetch_block(blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del ( del (
pre_infer_out.embed0, pre_infer_out.embed0,
pre_infer_out.freqs,
pre_infer_out.context, pre_infer_out.context,
) )
torch.cuda.empty_cache() torch_device_module.empty_cache()
return x return x
def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out): def infer_with_phases_offload(self, blocks, x, pre_infer_out):
self.offload_manager.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True) x = self.infer_phases(block_idx, blocks, x, pre_infer_out)
self.offload_manager._async_prefetch_block(blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del ( del (
self.phase_params["attn_out"], self.phase_params["attn_out"],
self.phase_params["y_out"], self.phase_params["y_out"],
self.phase_params["y"], self.phase_params["y"],
) )
torch.cuda.empty_cache() torch_device_module.empty_cache()
if self.clean_cuda_cache: if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out) self.clear_offload_params(pre_infer_out)
return x return x
def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy): def infer_phases(self, block_idx, blocks, x, pre_infer_out):
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0: if self.offload_manager.need_init_first_buffer:
if lazy: self.offload_manager.init_first_buffer(blocks)
obj_key = (block_idx, phase_idx) next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
phase = self.offload_manager.pin_memory_buffer.get(obj_key) next_phase_idx = (phase_idx + 1) % self.phases_num
phase.to_cuda() self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
self.offload_manager.cuda_buffers[0] = (obj_key, phase) with torch_device_module.stream(self.offload_manager.compute_stream):
else:
self.offload_manager.init_first_buffer(blocks)
is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out) x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
self.offload_manager.swap_phases() self.offload_manager.swap_phases()
...@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0) ) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0)
self.phase_params["y_out"] = self.infer_self_attn( self.phase_params["y_out"] = self.infer_self_attn(
cur_phase, cur_phase,
pre_infer_out.grid_sizes.tuple,
x, x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
self.phase_params["shift_msa"], self.phase_params["shift_msa"],
self.phase_params["scale_msa"], self.phase_params["scale_msa"],
) )
...@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
) )
del ( del (
pre_infer_out.embed0, pre_infer_out.embed0,
pre_infer_out.freqs,
pre_infer_out.context, pre_infer_out.context,
) )
torch.cuda.empty_cache() torch_device_module.empty_cache()
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput from lightx2v.models.networks.wan.infer.module_io import GridOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
def sinusoidal_embedding_1d(dim, position): def sinusoidal_embedding_1d(dim, position):
...@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer): ...@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer):
rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6)),
], ],
dim=1, dim=1,
).cuda() ).to(AI_DEVICE)
def time_embedding(self, weights, embed): def time_embedding(self, weights, embed):
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
......
...@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer): ...@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])} self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])}
def infer(self, weights, pre_infer_out): def infer(self, weights, pre_infer_out):
self.get_scheduler_values()
pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context) pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
self.infer_vace_blocks(weights.vace_blocks, pre_infer_out) self.infer_vace_blocks(weights.vace_blocks, pre_infer_out)
x = self.infer_main_blocks(weights.blocks, pre_infer_out) x = self.infer_main_blocks(weights.blocks, pre_infer_out)
...@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer): ...@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
pre_infer_out.adapter_args["hints"] = [] pre_infer_out.adapter_args["hints"] = []
self.infer_state = "vace" self.infer_state = "vace"
if hasattr(self, "offload_manager"): if hasattr(self, "offload_manager"):
self.offload_manager.init_cuda_buffer(self.vace_offload_block_buffers, self.vace_offload_phase_buffers) self.offload_manager.init_cuda_buffer(self.vace_offload_block_cuda_buffers, self.vace_offload_phase_cuda_buffers)
self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out) self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out)
self.infer_state = "base" self.infer_state = "base"
if hasattr(self, "offload_manager"): if hasattr(self, "offload_manager"):
self.offload_manager.init_cuda_buffer(self.offload_block_buffers, self.offload_phase_buffers) self.offload_manager.init_cuda_buffer(self.offload_block_cuda_buffers, self.offload_phase_cuda_buffers)
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
......
...@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin): ...@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin):
self.cpu_offload = self.config.get("cpu_offload", False) self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block") self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type self.model_type = model_type
self.remove_keys = []
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
self.remove_keys.extend(["blocks."])
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else: else:
...@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer): def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config["parallel"]: if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank() device = dist.get_rank()
else: else:
device = str(self.device) device = str(self.device)
...@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin): ...@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin):
else: else:
safetensors_files = [safetensors_path] safetensors_files = [safetensors_path]
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None: if self.config.get("adapter_model_path", None) is not None:
...@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin): ...@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin):
safetensors_files = [safetensors_path] safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path) safetensors_path = os.path.dirname(safetensors_path)
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
weight_dict = {} weight_dict = {}
for safetensor_path in safetensors_files: for safetensor_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None: if self.config.get("adapter_model_path", None) is not None:
...@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin): ...@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin):
return weight_dict return weight_dict
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
return pre_post_weight_dict
def _load_gguf_ckpt(self, gguf_path): def _load_gguf_ckpt(self, gguf_path):
state_dict = load_gguf_sd_ckpt(gguf_path, to_device=self.device) state_dict = load_gguf_sd_ckpt(gguf_path, to_device=self.device)
return state_dict return state_dict
...@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer) weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else: else:
# Load quantized weights # Load quantized weights
if not self.config.get("lazy_load", False): weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False): if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader) weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
...@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin): ...@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers # Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config) self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config) if self.lazy_load:
self.transformer_weights = self.transformer_weight_class(self.config, self.lazy_load_path)
else:
self.transformer_weights = self.transformer_weight_class(self.config)
if not self._should_init_empty_model(): if not self._should_init_empty_model():
self._apply_weights() self._apply_weights()
...@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin): ...@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin):
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"): if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers) self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
......
...@@ -22,10 +22,15 @@ class WanVaceModel(WanModel): ...@@ -22,10 +22,15 @@ class WanVaceModel(WanModel):
def _init_infer(self): def _init_infer(self):
super()._init_infer() super()._init_infer()
if hasattr(self.transformer_infer, "offload_manager"): if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_block_buffers = self.transformer_weights.offload_block_buffers self.transformer_infer.offload_block_cuda_buffers = self.transformer_weights.offload_block_cuda_buffers
self.transformer_infer.offload_phase_buffers = self.transformer_weights.offload_phase_buffers self.transformer_infer.offload_phase_cuda_buffers = self.transformer_weights.offload_phase_cuda_buffers
self.transformer_infer.vace_offload_block_buffers = self.transformer_weights.vace_offload_block_buffers self.transformer_infer.vace_offload_block_cuda_buffers = self.transformer_weights.vace_offload_block_cuda_buffers
self.transformer_infer.vace_offload_phase_buffers = self.transformer_weights.vace_offload_phase_buffers self.transformer_infer.vace_offload_phase_cuda_buffers = self.transformer_weights.vace_offload_phase_cuda_buffers
if self.lazy_load:
self.transformer_infer.offload_block_cpu_buffers = self.transformer_weights.offload_block_cpu_buffers
self.transformer_infer.offload_phase_cpu_buffers = self.transformer_weights.offload_phase_cpu_buffers
self.transformer_infer.vace_offload_block_cpu_buffers = self.transformer_weights.vace_offload_block_cpu_buffers
self.transformer_infer.vace_offload_phase_cpu_buffers = self.transformer_weights.vace_offload_phase_cpu_buffers
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
......
...@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights): ...@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self._add_animate_fuserblock_to_offload_buffers() self._add_animate_fuserblock_to_offload_buffers()
def _add_animate_fuserblock_to_offload_buffers(self): def _add_animate_fuserblock_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None: if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None:
for i in range(self.offload_blocks_num): for i in range(self.offload_blocks_num):
self.offload_block_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True)) self.offload_block_cuda_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True))
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None: if self.lazy_load:
self.offload_phase_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True)) self.offload_block_cpu_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True))
elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None:
self.offload_phase_cuda_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True))
if self.lazy_load:
self.offload_phase_cpu_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True))
class WanAnimateFuserBlock(WeightModule): class WanAnimateFuserBlock(WeightModule):
def __init__(self, config, block_index, block_prefix, mm_type, is_offload_buffer=False): def __init__(self, config, block_index, block_prefix, mm_type, create_cuda_buffer=False, create_cpu_buffer=False):
super().__init__() super().__init__()
self.config = config self.config = config
self.is_post_adapter = True self.is_post_adapter = True
...@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type]( MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_kv.weight", f"{block_prefix}.{block_index}.linear1_kv.weight",
f"{block_prefix}.{block_index}.linear1_kv.bias", f"{block_prefix}.{block_index}.linear1_kv.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter, self.is_post_adapter,
...@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type]( MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_q.weight", f"{block_prefix}.{block_index}.linear1_q.weight",
f"{block_prefix}.{block_index}.linear1_q.bias", f"{block_prefix}.{block_index}.linear1_q.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter, self.is_post_adapter,
...@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type]( MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear2.weight", f"{block_prefix}.{block_index}.linear2.weight",
f"{block_prefix}.{block_index}.linear2.bias", f"{block_prefix}.{block_index}.linear2.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter, self.is_post_adapter,
...@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule):
"q_norm", "q_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.q_norm.weight", f"{block_prefix}.{block_index}.q_norm.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter, self.is_post_adapter,
...@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm", "k_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.k_norm.weight", f"{block_prefix}.{block_index}.k_norm.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter, self.is_post_adapter,
......
...@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights): ...@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self.mm_type, self.mm_type,
self.config, self.config,
False, False,
False,
self.blocks[i].lazy_load, self.blocks[i].lazy_load,
self.blocks[i].lazy_load_file, self.blocks[i].lazy_load_file,
) )
...@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights): ...@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self._add_audio_adapter_ca_to_offload_buffers() self._add_audio_adapter_ca_to_offload_buffers()
def _add_audio_adapter_ca_to_offload_buffers(self): def _add_audio_adapter_ca_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None: if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None:
for i in range(self.offload_blocks_num): for i in range(self.offload_blocks_num):
offload_buffer = self.offload_block_buffers[i] offload_buffer = self.offload_block_cuda_buffers[i]
adapter_ca = WanAudioAdapterCA( adapter_ca = WanAudioAdapterCA(
block_index=i, block_index=i,
block_prefix=f"ca", block_prefix=f"ca",
task=self.task, task=self.task,
mm_type=self.mm_type, mm_type=self.mm_type,
config=self.config, config=self.config,
is_offload_buffer=True, create_cuda_buffer=True,
create_cpu_buffer=False,
lazy_load=offload_buffer.lazy_load, lazy_load=offload_buffer.lazy_load,
lazy_load_file=offload_buffer.lazy_load_file, lazy_load_file=offload_buffer.lazy_load_file,
) )
offload_buffer.compute_phases.append(adapter_ca) offload_buffer.compute_phases.append(adapter_ca)
if self.lazy_load:
offload_buffer = self.offload_block_cpu_buffers[i]
adapter_ca = WanAudioAdapterCA(
block_index=i,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
lazy_load=offload_buffer.lazy_load,
lazy_load_file=offload_buffer.lazy_load_file,
)
offload_buffer.compute_phases.append(adapter_ca)
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None: elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None:
adapter_ca = WanAudioAdapterCA( adapter_ca = WanAudioAdapterCA(
block_index=0, block_index=0,
block_prefix=f"ca", block_prefix=f"ca",
task=self.task, task=self.task,
mm_type=self.mm_type, mm_type=self.mm_type,
config=self.config, config=self.config,
is_offload_buffer=True, create_cuda_buffer=True,
create_cpu_buffer=False,
lazy_load=self.blocks[0].lazy_load, lazy_load=self.blocks[0].lazy_load,
lazy_load_file=self.blocks[0].lazy_load_file, lazy_load_file=self.blocks[0].lazy_load_file,
) )
self.offload_phase_buffers.append(adapter_ca) self.offload_phase_cuda_buffers.append(adapter_ca)
if self.lazy_load:
adapter_ca = WanAudioAdapterCA(
block_index=0,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
lazy_load=self.blocks[0].lazy_load,
lazy_load_file=self.blocks[0].lazy_load_file,
)
self.offload_phase_cpu_buffers.append(adapter_ca)
class WanAudioAdapterCA(WeightModule): class WanAudioAdapterCA(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule): ...@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_q.weight", f"{block_prefix}.{block_index}.to_q.weight",
f"{block_prefix}.{block_index}.to_q.bias", f"{block_prefix}.{block_index}.to_q.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule): ...@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_kv.weight", f"{block_prefix}.{block_index}.to_kv.weight",
f"{block_prefix}.{block_index}.to_kv.bias", f"{block_prefix}.{block_index}.to_kv.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule): ...@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_out.weight", f"{block_prefix}.{block_index}.to_out.weight",
f"{block_prefix}.{block_index}.to_out.bias", f"{block_prefix}.{block_index}.to_out.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule): ...@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{block_index}.norm_kv.weight", f"{block_prefix}.{block_index}.norm_kv.weight",
f"{block_prefix}.{block_index}.norm_kv.bias", f"{block_prefix}.{block_index}.norm_kv.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule): ...@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate", "shift_scale_gate",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{block_index}.shift_scale_gate", f"{block_prefix}.{block_index}.shift_scale_gate",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanFFN, WanFFN,
...@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule): ...@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule):
block_list = [] block_list = []
for i in range(self.blocks_num): for i in range(self.blocks_num):
if i in action_blocks: if i in action_blocks:
block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config, "blocks")) block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config))
else: else:
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "blocks")) block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config))
self.blocks = WeightModuleList(block_list) self.blocks = WeightModuleList(block_list)
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
...@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule): ...@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule):
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
def non_block_weights_to_cuda(self): def non_block_weights_to_cuda(self):
self.norm.to_cuda() self.norm.to_cuda()
self.head.to_cuda() self.head.to_cuda()
...@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule): ...@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule):
self.task = task self.task = task
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
assert not self.config.get("lazy_load", False)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
WanSelfAttention( WanSelfAttention(block_index, block_prefix, task, mm_type, config),
block_index,
block_prefix,
task,
mm_type,
config,
False,
self.lazy_load,
self.lazy_load_file,
),
WanActionCrossAttention( WanActionCrossAttention(
block_index, block_index,
block_prefix, block_prefix,
task, task,
mm_type, mm_type,
config, config,
self.lazy_load,
self.lazy_load_file,
), ),
WanActionModule( WanActionModule(
block_index, block_index,
...@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule): ...@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
self.lazy_load,
self.lazy_load_file,
), ),
WanFFN( WanFFN(
block_index, block_index,
...@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule): ...@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
False,
self.lazy_load,
self.lazy_load_file,
), ),
] ]
) )
...@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule): ...@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule):
class WanActionModule(WeightModule): class WanActionModule(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -129,9 +97,6 @@ class WanActionModule(WeightModule): ...@@ -129,9 +97,6 @@ class WanActionModule(WeightModule):
self.config = config self.config = config
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.attn_rms_type = "self_forcing" self.attn_rms_type = "self_forcing"
self.add_module( self.add_module(
...@@ -139,8 +104,6 @@ class WanActionModule(WeightModule): ...@@ -139,8 +104,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.weight", f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.bias", f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -148,8 +111,6 @@ class WanActionModule(WeightModule): ...@@ -148,8 +111,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.weight", f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.bias", f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
...@@ -158,8 +119,6 @@ class WanActionModule(WeightModule): ...@@ -158,8 +119,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_keyboard.weight", f"{block_prefix}.{self.block_index}.action_model.proj_keyboard.weight",
bias_name=None, bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
), ),
) )
...@@ -168,8 +127,6 @@ class WanActionModule(WeightModule): ...@@ -168,8 +127,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_attn_kv.weight", f"{block_prefix}.{self.block_index}.action_model.keyboard_attn_kv.weight",
bias_name=None, bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
), ),
) )
...@@ -180,8 +137,6 @@ class WanActionModule(WeightModule): ...@@ -180,8 +137,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_attn_q.weight", f"{block_prefix}.{self.block_index}.action_model.mouse_attn_q.weight",
bias_name=None, bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
), ),
) )
...@@ -191,8 +146,6 @@ class WanActionModule(WeightModule): ...@@ -191,8 +146,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.t_qkv.weight", f"{block_prefix}.{self.block_index}.action_model.t_qkv.weight",
bias_name=None, bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
), ),
) )
...@@ -201,8 +154,6 @@ class WanActionModule(WeightModule): ...@@ -201,8 +154,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_mouse.weight", f"{block_prefix}.{self.block_index}.action_model.proj_mouse.weight",
bias_name=None, bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
), ),
) )
...@@ -211,8 +162,6 @@ class WanActionModule(WeightModule): ...@@ -211,8 +162,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.weight", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.bias", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -220,8 +169,6 @@ class WanActionModule(WeightModule): ...@@ -220,8 +169,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.weight", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.bias", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -229,22 +176,18 @@ class WanActionModule(WeightModule): ...@@ -229,22 +176,18 @@ class WanActionModule(WeightModule):
LN_WEIGHT_REGISTER[self.mm_type]( LN_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.weight", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.bias", f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6, eps=1e-6,
), ),
) )
class WanActionCrossAttention(WeightModule): class WanActionCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.config.get("sf_config", False): if self.config.get("sf_config", False):
self.attn_rms_type = "self_forcing" self.attn_rms_type = "self_forcing"
...@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule): ...@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule):
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight", f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias", f"{block_prefix}.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule): ...@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight", f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias", f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule): ...@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight", f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias", f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule): ...@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight", f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias", f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
...@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule): ...@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight", f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias", f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"cross_attn_norm_q", "cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"cross_attn_norm_k", "cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
), ),
) )
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]()) self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
import os
from safetensors import safe_open from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
...@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import ( ...@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import (
class WanTransformerWeights(WeightModule): class WanTransformerWeights(WeightModule):
def __init__(self, config): def __init__(self, config, lazy_load_path=None):
super().__init__() super().__init__()
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config["task"] self.task = config["task"]
...@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule): ...@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule):
assert config.get("dit_quantized") is True assert config.get("dit_quantized") is True
if config.get("do_mm_calib", False): if config.get("do_mm_calib", False):
self.mm_type = "Calib" self.mm_type = "Calib"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.lazy_load = self.config.get("lazy_load", False)
if not self.lazy_load:
self.lazy_load_file = None
else:
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
self.blocks = WeightModuleList(
[
WanTransformerAttentionBlock(
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
for i in range(self.blocks_num)
]
)
self.register_offload_buffers(config) self.register_offload_buffers(config)
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
...@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule): ...@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule):
if config["cpu_offload"]: if config["cpu_offload"]:
if config["offload_granularity"] == "block": if config["offload_granularity"] == "block":
self.offload_blocks_num = 2 self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList( self.offload_block_cuda_buffers = WeightModuleList(
[ [
WanTransformerAttentionBlock( WanTransformerAttentionBlock(
i, block_index=i,
self.task, task=self.task,
self.mm_type, mm_type=self.mm_type,
self.config, config=self.config,
True, create_cuda_buffer=True,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
) )
for i in range(self.offload_blocks_num) for i in range(self.offload_blocks_num)
] ]
) )
self.add_module("offload_block_buffers", self.offload_block_buffers) self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers)
self.offload_phase_buffers = None self.offload_phase_cuda_buffers = None
if self.lazy_load:
self.offload_blocks_num = 2
self.offload_block_cpu_buffers = WeightModuleList(
[
WanTransformerAttentionBlock(
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_cpu_buffers", self.offload_block_cpu_buffers)
self.offload_phase_cpu_buffers = None
elif config["offload_granularity"] == "phase": elif config["offload_granularity"] == "phase":
self.offload_phase_buffers = WanTransformerAttentionBlock( self.offload_phase_cuda_buffers = WanTransformerAttentionBlock(
0, block_index=0,
self.task, task=self.task,
self.mm_type, mm_type=self.mm_type,
self.config, config=self.config,
True, create_cuda_buffer=True,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
).compute_phases ).compute_phases
self.add_module("offload_phase_buffers", self.offload_phase_buffers) self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers)
self.offload_block_buffers = None self.offload_block_cuda_buffers = None
if self.lazy_load:
def clear(self): self.offload_phase_cpu_buffers = WanTransformerAttentionBlock(
for block in self.blocks: block_index=0,
for phase in block.compute_phases: task=self.task,
phase.clear() mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
).compute_phases
self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers)
self.offload_block_cpu_buffers = None
def non_block_weights_to_cuda(self): def non_block_weights_to_cuda(self):
self.norm.to_cuda() self.norm.to_cuda()
...@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer=False, create_cuda_buffer=False,
create_cpu_buffer=False,
block_prefix="blocks", block_prefix="blocks",
lazy_load=False,
lazy_load_file=None,
): ):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.is_offload_buffer = is_offload_buffer self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = lazy_load
if self.lazy_load: self.lazy_load_file = lazy_load_file
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
...@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule): ...@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer=False,
lazy_load, create_cpu_buffer=False,
lazy_load_file, lazy_load=False,
lazy_load_file=None,
): ):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
...@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule): ...@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule):
"modulation", "modulation",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.modulation", f"{block_prefix}.{self.block_index}.modulation",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule): ...@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.q.weight", f"{block_prefix}.{self.block_index}.self_attn.q.weight",
f"{block_prefix}.{self.block_index}.self_attn.q.bias", f"{block_prefix}.{self.block_index}.self_attn.q.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule): ...@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.k.weight", f"{block_prefix}.{self.block_index}.self_attn.k.weight",
f"{block_prefix}.{self.block_index}.self_attn.k.bias", f"{block_prefix}.{self.block_index}.self_attn.k.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule): ...@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.v.weight", f"{block_prefix}.{self.block_index}.self_attn.v.weight",
f"{block_prefix}.{self.block_index}.self_attn.v.bias", f"{block_prefix}.{self.block_index}.self_attn.v.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule): ...@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.o.weight", f"{block_prefix}.{self.block_index}.self_attn.o.weight",
f"{block_prefix}.{self.block_index}.self_attn.o.bias", f"{block_prefix}.{self.block_index}.self_attn.o.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule): ...@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q", "self_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight", f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule): ...@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k", "self_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight", f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule): ...@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_weight", "smooth_norm1_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.weight", f"{block_prefix}.{self.block_index}.affine_norm1.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule): ...@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias", "smooth_norm1_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.bias", f"{block_prefix}.{self.block_index}.affine_norm1.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule): ...@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer=False,
lazy_load, create_cpu_buffer=False,
lazy_load_file, lazy_load=False,
lazy_load_file=None,
): ):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
...@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule): ...@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight", f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias", f"{block_prefix}.{self.block_index}.norm3.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule): ...@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight", f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias", f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule): ...@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight", f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias", f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule): ...@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight", f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias", f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule): ...@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight", f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias", f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule): ...@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q", "cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule): ...@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k", "cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule): ...@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias", f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule): ...@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias", f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule): ...@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img", "cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -430,9 +511,10 @@ class WanFFN(WeightModule): ...@@ -430,9 +511,10 @@ class WanFFN(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer, create_cuda_buffer=False,
lazy_load, create_cpu_buffer=False,
lazy_load_file, lazy_load=False,
lazy_load_file=None,
): ):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
...@@ -453,7 +535,8 @@ class WanFFN(WeightModule): ...@@ -453,7 +535,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.0.weight", f"{block_prefix}.{self.block_index}.ffn.0.weight",
f"{block_prefix}.{self.block_index}.ffn.0.bias", f"{block_prefix}.{self.block_index}.ffn.0.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -463,7 +546,8 @@ class WanFFN(WeightModule): ...@@ -463,7 +546,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.2.weight", f"{block_prefix}.{self.block_index}.ffn.2.weight",
f"{block_prefix}.{self.block_index}.ffn.2.bias", f"{block_prefix}.{self.block_index}.ffn.2.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -474,7 +558,8 @@ class WanFFN(WeightModule): ...@@ -474,7 +558,8 @@ class WanFFN(WeightModule):
"smooth_norm2_weight", "smooth_norm2_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.weight", f"{block_prefix}.{self.block_index}.affine_norm3.weight",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -483,7 +568,8 @@ class WanFFN(WeightModule): ...@@ -483,7 +568,8 @@ class WanFFN(WeightModule):
"smooth_norm2_bias", "smooth_norm2_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.bias", f"{block_prefix}.{self.block_index}.affine_norm3.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.register_offload_buffers(config) self.register_offload_buffers(config)
self.vace_blocks = WeightModuleList( self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))] [WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
) )
self.add_module("vace_blocks", self.vace_blocks) self.add_module("vace_blocks", self.vace_blocks)
self.add_module( self.add_module(
...@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super().register_offload_buffers(config) super().register_offload_buffers(config)
if config["cpu_offload"]: if config["cpu_offload"]:
if config["offload_granularity"] == "block": if config["offload_granularity"] == "block":
self.vace_offload_block_buffers = WeightModuleList( self.vace_offload_block_cuda_buffers = WeightModuleList(
[ [
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"), WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"),
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"), WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"),
] ]
) )
self.add_module("vace_offload_block_buffers", self.vace_offload_block_buffers) self.add_module("vace_offload_block_cuda_buffers", self.vace_offload_block_cuda_buffers)
self.vace_offload_phase_buffers = None self.vace_offload_phase_cuda_buffers = None
elif config["offload_granularity"] == "phase": elif config["offload_granularity"] == "phase":
raise NotImplementedError raise NotImplementedError
def clear(self):
super().clear()
for vace_block in self.vace_blocks:
for vace_phase in vace_block.compute_phases:
vace_phase.clear()
def non_block_weights_to_cuda(self): def non_block_weights_to_cuda(self):
super().non_block_weights_to_cuda() super().non_block_weights_to_cuda()
self.vace_patch_embedding.to_cuda() self.vace_patch_embedding.to_cuda()
...@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, is_offload_buffer, block_prefix): def __init__(self, base_block_idx, block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix):
super().__init__(block_index, task, mm_type, config, is_offload_buffer, block_prefix) super().__init__(block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix)
if base_block_idx == 0: if base_block_idx == 0:
self.compute_phases[0].add_module( self.compute_phases[0].add_module(
"before_proj", "before_proj",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight", f"{block_prefix}.{self.block_index}.before_proj.weight",
f"{block_prefix}.{self.block_index}.before_proj.bias", f"{block_prefix}.{self.block_index}.before_proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): ...@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight", f"{block_prefix}.{self.block_index}.after_proj.weight",
f"{block_prefix}.{self.block_index}.after_proj.bias", f"{block_prefix}.{self.block_index}.after_proj.bias",
is_offload_buffer, create_cuda_buffer,
create_cpu_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner): ...@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner):
self.load_model() self.load_model()
elif self.config.get("lazy_load", False): elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False) assert self.config.get("cpu_offload", False)
self.model.set_scheduler(self.scheduler) # set scheduler to model if hasattr(self, "model"):
self.model.set_scheduler(self.scheduler) # set scheduler to model
if self.config["task"] == "i2v": if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v self.run_input_encoder = self._run_input_encoder_local_i2v
elif self.config["task"] == "flf2v": elif self.config["task"] == "flf2v":
...@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner): ...@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner):
del self.inputs del self.inputs
self.input_info = None self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
del self.model del self.model
if self.config.get("do_mm_calib", False): if self.config.get("do_mm_calib", False):
calib_path = os.path.join(os.getcwd(), "calib.pt") calib_path = os.path.join(os.getcwd(), "calib.pt")
...@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner): ...@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.model.set_scheduler(self.scheduler)
self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]: if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
......
...@@ -24,6 +24,7 @@ from lightx2v.utils.envs import * ...@@ -24,6 +24,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
from lightx2v_platform.base.global_var import AI_DEVICE
@RUNNER_REGISTER("wan2.2_animate") @RUNNER_REGISTER("wan2.2_animate")
...@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner): ...@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner):
], ],
dim=1, dim=1,
) )
.cuda() .to(AI_DEVICE)
.unsqueeze(0) .unsqueeze(0)
) )
mask_pixel_values = 1 - mask_pixel_values mask_pixel_values = 1 - mask_pixel_values
...@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner): ...@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner):
], ],
dim=1, dim=1,
) )
.cuda() .to(AI_DEVICE)
.unsqueeze(0) .unsqueeze(0)
) )
msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len) msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
...@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner): ...@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner):
dtype=GET_DTYPE(), dtype=GET_DTYPE(),
) # c t h w ) # c t h w
else: else:
refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().cuda() # c t h w refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().to(AI_DEVICE) # c t h w
bg_pixel_values, mask_pixel_values = None, None bg_pixel_values, mask_pixel_values = None, None
if self.config["replace_flag"] if "replace_flag" in self.config else False: if self.config["replace_flag"] if "replace_flag" in self.config else False:
...@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner): ...@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner):
return model return model
def load_encoders(self): def load_encoders(self):
motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).cuda() motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).cuda() face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.") motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.") face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
motion_encoder.load_state_dict(motion_weight_dict) motion_encoder.load_state_dict(motion_weight_dict)
......
...@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def process_single_mask(self, mask_file): def process_single_mask(self, mask_file):
mask_img = load_image(mask_file) mask_img = load_image(mask_file)
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
if mask_img.shape[1] == 3: # If it is an RGB three-channel image if mask_img.shape[1] == 3: # If it is an RGB three-channel image
mask_img = mask_img[:, :1] # Only take the first channel mask_img = mask_img[:, :1] # Only take the first channel
......
...@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli ...@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class VAEWrapper: class VAEWrapper:
...@@ -90,8 +91,8 @@ def get_current_action(mode="universal"): ...@@ -90,8 +91,8 @@ def get_current_action(mode="universal"):
flag = 1 flag = 1
except Exception as e: except Exception as e:
pass pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).cuda() mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).to(AI_DEVICE)
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda() keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE)
elif mode == "gta_drive": elif mode == "gta_drive":
print() print()
print("-" * 30) print("-" * 30)
...@@ -118,8 +119,8 @@ def get_current_action(mode="universal"): ...@@ -118,8 +119,8 @@ def get_current_action(mode="universal"):
flag = 1 flag = 1
except Exception as e: except Exception as e:
pass pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).cuda() mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).to(AI_DEVICE)
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).cuda() keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).to(AI_DEVICE)
elif mode == "templerun": elif mode == "templerun":
print() print()
print("-" * 30) print("-" * 30)
...@@ -142,7 +143,7 @@ def get_current_action(mode="universal"): ...@@ -142,7 +143,7 @@ def get_current_action(mode="universal"):
flag = 1 flag = 1
except Exception as e: except Exception as e:
pass pass
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda() keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE)
if mode != "templerun": if mode != "templerun":
return {"mouse": mouse_cond, "keyboard": keyboard_cond} return {"mouse": mouse_cond, "keyboard": keyboard_cond}
......
...@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner): ...@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner):
if vae_offload: if vae_offload:
vae_device = torch.device("cpu") vae_device = torch.device("cpu")
else: else:
vae_device = torch.device(self.init_device) vae_device = torch.device(AI_DEVICE)
vae_config = { vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name), "vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
...@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner): ...@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner):
} }
if self.config.get("use_tae", False): if self.config.get("use_tae", False):
tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name) tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda") vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to(AI_DEVICE)
else: else:
vae_decoder = self.vae_cls(**vae_config) vae_decoder = self.vae_cls(**vae_config)
return vae_decoder return vae_decoder
......
...@@ -2,6 +2,8 @@ from typing import List, Tuple, Union ...@@ -2,6 +2,8 @@ from typing import List, Tuple, Union
import torch import torch
from lightx2v_platform.base.global_var import AI_DEVICE
def _to_tuple(x, dim=2): def _to_tuple(x, dim=2):
if isinstance(x, int): if isinstance(x, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment