Unverified Commit 51be3ad2 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Fix] remove d2h of cpu-offload infer (#476)

parent 2559b3e7
...@@ -10,11 +10,6 @@ class QwenImagePostWeights(WeightModule): ...@@ -10,11 +10,6 @@ class QwenImagePostWeights(WeightModule):
super().__init__() super().__init__()
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
if config["do_mm_calib"]:
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load: if self.lazy_load:
assert NotImplementedError assert NotImplementedError
...@@ -23,7 +18,7 @@ class QwenImagePostWeights(WeightModule): ...@@ -23,7 +18,7 @@ class QwenImagePostWeights(WeightModule):
# norm_out # norm_out
self.add_module( self.add_module(
"norm_out_linear", "norm_out_linear",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER["Default"](
"norm_out.linear.weight", "norm_out.linear.weight",
"norm_out.linear.bias", "norm_out.linear.bias",
self.lazy_load, self.lazy_load,
...@@ -35,7 +30,7 @@ class QwenImagePostWeights(WeightModule): ...@@ -35,7 +30,7 @@ class QwenImagePostWeights(WeightModule):
# proj_out # proj_out
self.add_module( self.add_module(
"proj_out_linear", "proj_out_linear",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER["Default"](
"proj_out.weight", "proj_out.weight",
"proj_out.bias", "proj_out.bias",
self.lazy_load, self.lazy_load,
......
...@@ -12,17 +12,28 @@ class QwenImageTransformerWeights(WeightModule): ...@@ -12,17 +12,28 @@ class QwenImageTransformerWeights(WeightModule):
self.blocks_num = config["num_layers"] self.blocks_num = config["num_layers"]
self.task = config["task"] self.task = config["task"]
self.config = config self.config = config
if config["do_mm_calib"]: self.mm_type = config.get("dit_quant_scheme", "Default")
self.mm_type = "Calib" if self.mm_type != "Default":
else: assert config.get("dit_quantized") is True
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default" blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "transformer_blocks") for i in range(self.blocks_num))
self.register_offload_buffers(config)
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "transformer_blocks") for i in range(self.blocks_num))
self.add_module("blocks", blocks) self.add_module("blocks", blocks)
def register_offload_buffers(self, config):
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
[QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, "transformer_blocks") for i in range(self.offload_blocks_num)]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
else:
raise NotImplementedError
class QwenImageTransformerAttentionBlock(WeightModule): class QwenImageTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, block_prefix="transformer_blocks"): def __init__(self, block_index, task, mm_type, config, is_offload_buffer=False, block_prefix="transformer_blocks"):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -44,22 +55,30 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -44,22 +55,30 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.1.weight", f"{block_prefix}.{self.block_index}.img_mod.1.weight",
f"{block_prefix}.{self.block_index}.img_mod.1.bias", f"{block_prefix}.{self.block_index}.img_mod.1.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"img_norm1", "img_norm1",
LN_WEIGHT_REGISTER["Default"](eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
) )
self.attn = QwenImageCrossAttention( self.attn = QwenImageCrossAttention(
block_index=block_index, block_prefix="transformer_blocks", task=config["task"], mm_type=mm_type, config=config, lazy_load=self.lazy_load, lazy_load_file=self.lazy_load_file block_index=block_index,
block_prefix="transformer_blocks",
task=config["task"],
mm_type=mm_type,
config=config,
is_offload_buffer=is_offload_buffer,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
) )
self.add_module("attn", self.attn) self.add_module("attn", self.attn)
self.add_module( self.add_module(
"img_norm2", "img_norm2",
LN_WEIGHT_REGISTER["Default"](eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
) )
img_mlp = QwenImageFFN( img_mlp = QwenImageFFN(
block_index=block_index, block_index=block_index,
...@@ -68,6 +87,7 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -68,6 +87,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"], task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
is_offload_buffer=is_offload_buffer,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_file=self.lazy_load_file,
) )
...@@ -79,19 +99,20 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -79,19 +99,20 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.1.weight", f"{block_prefix}.{self.block_index}.txt_mod.1.weight",
f"{block_prefix}.{self.block_index}.txt_mod.1.bias", f"{block_prefix}.{self.block_index}.txt_mod.1.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
) )
self.add_module( self.add_module(
"txt_norm1", "txt_norm1",
LN_WEIGHT_REGISTER["Default"](eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
) )
# Text doesn't need separate attention - it's handled by img_attn joint computation # Text doesn't need separate attention - it's handled by img_attn joint computation
self.add_module( self.add_module(
"txt_norm2", "txt_norm2",
LN_WEIGHT_REGISTER["Default"](eps=1e-6), LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
) )
txt_mlp = QwenImageFFN( txt_mlp = QwenImageFFN(
block_index=block_index, block_index=block_index,
...@@ -100,39 +121,15 @@ class QwenImageTransformerAttentionBlock(WeightModule): ...@@ -100,39 +121,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"], task=config["task"],
mm_type=mm_type, mm_type=mm_type,
config=config, config=config,
is_offload_buffer=is_offload_buffer,
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_file=self.lazy_load_file,
) )
self.add_module("txt_mlp", txt_mlp) self.add_module("txt_mlp", txt_mlp)
self.cpu_offload = config["cpu_offload"]
if self.cpu_offload:
offload_granularity = config.get("offload_granularity", "block")
if offload_granularity == "phase":
phase1_dict = {
"img_mod": self.img_mod,
"txt_mod": self.txt_mod,
"img_norm1": self.img_norm1,
"txt_norm1": self.txt_norm1,
}
phase2_dict = {"attn": self.attn}
phase3_dict = {
"img_norm2": self.img_norm2,
"img_mlp": self.img_mlp,
"txt_norm2": self.txt_norm2,
"txt_mlp": self.txt_mlp,
}
compute_phases = [
ComputePhase(phase1_dict),
ComputePhase(phase2_dict),
ComputePhase(phase3_dict),
]
self.add_module("compute_phases", compute_phases)
class QwenImageCrossAttention(WeightModule): class QwenImageCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -149,12 +146,12 @@ class QwenImageCrossAttention(WeightModule): ...@@ -149,12 +146,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q # norm_q
self.add_module( self.add_module(
"norm_q", "norm_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight"), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=is_offload_buffer),
) )
# norm_k # norm_k
self.add_module( self.add_module(
"norm_k", "norm_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight"), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=is_offload_buffer),
) )
# to_q # to_q
self.add_module( self.add_module(
...@@ -162,6 +159,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -162,6 +159,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_q.weight", f"{block_prefix}.{self.block_index}.attn.to_q.weight",
f"{block_prefix}.{self.block_index}.attn.to_q.bias", f"{block_prefix}.{self.block_index}.attn.to_q.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -172,6 +170,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -172,6 +170,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_k.weight", f"{block_prefix}.{self.block_index}.attn.to_k.weight",
f"{block_prefix}.{self.block_index}.attn.to_k.bias", f"{block_prefix}.{self.block_index}.attn.to_k.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -182,6 +181,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -182,6 +181,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_v.weight", f"{block_prefix}.{self.block_index}.attn.to_v.weight",
f"{block_prefix}.{self.block_index}.attn.to_v.bias", f"{block_prefix}.{self.block_index}.attn.to_v.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -192,6 +192,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -192,6 +192,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -202,6 +203,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -202,6 +203,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -212,6 +214,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -212,6 +214,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight", f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias", f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -222,6 +225,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -222,6 +225,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_out.0.weight", f"{block_prefix}.{self.block_index}.attn.to_out.0.weight",
f"{block_prefix}.{self.block_index}.attn.to_out.0.bias", f"{block_prefix}.{self.block_index}.attn.to_out.0.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -232,6 +236,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -232,6 +236,7 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_add_out.weight", f"{block_prefix}.{self.block_index}.attn.to_add_out.weight",
f"{block_prefix}.{self.block_index}.attn.to_add_out.bias", f"{block_prefix}.{self.block_index}.attn.to_add_out.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -239,12 +244,12 @@ class QwenImageCrossAttention(WeightModule): ...@@ -239,12 +244,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q # norm_added_q
self.add_module( self.add_module(
"norm_added_q", "norm_added_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight"), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=is_offload_buffer),
) )
# norm_added_k # norm_added_k
self.add_module( self.add_module(
"norm_added_k", "norm_added_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight"), RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=is_offload_buffer),
) )
# attn # attn
self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]()) self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]())
...@@ -261,7 +266,7 @@ class QwenImageCrossAttention(WeightModule): ...@@ -261,7 +266,7 @@ class QwenImageCrossAttention(WeightModule):
class QwenImageFFN(WeightModule): class QwenImageFFN(WeightModule):
def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -276,6 +281,7 @@ class QwenImageFFN(WeightModule): ...@@ -276,6 +281,7 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -285,6 +291,7 @@ class QwenImageFFN(WeightModule): ...@@ -285,6 +291,7 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias", f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -299,20 +306,3 @@ class QwenImageFFN(WeightModule): ...@@ -299,20 +306,3 @@ class QwenImageFFN(WeightModule):
for module in self._modules.values(): for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"): if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking) module.to_cuda(non_blocking=non_blocking)
class ComputePhase(WeightModule):
def __init__(self, sub_module_dict):
super().__init__()
for k, v in sub_module_dict.items():
self.add_module(k, v)
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
...@@ -10,9 +10,45 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer): ...@@ -10,9 +10,45 @@ class WanAnimateTransformerInfer(WanOffloadTransformerInfer):
self.has_post_adapter = True self.has_post_adapter = True
self.phases_num = 4 self.phases_num = 4
def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0:
self.offload_manager.init_first_buffer(blocks, block_idx // 5)
if block_idx < len(blocks) - 1:
self.offload_manager.prefetch_weights(block_idx + 1, blocks, (block_idx + 1) // 5)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out)
self.offload_manager.swap_blocks()
return x
def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
if lazy:
obj_key = (block_idx, phase_idx)
phase = self.offload_manager.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.offload_manager.cuda_buffers[0] = (obj_key, phase)
else:
self.offload_manager.init_first_buffer(blocks, block_idx // 5)
is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks, (block_idx + 1) // 5)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
self.offload_manager.swap_phases()
return x
@torch.no_grad() @torch.no_grad()
def infer_post_adapter(self, phase, x, pre_infer_out): def infer_post_adapter(self, phase, x, pre_infer_out):
if phase.is_empty(): if phase.is_empty() or phase.linear1_kv.weight is None:
return x return x
T = pre_infer_out.adapter_args["motion_vec"].shape[0] T = pre_infer_out.adapter_args["motion_vec"].shape[0]
x_motion = phase.pre_norm_motion.apply(pre_infer_out.adapter_args["motion_vec"]) x_motion = phase.pre_norm_motion.apply(pre_infer_out.adapter_args["motion_vec"])
......
...@@ -42,13 +42,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -42,13 +42,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if offload_granularity != "model": if offload_granularity != "model":
if not self.config.get("lazy_load", False): if not self.config.get("lazy_load", False):
self.weights_stream_mgr = WeightAsyncStreamManager( self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
blocks_num=self.blocks_num,
offload_ratio=self.offload_ratio,
phases_num=self.phases_num,
)
else: else:
self.weights_stream_mgr = LazyWeightAsyncStreamManager( self.offload_manager = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num, blocks_num=self.blocks_num,
offload_ratio=self.offload_ratio, offload_ratio=self.offload_ratio,
phases_num=self.phases_num, phases_num=self.phases_num,
...@@ -61,40 +57,57 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -61,40 +57,57 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
self.weights_stream_mgr.active_weights[0] = blocks[0] self.offload_manager.init_first_buffer(blocks)
self.weights_stream_mgr.active_weights[0].to_cuda()
if block_idx < len(blocks) - 1: if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks) self.offload_manager.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(blocks[block_idx], x, pre_infer_out) x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out)
self.weights_stream_mgr.swap_weights()
self.offload_manager.swap_blocks()
return x
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x return x
def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out): def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(blocks) self.offload_manager.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if block_idx == 0: if block_idx == 0:
block = self.weights_stream_mgr.pin_memory_buffer.get(block_idx) block = self.offload_manager.pin_memory_buffer.get(block_idx)
block.to_cuda() block.to_cuda()
self.weights_stream_mgr.active_weights[0] = (block_idx, block) self.offload_manager.cuda_buffers[0] = (block_idx, block)
if block_idx < len(blocks) - 1: if block_idx < len(blocks) - 1:
self.weights_stream_mgr.prefetch_weights(block_idx + 1, blocks) self.offload_manager.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream): with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(blocks[block_idx], x, pre_infer_out) x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.weights_stream_mgr.swap_weights() self.offload_manager.swap_blocks()
if block_idx == len(blocks) - 1: if block_idx == len(blocks) - 1:
self.weights_stream_mgr.pin_memory_buffer.pop_front() self.offload_manager.pin_memory_buffer.pop_front()
self.weights_stream_mgr._async_prefetch_block(blocks) self.offload_manager._async_prefetch_block(blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del ( del (
...@@ -106,31 +119,14 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -106,31 +119,14 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
return x return x
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out): def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out):
self.weights_stream_mgr.prefetch_weights_from_disk(blocks) self.offload_manager.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True) x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True)
self.weights_stream_mgr._async_prefetch_block(blocks) self.offload_manager._async_prefetch_block(blocks)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del ( del (
...@@ -148,35 +144,27 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -148,35 +144,27 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if block_idx == 0 and phase_idx == 0: if block_idx == 0 and phase_idx == 0:
if lazy: if lazy:
obj_key = (block_idx, phase_idx) obj_key = (block_idx, phase_idx)
phase = self.weights_stream_mgr.pin_memory_buffer.get(obj_key) phase = self.offload_manager.pin_memory_buffer.get(obj_key)
phase.to_cuda() phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (obj_key, phase) self.offload_manager.cuda_buffers[0] = (obj_key, phase)
else: else:
phase = blocks[block_idx].compute_phases[phase_idx] self.offload_manager.init_first_buffer(blocks)
phase.to_cuda()
self.weights_stream_mgr.active_weights[0] = (phase_idx, phase)
with torch.cuda.stream(self.weights_stream_mgr.compute_stream):
x = self.infer_phase(self.weights_stream_mgr.active_weights[0], x, pre_infer_out)
is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1 is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
if not is_last_phase: if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num next_phase_idx = (phase_idx + 1) % self.phases_num
self.weights_stream_mgr.prefetch_phase(next_block_idx, next_phase_idx, blocks) self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
self.weights_stream_mgr.swap_phases() with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
return x self.offload_manager.swap_phases()
def infer_phase(self, active_weight, x, pre_infer_out): return x
if not self.config.get("lazy_load"):
cur_phase_idx, cur_phase = active_weight
else:
(_, cur_phase_idx), cur_phase = active_weight
def infer_phase(self, cur_phase_idx, cur_phase, x, pre_infer_out):
if cur_phase_idx == 0: if cur_phase_idx == 0:
if hasattr(cur_phase, "before_proj"): if hasattr(cur_phase, "before_proj") and cur_phase.before_proj.weight is not None:
x = cur_phase.before_proj.apply(x) + pre_infer_out.x x = cur_phase.before_proj.apply(x) + pre_infer_out.x
( (
self.phase_params["shift_msa"], self.phase_params["shift_msa"],
...@@ -211,11 +199,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -211,11 +199,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.phase_params["c_shift_msa"], self.phase_params["c_shift_msa"],
self.phase_params["c_scale_msa"], self.phase_params["c_scale_msa"],
) )
x = self.post_process( x = self.post_process(x, self.phase_params["y"], self.phase_params["c_gate_msa"], pre_infer_out)
x,
self.phase_params["y"],
self.phase_params["c_gate_msa"],
)
if hasattr(cur_phase, "after_proj"): if hasattr(cur_phase, "after_proj"):
pre_infer_out.adapter_args["hints"].append(cur_phase.after_proj.apply(x)) pre_infer_out.adapter_args["hints"].append(cur_phase.after_proj.apply(x))
elif cur_phase_idx == 3: elif cur_phase_idx == 3:
......
...@@ -88,7 +88,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -88,7 +88,7 @@ class WanTransformerInfer(BaseTransformerInfer):
return x return x
def infer_block(self, block, x, pre_infer_out): def infer_block(self, block, x, pre_infer_out):
if hasattr(block.compute_phases[0], "before_proj"): if hasattr(block.compute_phases[0], "before_proj") and block.compute_phases[0].before_proj.weight is not None:
x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process( shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process(
......
...@@ -22,12 +22,12 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer): ...@@ -22,12 +22,12 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def infer_vace_blocks(self, vace_blocks, pre_infer_out): def infer_vace_blocks(self, vace_blocks, pre_infer_out):
pre_infer_out.adapter_args["hints"] = [] pre_infer_out.adapter_args["hints"] = []
self.infer_state = "vace" self.infer_state = "vace"
if hasattr(self, "weights_stream_mgr"): if hasattr(self, "offload_manager"):
self.weights_stream_mgr.init(self.vace_blocks_num, self.phases_num, self.offload_ratio) self.offload_manager.init_cuda_buffer(self.vace_offload_block_buffers, self.vace_offload_phase_buffers)
self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out) self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out)
self.infer_state = "base" self.infer_state = "base"
if hasattr(self, "weights_stream_mgr"): if hasattr(self, "offload_manager"):
self.weights_stream_mgr.init(self.blocks_num, self.phases_num, self.offload_ratio) self.offload_manager.init_cuda_buffer(self.offload_block_buffers, self.offload_phase_buffers)
def post_process(self, x, y, c_gate_msa, pre_infer_out): def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out) x = super().post_process(x, y, c_gate_msa, pre_infer_out)
......
...@@ -365,6 +365,8 @@ class WanModel(CompiledMethodsMixin): ...@@ -365,6 +365,8 @@ class WanModel(CompiledMethodsMixin):
self.pre_infer = self.pre_infer_class(self.config) self.pre_infer = self.pre_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
......
...@@ -19,6 +19,14 @@ class WanVaceModel(WanModel): ...@@ -19,6 +19,14 @@ class WanVaceModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
def _init_infer(self):
super()._init_infer()
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_block_buffers = self.transformer_weights.offload_block_buffers
self.transformer_infer.offload_phase_buffers = self.transformer_weights.offload_phase_buffers
self.transformer_infer.vace_offload_block_buffers = self.transformer_weights.vace_offload_block_buffers
self.transformer_infer.vace_offload_phase_buffers = self.transformer_weights.vace_offload_phase_buffers
def _init_infer_class(self): def _init_infer_class(self):
self.pre_infer_class = WanPreInfer self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer self.post_infer_class = WanPostInfer
......
...@@ -6,7 +6,12 @@ from lightx2v.common.modules.weight_module import WeightModule ...@@ -6,7 +6,12 @@ from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
)
class WanAnimateTransformerWeights(WanTransformerWeights): class WanAnimateTransformerWeights(WanTransformerWeights):
...@@ -18,39 +23,74 @@ class WanAnimateTransformerWeights(WanTransformerWeights): ...@@ -18,39 +23,74 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self.blocks[i].compute_phases.append(WanAnimateFuserBlock(self.config, i // 5, "face_adapter.fuser_blocks", self.mm_type)) self.blocks[i].compute_phases.append(WanAnimateFuserBlock(self.config, i // 5, "face_adapter.fuser_blocks", self.mm_type))
else: else:
self.blocks[i].compute_phases.append(WeightModule()) self.blocks[i].compute_phases.append(WeightModule())
self._add_animate_fuserblock_to_offload_buffers()
def _add_animate_fuserblock_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None:
for i in range(self.offload_blocks_num):
self.offload_block_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True))
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None:
self.offload_phase_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True))
class WanAnimateFuserBlock(WeightModule): class WanAnimateFuserBlock(WeightModule):
def __init__(self, config, block_index, block_prefix, mm_type): def __init__(self, config, block_index, block_prefix, mm_type, is_offload_buffer=False):
super().__init__() super().__init__()
self.config = config self.config = config
self.is_post_adapter = True
lazy_load = config.get("lazy_load", False) lazy_load = config.get("lazy_load", False)
if lazy_load: if lazy_load:
lazy_load_path = os.path.join(config.dit_quantized_ckpt, f"{block_prefix[:-1]}_{block_index}.safetensors") lazy_load_path = os.path.join(
config.dit_quantized_ckpt,
f"{block_prefix[:-1]}_{block_index}.safetensors",
)
lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu") lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else: else:
lazy_load_file = None lazy_load_file = None
self.add_module( self.add_module(
"linear1_kv", "linear1_kv",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear1_kv.weight", f"{block_prefix}.{block_index}.linear1_kv.bias", lazy_load, lazy_load_file), MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_kv.weight",
f"{block_prefix}.{block_index}.linear1_kv.bias",
is_offload_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
),
) )
self.add_module( self.add_module(
"linear1_q", "linear1_q",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear1_q.weight", f"{block_prefix}.{block_index}.linear1_q.bias", lazy_load, lazy_load_file), MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_q.weight",
f"{block_prefix}.{block_index}.linear1_q.bias",
is_offload_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
),
) )
self.add_module( self.add_module(
"linear2", "linear2",
MM_WEIGHT_REGISTER[mm_type](f"{block_prefix}.{block_index}.linear2.weight", f"{block_prefix}.{block_index}.linear2.bias", lazy_load, lazy_load_file), MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear2.weight",
f"{block_prefix}.{block_index}.linear2.bias",
is_offload_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
),
) )
self.add_module( self.add_module(
"q_norm", "q_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.q_norm.weight", f"{block_prefix}.{block_index}.q_norm.weight",
is_offload_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter,
), ),
) )
...@@ -58,8 +98,10 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -58,8 +98,10 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm", "k_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"]( RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.k_norm.weight", f"{block_prefix}.{block_index}.k_norm.weight",
is_offload_buffer,
lazy_load, lazy_load,
lazy_load_file, lazy_load_file,
self.is_post_adapter,
), ),
) )
...@@ -67,6 +109,7 @@ class WanAnimateFuserBlock(WeightModule): ...@@ -67,6 +109,7 @@ class WanAnimateFuserBlock(WeightModule):
"pre_norm_feat", "pre_norm_feat",
LN_WEIGHT_REGISTER["Default"](), LN_WEIGHT_REGISTER["Default"](),
) )
self.add_module( self.add_module(
"pre_norm_motion", "pre_norm_motion",
LN_WEIGHT_REGISTER["Default"](), LN_WEIGHT_REGISTER["Default"](),
......
...@@ -18,14 +18,46 @@ class WanAudioTransformerWeights(WanTransformerWeights): ...@@ -18,14 +18,46 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self.task, self.task,
self.mm_type, self.mm_type,
self.config, self.config,
False,
self.blocks[i].lazy_load, self.blocks[i].lazy_load,
self.blocks[i].lazy_load_file, self.blocks[i].lazy_load_file,
) )
) )
self._add_audio_adapter_ca_to_offload_buffers()
def _add_audio_adapter_ca_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None:
for i in range(self.offload_blocks_num):
offload_buffer = self.offload_block_buffers[i]
adapter_ca = WanAudioAdapterCA(
block_index=i,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
is_offload_buffer=True,
lazy_load=offload_buffer.lazy_load,
lazy_load_file=offload_buffer.lazy_load_file,
)
offload_buffer.compute_phases.append(adapter_ca)
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None:
adapter_ca = WanAudioAdapterCA(
block_index=0,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
is_offload_buffer=True,
lazy_load=self.blocks[0].lazy_load,
lazy_load_file=self.blocks[0].lazy_load_file,
)
self.offload_phase_buffers.append(adapter_ca)
class WanAudioAdapterCA(WeightModule): class WanAudioAdapterCA(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -39,6 +71,7 @@ class WanAudioAdapterCA(WeightModule): ...@@ -39,6 +71,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_q.weight", f"{block_prefix}.{block_index}.to_q.weight",
f"{block_prefix}.{block_index}.to_q.bias", f"{block_prefix}.{block_index}.to_q.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -49,6 +82,7 @@ class WanAudioAdapterCA(WeightModule): ...@@ -49,6 +82,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_kv.weight", f"{block_prefix}.{block_index}.to_kv.weight",
f"{block_prefix}.{block_index}.to_kv.bias", f"{block_prefix}.{block_index}.to_kv.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -59,6 +93,7 @@ class WanAudioAdapterCA(WeightModule): ...@@ -59,6 +93,7 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_out.weight", f"{block_prefix}.{block_index}.to_out.weight",
f"{block_prefix}.{block_index}.to_out.bias", f"{block_prefix}.{block_index}.to_out.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -69,6 +104,7 @@ class WanAudioAdapterCA(WeightModule): ...@@ -69,6 +104,7 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{block_index}.norm_kv.weight", f"{block_prefix}.{block_index}.norm_kv.weight",
f"{block_prefix}.{block_index}.norm_kv.bias", f"{block_prefix}.{block_index}.norm_kv.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -83,6 +119,7 @@ class WanAudioAdapterCA(WeightModule): ...@@ -83,6 +119,7 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate", "shift_scale_gate",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{block_index}.shift_scale_gate", f"{block_prefix}.{block_index}.shift_scale_gate",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -24,6 +24,7 @@ class WanTransformerWeights(WeightModule): ...@@ -24,6 +24,7 @@ class WanTransformerWeights(WeightModule):
if config.get("do_mm_calib", False): if config.get("do_mm_calib", False):
self.mm_type = "Calib" self.mm_type = "Calib"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]) self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.register_offload_buffers(config)
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
# non blocks weights # non blocks weights
...@@ -31,6 +32,35 @@ class WanTransformerWeights(WeightModule): ...@@ -31,6 +32,35 @@ class WanTransformerWeights(WeightModule):
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def register_offload_buffers(self, config):
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
[
WanTransformerAttentionBlock(
i,
self.task,
self.mm_type,
self.config,
True,
)
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
elif config["offload_granularity"] == "phase":
self.offload_phase_buffers = WanTransformerAttentionBlock(
0,
self.task,
self.mm_type,
self.config,
True,
).compute_phases
self.add_module("offload_phase_buffers", self.offload_phase_buffers)
self.offload_block_buffers = None
def clear(self): def clear(self):
for block in self.blocks: for block in self.blocks:
for phase in block.compute_phases: for phase in block.compute_phases:
...@@ -48,12 +78,21 @@ class WanTransformerWeights(WeightModule): ...@@ -48,12 +78,21 @@ class WanTransformerWeights(WeightModule):
class WanTransformerAttentionBlock(WeightModule): class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, block_prefix="blocks"): def __init__(
self,
block_index,
task,
mm_type,
config,
is_offload_buffer=False,
block_prefix="blocks",
):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
self.task = task self.task = task
self.config = config self.config = config
self.is_offload_buffer = is_offload_buffer
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
...@@ -71,6 +110,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -71,6 +110,7 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -80,6 +120,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -80,6 +120,7 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -89,6 +130,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -89,6 +130,7 @@ class WanTransformerAttentionBlock(WeightModule):
task, task,
mm_type, mm_type,
config, config,
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -99,7 +141,17 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -99,7 +141,17 @@ class WanTransformerAttentionBlock(WeightModule):
class WanSelfAttention(WeightModule): class WanSelfAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(
self,
block_index,
block_prefix,
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -119,6 +171,7 @@ class WanSelfAttention(WeightModule): ...@@ -119,6 +171,7 @@ class WanSelfAttention(WeightModule):
"modulation", "modulation",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.modulation", f"{block_prefix}.{self.block_index}.modulation",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -134,6 +187,7 @@ class WanSelfAttention(WeightModule): ...@@ -134,6 +187,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.q.weight", f"{block_prefix}.{self.block_index}.self_attn.q.weight",
f"{block_prefix}.{self.block_index}.self_attn.q.bias", f"{block_prefix}.{self.block_index}.self_attn.q.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -144,6 +198,7 @@ class WanSelfAttention(WeightModule): ...@@ -144,6 +198,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.k.weight", f"{block_prefix}.{self.block_index}.self_attn.k.weight",
f"{block_prefix}.{self.block_index}.self_attn.k.bias", f"{block_prefix}.{self.block_index}.self_attn.k.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -153,6 +208,7 @@ class WanSelfAttention(WeightModule): ...@@ -153,6 +208,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.v.weight", f"{block_prefix}.{self.block_index}.self_attn.v.weight",
f"{block_prefix}.{self.block_index}.self_attn.v.bias", f"{block_prefix}.{self.block_index}.self_attn.v.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -162,6 +218,7 @@ class WanSelfAttention(WeightModule): ...@@ -162,6 +218,7 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.o.weight", f"{block_prefix}.{self.block_index}.self_attn.o.weight",
f"{block_prefix}.{self.block_index}.self_attn.o.bias", f"{block_prefix}.{self.block_index}.self_attn.o.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -170,6 +227,7 @@ class WanSelfAttention(WeightModule): ...@@ -170,6 +227,7 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q", "self_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight", f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -178,6 +236,7 @@ class WanSelfAttention(WeightModule): ...@@ -178,6 +236,7 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k", "self_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight", f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -192,7 +251,12 @@ class WanSelfAttention(WeightModule): ...@@ -192,7 +251,12 @@ class WanSelfAttention(WeightModule):
context_length=self.config.get("svg_context_length", 0), context_length=self.config.get("svg_context_length", 0),
sparsity=self.config.get("svg_sparsity", 0.25), sparsity=self.config.get("svg_sparsity", 0.25),
) )
if self.config["self_attn_1_type"] in ["svg_attn", "radial_attn", "nbhd_attn", "nbhd_attn_flashinfer"]: if self.config["self_attn_1_type"] in [
"svg_attn",
"radial_attn",
"nbhd_attn",
"nbhd_attn_flashinfer",
]:
attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"] attention_weights_cls.attnmap_frame_num = self.config["attnmap_frame_num"]
# nbhd_attn setting # nbhd_attn setting
if self.config["self_attn_1_type"] in ["nbhd_attn", "nbhd_attn_flashinfer"]: if self.config["self_attn_1_type"] in ["nbhd_attn", "nbhd_attn_flashinfer"]:
...@@ -204,13 +268,17 @@ class WanSelfAttention(WeightModule): ...@@ -204,13 +268,17 @@ class WanSelfAttention(WeightModule):
self.add_module("self_attn_1", attention_weights_cls()) self.add_module("self_attn_1", attention_weights_cls())
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.add_module("self_attn_1_parallel", ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")]()) self.add_module(
"self_attn_1_parallel",
ATTN_WEIGHT_REGISTER[self.config["parallel"].get("seq_p_attn_type", "ulysses")](),
)
if self.quant_method in ["advanced_ptq"]: if self.quant_method in ["advanced_ptq"]:
self.add_module( self.add_module(
"smooth_norm1_weight", "smooth_norm1_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.weight", f"{block_prefix}.{self.block_index}.affine_norm1.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -219,6 +287,7 @@ class WanSelfAttention(WeightModule): ...@@ -219,6 +287,7 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias", "smooth_norm1_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.bias", f"{block_prefix}.{self.block_index}.affine_norm1.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -226,7 +295,17 @@ class WanSelfAttention(WeightModule): ...@@ -226,7 +295,17 @@ class WanSelfAttention(WeightModule):
class WanCrossAttention(WeightModule): class WanCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(
self,
block_index,
block_prefix,
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -245,6 +324,7 @@ class WanCrossAttention(WeightModule): ...@@ -245,6 +324,7 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER["Default"]( LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight", f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias", f"{block_prefix}.{self.block_index}.norm3.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -254,6 +334,7 @@ class WanCrossAttention(WeightModule): ...@@ -254,6 +334,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight", f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias", f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -263,6 +344,7 @@ class WanCrossAttention(WeightModule): ...@@ -263,6 +344,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight", f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias", f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -272,6 +354,7 @@ class WanCrossAttention(WeightModule): ...@@ -272,6 +354,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight", f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias", f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -281,6 +364,7 @@ class WanCrossAttention(WeightModule): ...@@ -281,6 +364,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight", f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias", f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -289,6 +373,7 @@ class WanCrossAttention(WeightModule): ...@@ -289,6 +373,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q", "cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -297,6 +382,7 @@ class WanCrossAttention(WeightModule): ...@@ -297,6 +382,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k", "cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -309,6 +395,7 @@ class WanCrossAttention(WeightModule): ...@@ -309,6 +395,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias", f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -318,6 +405,7 @@ class WanCrossAttention(WeightModule): ...@@ -318,6 +405,7 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias", f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -326,6 +414,7 @@ class WanCrossAttention(WeightModule): ...@@ -326,6 +414,7 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img", "cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER[self.attn_rms_type]( RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight", f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -334,7 +423,17 @@ class WanCrossAttention(WeightModule): ...@@ -334,7 +423,17 @@ class WanCrossAttention(WeightModule):
class WanFFN(WeightModule): class WanFFN(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file): def __init__(
self,
block_index,
block_prefix,
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
self.mm_type = mm_type self.mm_type = mm_type
...@@ -354,6 +453,7 @@ class WanFFN(WeightModule): ...@@ -354,6 +453,7 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.0.weight", f"{block_prefix}.{self.block_index}.ffn.0.weight",
f"{block_prefix}.{self.block_index}.ffn.0.bias", f"{block_prefix}.{self.block_index}.ffn.0.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -363,6 +463,7 @@ class WanFFN(WeightModule): ...@@ -363,6 +463,7 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.2.weight", f"{block_prefix}.{self.block_index}.ffn.2.weight",
f"{block_prefix}.{self.block_index}.ffn.2.bias", f"{block_prefix}.{self.block_index}.ffn.2.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -373,6 +474,7 @@ class WanFFN(WeightModule): ...@@ -373,6 +474,7 @@ class WanFFN(WeightModule):
"smooth_norm2_weight", "smooth_norm2_weight",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.weight", f"{block_prefix}.{self.block_index}.affine_norm3.weight",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -381,6 +483,7 @@ class WanFFN(WeightModule): ...@@ -381,6 +483,7 @@ class WanFFN(WeightModule):
"smooth_norm2_bias", "smooth_norm2_bias",
TENSOR_REGISTER["Default"]( TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.bias", f"{block_prefix}.{self.block_index}.affine_norm3.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -13,16 +13,31 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -13,16 +13,31 @@ class WanVaceTransformerWeights(WanTransformerWeights):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.patch_size = (1, 2, 2) self.patch_size = (1, 2, 2)
self.register_offload_buffers(config)
self.vace_blocks = WeightModuleList( self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config["vace_layers"]))] [WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
) )
self.add_module("vace_blocks", self.vace_blocks) self.add_module("vace_blocks", self.vace_blocks)
self.add_module( self.add_module(
"vace_patch_embedding", "vace_patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("vace_patch_embedding.weight", "vace_patch_embedding.bias", stride=self.patch_size), CONV3D_WEIGHT_REGISTER["Default"]("vace_patch_embedding.weight", "vace_patch_embedding.bias", stride=self.patch_size),
) )
def register_offload_buffers(self, config):
super().register_offload_buffers(config)
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.vace_offload_block_buffers = WeightModuleList(
[
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"),
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"),
]
)
self.add_module("vace_offload_block_buffers", self.vace_offload_block_buffers)
self.vace_offload_phase_buffers = None
elif config["offload_granularity"] == "phase":
raise NotImplementedError
def clear(self): def clear(self):
super().clear() super().clear()
for vace_block in self.vace_blocks: for vace_block in self.vace_blocks:
...@@ -39,14 +54,15 @@ class WanVaceTransformerWeights(WanTransformerWeights): ...@@ -39,14 +54,15 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, block_prefix): def __init__(self, base_block_idx, block_index, task, mm_type, config, is_offload_buffer, block_prefix):
super().__init__(block_index, task, mm_type, config, block_prefix) super().__init__(block_index, task, mm_type, config, is_offload_buffer, block_prefix)
if base_block_idx == 0: if base_block_idx == 0:
self.compute_phases[0].add_module( self.compute_phases[0].add_module(
"before_proj", "before_proj",
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight", f"{block_prefix}.{self.block_index}.before_proj.weight",
f"{block_prefix}.{self.block_index}.before_proj.bias", f"{block_prefix}.{self.block_index}.before_proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
...@@ -57,6 +73,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock): ...@@ -57,6 +73,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER[self.mm_type]( MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight", f"{block_prefix}.{self.block_index}.after_proj.weight",
f"{block_prefix}.{self.block_index}.after_proj.bias", f"{block_prefix}.{self.block_index}.after_proj.bias",
is_offload_buffer,
self.lazy_load, self.lazy_load,
self.lazy_load_file, self.lazy_load_file,
), ),
......
...@@ -78,21 +78,27 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -78,21 +78,27 @@ class MultiDistillModelStruct(MultiModelStruct):
class Wan22MoeDistillRunner(WanDistillRunner): class Wan22MoeDistillRunner(WanDistillRunner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None): if self.config.get("dit_quantized", False) and self.config.get("high_noise_quantized_ckpt", None):
self.high_noise_model_path = self.config["high_noise_quantized_ckpt"] self.high_noise_model_path = self.config["high_noise_quantized_ckpt"]
elif self.config.get("high_noise_original_ckpt", None): elif self.config.get("high_noise_original_ckpt", None):
self.high_noise_model_path = self.config["high_noise_original_ckpt"] self.high_noise_model_path = self.config["high_noise_original_ckpt"]
else:
self.high_noise_model_path = os.path.join(self.config["model_path"], "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
self.high_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "high_noise_model")
if not os.path.isdir(self.high_noise_model_path):
raise FileNotFoundError(f"High Noise Model does not find")
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None): if self.config.get("dit_quantized", False) and self.config.get("low_noise_quantized_ckpt", None):
self.low_noise_model_path = self.config["low_noise_quantized_ckpt"] self.low_noise_model_path = self.config["low_noise_quantized_ckpt"]
elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None): elif not self.config.get("dit_quantized", False) and self.config.get("low_noise_original_ckpt", None):
self.low_noise_model_path = self.config["low_noise_original_ckpt"] self.low_noise_model_path = self.config["low_noise_original_ckpt"]
else:
self.low_noise_model_path = os.path.join(self.config["model_path"], "low_noise_model")
if not os.path.isdir(self.low_noise_model_path):
self.low_noise_model_path = os.path.join(self.config["model_path"], "distill_models", "low_noise_model")
if not os.path.isdir(self.high_noise_model_path):
raise FileNotFoundError(f"Low Noise Model does not find")
def load_transformer(self): def load_transformer(self):
use_high_lora, use_low_lora = False, False use_high_lora, use_low_lora = False, False
......
decord
peft
onnxruntime
pandas
matplotlib
-e git+https://github.com/facebookresearch/sam2.git@0e78a118995e66bb27d78518c4bd9a3e95b4e266#egg=SAM-2
loguru
sentencepiece
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment