Unverified Commit 74eeb429 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

reconstruct disk offload and fix lightx2v_platform bugs (#558)


Co-authored-by: default avatarhelloyongyang <yongyang1030@163.com>
parent f7cdbcb5
......@@ -37,3 +37,13 @@ class QwenImagePostWeights(WeightModule):
self.lazy_load_file,
),
)
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
......@@ -28,3 +28,13 @@ class QwenImagePreWeights(WeightModule):
self.add_module(
"time_text_embed_timestep_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_text_embed.timestep_embedder.linear_2.weight", "time_text_embed.timestep_embedder.linear_2.bias")
)
def to_cpu(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cpu"):
module.to_cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=True):
for module in self._modules.values():
if module is not None and hasattr(module, "to_cuda"):
module.to_cuda(non_blocking=non_blocking)
......@@ -15,7 +15,7 @@ class QwenImageTransformerWeights(WeightModule):
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "transformer_blocks") for i in range(self.blocks_num))
blocks = WeightModuleList(QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, False, "transformer_blocks") for i in range(self.blocks_num))
self.register_offload_buffers(config)
self.add_module("blocks", blocks)
......@@ -23,17 +23,17 @@ class QwenImageTransformerWeights(WeightModule):
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
[QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, "transformer_blocks") for i in range(self.offload_blocks_num)]
self.offload_block_cuda_buffers = WeightModuleList(
[QwenImageTransformerAttentionBlock(i, self.task, self.mm_type, self.config, True, False, "transformer_blocks") for i in range(self.offload_blocks_num)]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers)
self.offload_phase_cuda_buffers = None
else:
raise NotImplementedError
class QwenImageTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, is_offload_buffer=False, block_prefix="transformer_blocks"):
def __init__(self, block_index, task, mm_type, config, create_cuda_buffer=False, create_cpu_buffer=False, block_prefix="transformer_blocks"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -55,14 +55,15 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.img_mod.1.weight",
f"{block_prefix}.{self.block_index}.img_mod.1.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"img_norm1",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
)
self.attn = QwenImageCrossAttention(
block_index=block_index,
......@@ -70,7 +71,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"],
mm_type=mm_type,
config=config,
is_offload_buffer=is_offload_buffer,
create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
......@@ -78,7 +80,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self.add_module(
"img_norm2",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
)
img_mlp = QwenImageFFN(
block_index=block_index,
......@@ -87,7 +89,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"],
mm_type=mm_type,
config=config,
is_offload_buffer=is_offload_buffer,
create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
......@@ -99,20 +102,21 @@ class QwenImageTransformerAttentionBlock(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.txt_mod.1.weight",
f"{block_prefix}.{self.block_index}.txt_mod.1.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"txt_norm1",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
)
# Text doesn't need separate attention - it's handled by img_attn joint computation
self.add_module(
"txt_norm2",
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=is_offload_buffer, eps=1e-6),
LN_WEIGHT_REGISTER["Default"](create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer, eps=1e-6),
)
txt_mlp = QwenImageFFN(
block_index=block_index,
......@@ -121,7 +125,8 @@ class QwenImageTransformerAttentionBlock(WeightModule):
task=config["task"],
mm_type=mm_type,
config=config,
is_offload_buffer=is_offload_buffer,
create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
......@@ -129,7 +134,7 @@ class QwenImageTransformerAttentionBlock(WeightModule):
class QwenImageCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -146,12 +151,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_q
self.add_module(
"norm_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=is_offload_buffer),
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
)
# norm_k
self.add_module(
"norm_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=is_offload_buffer),
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
)
# to_q
self.add_module(
......@@ -159,7 +164,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_q.weight",
f"{block_prefix}.{self.block_index}.attn.to_q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -170,7 +176,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_k.weight",
f"{block_prefix}.{self.block_index}.attn.to_k.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -181,7 +188,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_v.weight",
f"{block_prefix}.{self.block_index}.attn.to_v.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -192,7 +200,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_q_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_q_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -203,7 +212,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_k_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_k_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -214,7 +224,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.add_v_proj.weight",
f"{block_prefix}.{self.block_index}.attn.add_v_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -225,7 +236,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_out.0.weight",
f"{block_prefix}.{self.block_index}.attn.to_out.0.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -236,7 +248,8 @@ class QwenImageCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.attn.to_add_out.weight",
f"{block_prefix}.{self.block_index}.attn.to_add_out.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -244,12 +257,12 @@ class QwenImageCrossAttention(WeightModule):
# norm_added_q
self.add_module(
"norm_added_q",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=is_offload_buffer),
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_q.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
)
# norm_added_k
self.add_module(
"norm_added_k",
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=is_offload_buffer),
RMS_WEIGHT_REGISTER["fp32_variance"](f"{block_prefix}.{block_index}.attn.norm_added_k.weight", create_cuda_buffer=create_cuda_buffer, create_cpu_buffer=create_cpu_buffer),
)
# attn
self.add_module("calculate", ATTN_WEIGHT_REGISTER[self.attn_type]())
......@@ -266,7 +279,7 @@ class QwenImageCrossAttention(WeightModule):
class QwenImageFFN(WeightModule):
def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, ffn_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -281,7 +294,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.0.proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -291,7 +305,8 @@ class QwenImageFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.weight",
f"{block_prefix}.{self.block_index}.{ffn_prefix}.net.2.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......
......@@ -7,6 +7,7 @@ import torch.nn.functional as F
from lightx2v.common.transformer_infer.transformer_infer import BaseTaylorCachingTransformerInfer
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
class WanTransformerInferCaching(WanOffloadTransformerInfer):
......@@ -56,7 +57,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(((modulated_inp - self.previous_e0_even.cuda()).abs().mean() / self.previous_e0_even.cuda().abs().mean()).cpu().item())
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_e0_even.to(AI_DEVICE)).abs().mean() / self.previous_e0_even.to(AI_DEVICE).abs().mean()).cpu().item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
......@@ -72,7 +75,7 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd.cuda()).abs().mean() / self.previous_e0_odd.cuda().abs().mean()).cpu().item())
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_e0_odd.to(AI_DEVICE)).abs().mean() / self.previous_e0_odd.to(AI_DEVICE).abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
......@@ -149,9 +152,9 @@ class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def infer_using_cache(self, x):
if self.scheduler.infer_condition:
x.add_(self.previous_residual_even.cuda())
x.add_(self.previous_residual_even.to(AI_DEVICE))
else:
x.add_(self.previous_residual_odd.cuda())
x.add_(self.previous_residual_odd.to(AI_DEVICE))
return x
def clear(self):
......@@ -1075,7 +1078,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
def infer_using_cache(self, x):
residual_x = self.residual_cache[self.scheduler.infer_condition]
x.add_(residual_x.cuda())
x.add_(residual_x.to(AI_DEVICE))
return x
def clear(self):
......
import torch
from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager,
WeightAsyncStreamManager,
)
from lightx2v.common.offload.manager import WeightAsyncStreamManager
from lightx2v.models.networks.wan.infer.transformer_infer import WanTransformerInfer
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
class WanOffloadTransformerInfer(WanTransformerInfer):
def __init__(self, config):
super().__init__(config)
if self.config.get("cpu_offload", False):
if "offload_ratio" in self.config:
self.offload_ratio = self.config["offload_ratio"]
else:
self.offload_ratio = 1
offload_granularity = self.config.get("offload_granularity", "block")
if offload_granularity == "block":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_blocks_offload
else:
self.infer_func = self.infer_with_blocks_lazy_offload
self.infer_func = self.infer_with_blocks_offload
elif offload_granularity == "phase":
if not self.config.get("lazy_load", False):
self.infer_func = self.infer_with_phases_offload
else:
self.infer_func = self.infer_with_phases_lazy_offload
self.infer_func = self.infer_with_phases_offload
self.phase_params = {
"shift_msa": None,
"scale_msa": None,
......@@ -41,121 +31,54 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.infer_func = self.infer_without_offload
if offload_granularity != "model":
if not self.config.get("lazy_load", False):
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
else:
self.offload_manager = LazyWeightAsyncStreamManager(
blocks_num=self.blocks_num,
offload_ratio=self.offload_ratio,
phases_num=self.phases_num,
num_disk_workers=self.config.get("num_disk_workers", 2),
max_memory=self.config.get("max_memory", 2),
offload_gra=offload_granularity,
)
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0:
if self.offload_manager.need_init_first_buffer:
self.offload_manager.init_first_buffer(blocks)
if block_idx < len(blocks) - 1:
self.offload_manager.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
self.offload_manager.prefetch_weights((block_idx + 1) % len(blocks), blocks)
with torch_device_module.stream(self.offload_manager.compute_stream):
x = self.infer_block(self.offload_manager.cuda_buffers[0], x, pre_infer_out)
self.offload_manager.swap_blocks()
return x
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, False)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_with_blocks_lazy_offload(self, blocks, x, pre_infer_out):
self.offload_manager.prefetch_weights_from_disk(blocks)
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if block_idx == 0:
block = self.offload_manager.pin_memory_buffer.get(block_idx)
block.to_cuda()
self.offload_manager.cuda_buffers[0] = (block_idx, block)
if block_idx < len(blocks) - 1:
self.offload_manager.prefetch_weights(block_idx + 1, blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
self.offload_manager.swap_blocks()
if block_idx == len(blocks) - 1:
self.offload_manager.pin_memory_buffer.pop_front()
self.offload_manager._async_prefetch_block(blocks)
if self.clean_cuda_cache:
del (
pre_infer_out.embed0,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache()
torch_device_module.empty_cache()
return x
def infer_with_phases_lazy_offload(self, blocks, x, pre_infer_out):
self.offload_manager.prefetch_weights_from_disk(blocks)
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_phases(block_idx, blocks, x, pre_infer_out, True)
self.offload_manager._async_prefetch_block(blocks)
x = self.infer_phases(block_idx, blocks, x, pre_infer_out)
if self.clean_cuda_cache:
del (
self.phase_params["attn_out"],
self.phase_params["y_out"],
self.phase_params["y"],
)
torch.cuda.empty_cache()
torch_device_module.empty_cache()
if self.clean_cuda_cache:
self.clear_offload_params(pre_infer_out)
return x
def infer_phases(self, block_idx, blocks, x, pre_infer_out, lazy):
def infer_phases(self, block_idx, blocks, x, pre_infer_out):
for phase_idx in range(self.phases_num):
if block_idx == 0 and phase_idx == 0:
if lazy:
obj_key = (block_idx, phase_idx)
phase = self.offload_manager.pin_memory_buffer.get(obj_key)
phase.to_cuda()
self.offload_manager.cuda_buffers[0] = (obj_key, phase)
else:
self.offload_manager.init_first_buffer(blocks)
is_last_phase = block_idx == len(blocks) - 1 and phase_idx == self.phases_num - 1
if not is_last_phase:
next_block_idx = block_idx + 1 if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
with torch.cuda.stream(self.offload_manager.compute_stream):
if self.offload_manager.need_init_first_buffer:
self.offload_manager.init_first_buffer(blocks)
next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
with torch_device_module.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
self.offload_manager.swap_phases()
......@@ -176,10 +99,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
) = self.pre_process(cur_phase.modulation, pre_infer_out.embed0)
self.phase_params["y_out"] = self.infer_self_attn(
cur_phase,
pre_infer_out.grid_sizes.tuple,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
self.phase_params["shift_msa"],
self.phase_params["scale_msa"],
)
......@@ -219,7 +139,6 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
del (
pre_infer_out.embed0,
pre_infer_out.freqs,
pre_infer_out.context,
)
torch.cuda.empty_cache()
torch_device_module.empty_cache()
......@@ -6,6 +6,7 @@ import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
def sinusoidal_embedding_1d(dim, position):
......@@ -50,7 +51,7 @@ class WanSFPreInfer(WanPreInfer):
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).cuda()
).to(AI_DEVICE)
def time_embedding(self, weights, embed):
embed = weights.time_embedding_0.apply(embed)
......
......@@ -9,6 +9,7 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config["vace_layers"])}
def infer(self, weights, pre_infer_out):
self.get_scheduler_values()
pre_infer_out.c = self.vace_pre_process(weights.vace_patch_embedding, pre_infer_out.vace_context)
self.infer_vace_blocks(weights.vace_blocks, pre_infer_out)
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
......@@ -23,11 +24,11 @@ class WanVaceTransformerInfer(WanOffloadTransformerInfer):
pre_infer_out.adapter_args["hints"] = []
self.infer_state = "vace"
if hasattr(self, "offload_manager"):
self.offload_manager.init_cuda_buffer(self.vace_offload_block_buffers, self.vace_offload_phase_buffers)
self.offload_manager.init_cuda_buffer(self.vace_offload_block_cuda_buffers, self.vace_offload_phase_cuda_buffers)
self.infer_func(vace_blocks, pre_infer_out.c, pre_infer_out)
self.infer_state = "base"
if hasattr(self, "offload_manager"):
self.offload_manager.init_cuda_buffer(self.offload_block_buffers, self.offload_phase_buffers)
self.offload_manager.init_cuda_buffer(self.offload_block_cuda_buffers, self.offload_phase_cuda_buffers)
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
......
......@@ -47,7 +47,10 @@ class WanModel(CompiledMethodsMixin):
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
self.remove_keys = []
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
self.remove_keys.extend(["blocks."])
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
......@@ -146,7 +149,7 @@ class WanModel(CompiledMethodsMixin):
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config["parallel"]:
if self.device.type != "cpu" and dist.is_initialized():
device = dist.get_rank()
else:
device = str(self.device)
......@@ -169,6 +172,10 @@ class WanModel(CompiledMethodsMixin):
else:
safetensors_files = [safetensors_path]
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
......@@ -205,6 +212,10 @@ class WanModel(CompiledMethodsMixin):
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
weight_dict = {}
for safetensor_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
......@@ -237,28 +248,6 @@ class WanModel(CompiledMethodsMixin):
return weight_dict
def _load_quant_split_ckpt(self, unified_dtype, sensitive_layer): # Need rewrite
lazy_load_model_path = self.dit_quantized_ckpt
logger.info(f"Loading splited quant model from {lazy_load_model_path}")
pre_post_weight_dict = {}
safetensor_path = os.path.join(lazy_load_model_path, "non_block.safetensors")
with safe_open(safetensor_path, framework="pt", device="cpu") as f:
for k in f.keys():
if f.get_tensor(k).dtype in [
torch.float16,
torch.bfloat16,
torch.float,
]:
if unified_dtype or all(s not in k for s in sensitive_layer):
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(GET_SENSITIVE_DTYPE()).to(self.device)
else:
pre_post_weight_dict[k] = f.get_tensor(k).to(self.device)
return pre_post_weight_dict
def _load_gguf_ckpt(self, gguf_path):
state_dict = load_gguf_sd_ckpt(gguf_path, to_device=self.device)
return state_dict
......@@ -285,10 +274,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
else:
# Load quantized weights
if not self.config.get("lazy_load", False):
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
else:
weight_dict = self._load_quant_split_ckpt(unified_dtype, sensitive_layer)
weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)
if self.config.get("device_mesh") is not None and self.config.get("load_from_rank0", False):
weight_dict = self._load_weights_from_rank0(weight_dict, is_weight_loader)
......@@ -302,7 +288,10 @@ class WanModel(CompiledMethodsMixin):
# Initialize weight containers
self.pre_weight = self.pre_weight_class(self.config)
self.transformer_weights = self.transformer_weight_class(self.config)
if self.lazy_load:
self.transformer_weights = self.transformer_weight_class(self.config, self.lazy_load_path)
else:
self.transformer_weights = self.transformer_weight_class(self.config)
if not self._should_init_empty_model():
self._apply_weights()
......@@ -383,7 +372,9 @@ class WanModel(CompiledMethodsMixin):
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_buffers, self.transformer_weights.offload_phase_buffers)
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......
......@@ -22,10 +22,15 @@ class WanVaceModel(WanModel):
def _init_infer(self):
super()._init_infer()
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_block_buffers = self.transformer_weights.offload_block_buffers
self.transformer_infer.offload_phase_buffers = self.transformer_weights.offload_phase_buffers
self.transformer_infer.vace_offload_block_buffers = self.transformer_weights.vace_offload_block_buffers
self.transformer_infer.vace_offload_phase_buffers = self.transformer_weights.vace_offload_phase_buffers
self.transformer_infer.offload_block_cuda_buffers = self.transformer_weights.offload_block_cuda_buffers
self.transformer_infer.offload_phase_cuda_buffers = self.transformer_weights.offload_phase_cuda_buffers
self.transformer_infer.vace_offload_block_cuda_buffers = self.transformer_weights.vace_offload_block_cuda_buffers
self.transformer_infer.vace_offload_phase_cuda_buffers = self.transformer_weights.vace_offload_phase_cuda_buffers
if self.lazy_load:
self.transformer_infer.offload_block_cpu_buffers = self.transformer_weights.offload_block_cpu_buffers
self.transformer_infer.offload_phase_cpu_buffers = self.transformer_weights.offload_phase_cpu_buffers
self.transformer_infer.vace_offload_block_cpu_buffers = self.transformer_weights.vace_offload_block_cpu_buffers
self.transformer_infer.vace_offload_phase_cpu_buffers = self.transformer_weights.vace_offload_phase_cpu_buffers
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
......
......@@ -26,15 +26,19 @@ class WanAnimateTransformerWeights(WanTransformerWeights):
self._add_animate_fuserblock_to_offload_buffers()
def _add_animate_fuserblock_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None:
if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None:
for i in range(self.offload_blocks_num):
self.offload_block_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True))
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None:
self.offload_phase_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, is_offload_buffer=True))
self.offload_block_cuda_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True))
if self.lazy_load:
self.offload_block_cpu_buffers[i].compute_phases.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True))
elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None:
self.offload_phase_cuda_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cuda_buffer=True))
if self.lazy_load:
self.offload_phase_cpu_buffers.append(WanAnimateFuserBlock(self.config, 0, "face_adapter.fuser_blocks", self.mm_type, create_cpu_buffer=True))
class WanAnimateFuserBlock(WeightModule):
def __init__(self, config, block_index, block_prefix, mm_type, is_offload_buffer=False):
def __init__(self, config, block_index, block_prefix, mm_type, create_cuda_buffer=False, create_cpu_buffer=False):
super().__init__()
self.config = config
self.is_post_adapter = True
......@@ -53,7 +57,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_kv.weight",
f"{block_prefix}.{block_index}.linear1_kv.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
......@@ -65,7 +70,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear1_q.weight",
f"{block_prefix}.{block_index}.linear1_q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
......@@ -76,7 +82,8 @@ class WanAnimateFuserBlock(WeightModule):
MM_WEIGHT_REGISTER[mm_type](
f"{block_prefix}.{block_index}.linear2.weight",
f"{block_prefix}.{block_index}.linear2.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
......@@ -87,7 +94,8 @@ class WanAnimateFuserBlock(WeightModule):
"q_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.q_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
......@@ -98,7 +106,8 @@ class WanAnimateFuserBlock(WeightModule):
"k_norm",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"{block_prefix}.{block_index}.k_norm.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
self.is_post_adapter,
......
......@@ -19,6 +19,7 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self.mm_type,
self.config,
False,
False,
self.blocks[i].lazy_load,
self.blocks[i].lazy_load_file,
)
......@@ -27,37 +28,66 @@ class WanAudioTransformerWeights(WanTransformerWeights):
self._add_audio_adapter_ca_to_offload_buffers()
def _add_audio_adapter_ca_to_offload_buffers(self):
if hasattr(self, "offload_block_buffers") and self.offload_block_buffers is not None:
if hasattr(self, "offload_block_cuda_buffers") and self.offload_block_cuda_buffers is not None:
for i in range(self.offload_blocks_num):
offload_buffer = self.offload_block_buffers[i]
offload_buffer = self.offload_block_cuda_buffers[i]
adapter_ca = WanAudioAdapterCA(
block_index=i,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
is_offload_buffer=True,
create_cuda_buffer=True,
create_cpu_buffer=False,
lazy_load=offload_buffer.lazy_load,
lazy_load_file=offload_buffer.lazy_load_file,
)
offload_buffer.compute_phases.append(adapter_ca)
if self.lazy_load:
offload_buffer = self.offload_block_cpu_buffers[i]
adapter_ca = WanAudioAdapterCA(
block_index=i,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
lazy_load=offload_buffer.lazy_load,
lazy_load_file=offload_buffer.lazy_load_file,
)
offload_buffer.compute_phases.append(adapter_ca)
elif hasattr(self, "offload_phase_buffers") and self.offload_phase_buffers is not None:
elif hasattr(self, "offload_phase_cuda_buffers") and self.offload_phase_cuda_buffers is not None:
adapter_ca = WanAudioAdapterCA(
block_index=0,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
is_offload_buffer=True,
create_cuda_buffer=True,
create_cpu_buffer=False,
lazy_load=self.blocks[0].lazy_load,
lazy_load_file=self.blocks[0].lazy_load_file,
)
self.offload_phase_buffers.append(adapter_ca)
self.offload_phase_cuda_buffers.append(adapter_ca)
if self.lazy_load:
adapter_ca = WanAudioAdapterCA(
block_index=0,
block_prefix=f"ca",
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
lazy_load=self.blocks[0].lazy_load,
lazy_load_file=self.blocks[0].lazy_load_file,
)
self.offload_phase_cpu_buffers.append(adapter_ca)
class WanAudioAdapterCA(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, is_offload_buffer, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -71,7 +101,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_q.weight",
f"{block_prefix}.{block_index}.to_q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -82,7 +113,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_kv.weight",
f"{block_prefix}.{block_index}.to_kv.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -93,7 +125,8 @@ class WanAudioAdapterCA(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{block_index}.to_out.weight",
f"{block_prefix}.{block_index}.to_out.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -104,7 +137,8 @@ class WanAudioAdapterCA(WeightModule):
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{block_index}.norm_kv.weight",
f"{block_prefix}.{block_index}.norm_kv.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -119,7 +153,8 @@ class WanAudioAdapterCA(WeightModule):
"shift_scale_gate",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{block_index}.shift_scale_gate",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanFFN,
......@@ -31,9 +27,9 @@ class WanActionTransformerWeights(WeightModule):
block_list = []
for i in range(self.blocks_num):
if i in action_blocks:
block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config, "blocks"))
block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config))
else:
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config, False, "blocks"))
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config))
self.blocks = WeightModuleList(block_list)
self.add_module("blocks", self.blocks)
......@@ -42,11 +38,6 @@ class WanActionTransformerWeights(WeightModule):
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
def non_block_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
......@@ -66,34 +57,16 @@ class WanTransformerActionBlock(WeightModule):
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
assert not self.config.get("lazy_load", False)
self.compute_phases = WeightModuleList(
[
WanSelfAttention(
block_index,
block_prefix,
task,
mm_type,
config,
False,
self.lazy_load,
self.lazy_load_file,
),
WanSelfAttention(block_index, block_prefix, task, mm_type, config),
WanActionCrossAttention(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanActionModule(
block_index,
......@@ -101,8 +74,6 @@ class WanTransformerActionBlock(WeightModule):
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanFFN(
block_index,
......@@ -110,9 +81,6 @@ class WanTransformerActionBlock(WeightModule):
task,
mm_type,
config,
False,
self.lazy_load,
self.lazy_load_file,
),
]
)
......@@ -121,7 +89,7 @@ class WanTransformerActionBlock(WeightModule):
class WanActionModule(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -129,9 +97,6 @@ class WanActionModule(WeightModule):
self.config = config
self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.attn_rms_type = "self_forcing"
self.add_module(
......@@ -139,8 +104,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -148,8 +111,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
......@@ -158,8 +119,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_keyboard.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
......@@ -168,8 +127,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_attn_kv.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
......@@ -180,8 +137,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_attn_q.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
......@@ -191,8 +146,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.t_qkv.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
......@@ -201,8 +154,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_mouse.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
......@@ -211,8 +162,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -220,8 +169,6 @@ class WanActionModule(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -229,22 +176,18 @@ class WanActionModule(WeightModule):
LN_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6,
),
)
class WanActionCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.config.get("sf_config", False):
self.attn_rms_type = "self_forcing"
......@@ -256,8 +199,6 @@ class WanActionCrossAttention(WeightModule):
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -265,8 +206,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -274,8 +213,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -283,8 +220,6 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
......@@ -292,24 +227,18 @@ class WanActionCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
......@@ -13,7 +11,7 @@ from lightx2v.utils.registry_factory import (
class WanTransformerWeights(WeightModule):
def __init__(self, config):
def __init__(self, config, lazy_load_path=None):
super().__init__()
self.blocks_num = config["num_layers"]
self.task = config["task"]
......@@ -23,7 +21,27 @@ class WanTransformerWeights(WeightModule):
assert config.get("dit_quantized") is True
if config.get("do_mm_calib", False):
self.mm_type = "Calib"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.lazy_load = self.config.get("lazy_load", False)
if not self.lazy_load:
self.lazy_load_file = None
else:
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
self.blocks = WeightModuleList(
[
WanTransformerAttentionBlock(
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
for i in range(self.blocks_num)
]
)
self.register_offload_buffers(config)
self.add_module("blocks", self.blocks)
......@@ -36,35 +54,74 @@ class WanTransformerWeights(WeightModule):
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.offload_blocks_num = 2
self.offload_block_buffers = WeightModuleList(
self.offload_block_cuda_buffers = WeightModuleList(
[
WanTransformerAttentionBlock(
i,
self.task,
self.mm_type,
self.config,
True,
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=True,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_buffers", self.offload_block_buffers)
self.offload_phase_buffers = None
self.add_module("offload_block_cuda_buffers", self.offload_block_cuda_buffers)
self.offload_phase_cuda_buffers = None
if self.lazy_load:
self.offload_blocks_num = 2
self.offload_block_cpu_buffers = WeightModuleList(
[
WanTransformerAttentionBlock(
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
)
for i in range(self.offload_blocks_num)
]
)
self.add_module("offload_block_cpu_buffers", self.offload_block_cpu_buffers)
self.offload_phase_cpu_buffers = None
elif config["offload_granularity"] == "phase":
self.offload_phase_buffers = WanTransformerAttentionBlock(
0,
self.task,
self.mm_type,
self.config,
True,
self.offload_phase_cuda_buffers = WanTransformerAttentionBlock(
block_index=0,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=True,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
).compute_phases
self.add_module("offload_phase_buffers", self.offload_phase_buffers)
self.offload_block_buffers = None
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers)
self.offload_block_cuda_buffers = None
if self.lazy_load:
self.offload_phase_cpu_buffers = WanTransformerAttentionBlock(
block_index=0,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
).compute_phases
self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers)
self.offload_block_cpu_buffers = None
def non_block_weights_to_cuda(self):
self.norm.to_cuda()
......@@ -84,23 +141,23 @@ class WanTransformerAttentionBlock(WeightModule):
task,
mm_type,
config,
is_offload_buffer=False,
create_cuda_buffer=False,
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=False,
lazy_load_file=None,
):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.is_offload_buffer = is_offload_buffer
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.quant_method = config.get("quant_method", None)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.compute_phases = WeightModuleList(
[
......@@ -110,7 +167,8 @@ class WanTransformerAttentionBlock(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -120,7 +178,8 @@ class WanTransformerAttentionBlock(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -130,7 +189,8 @@ class WanTransformerAttentionBlock(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -148,9 +208,10 @@ class WanSelfAttention(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
):
super().__init__()
self.block_index = block_index
......@@ -171,7 +232,8 @@ class WanSelfAttention(WeightModule):
"modulation",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.modulation",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -187,7 +249,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.q.weight",
f"{block_prefix}.{self.block_index}.self_attn.q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -198,7 +261,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.k.weight",
f"{block_prefix}.{self.block_index}.self_attn.k.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -208,7 +272,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.v.weight",
f"{block_prefix}.{self.block_index}.self_attn.v.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -218,7 +283,8 @@ class WanSelfAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.self_attn.o.weight",
f"{block_prefix}.{self.block_index}.self_attn.o.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -227,7 +293,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -236,7 +303,8 @@ class WanSelfAttention(WeightModule):
"self_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -278,7 +346,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_weight",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -287,7 +356,8 @@ class WanSelfAttention(WeightModule):
"smooth_norm1_bias",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm1.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -302,9 +372,10 @@ class WanCrossAttention(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
):
super().__init__()
self.block_index = block_index
......@@ -324,7 +395,8 @@ class WanCrossAttention(WeightModule):
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -334,7 +406,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -344,7 +417,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -354,7 +428,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -364,7 +439,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -373,7 +449,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -382,7 +459,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -395,7 +473,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -405,7 +484,8 @@ class WanCrossAttention(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -414,7 +494,8 @@ class WanCrossAttention(WeightModule):
"cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -430,9 +511,10 @@ class WanFFN(WeightModule):
task,
mm_type,
config,
is_offload_buffer,
lazy_load,
lazy_load_file,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
):
super().__init__()
self.block_index = block_index
......@@ -453,7 +535,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.0.weight",
f"{block_prefix}.{self.block_index}.ffn.0.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -463,7 +546,8 @@ class WanFFN(WeightModule):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.ffn.2.weight",
f"{block_prefix}.{self.block_index}.ffn.2.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -474,7 +558,8 @@ class WanFFN(WeightModule):
"smooth_norm2_weight",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.weight",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -483,7 +568,8 @@ class WanFFN(WeightModule):
"smooth_norm2_bias",
TENSOR_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.affine_norm3.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......
......@@ -15,7 +15,7 @@ class WanVaceTransformerWeights(WanTransformerWeights):
self.patch_size = (1, 2, 2)
self.register_offload_buffers(config)
self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
[WanVaceTransformerAttentionBlock(self.config["vace_layers"][i], i, self.task, self.mm_type, self.config, False, False, "vace_blocks") for i in range(len(self.config["vace_layers"]))]
)
self.add_module("vace_blocks", self.vace_blocks)
self.add_module(
......@@ -27,23 +27,17 @@ class WanVaceTransformerWeights(WanTransformerWeights):
super().register_offload_buffers(config)
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.vace_offload_block_buffers = WeightModuleList(
self.vace_offload_block_cuda_buffers = WeightModuleList(
[
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"),
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, "vace_blocks"),
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"),
WanVaceTransformerAttentionBlock(self.config["vace_layers"][0], 0, self.task, self.mm_type, self.config, True, False, "vace_blocks"),
]
)
self.add_module("vace_offload_block_buffers", self.vace_offload_block_buffers)
self.vace_offload_phase_buffers = None
self.add_module("vace_offload_block_cuda_buffers", self.vace_offload_block_cuda_buffers)
self.vace_offload_phase_cuda_buffers = None
elif config["offload_granularity"] == "phase":
raise NotImplementedError
def clear(self):
super().clear()
for vace_block in self.vace_blocks:
for vace_phase in vace_block.compute_phases:
vace_phase.clear()
def non_block_weights_to_cuda(self):
super().non_block_weights_to_cuda()
self.vace_patch_embedding.to_cuda()
......@@ -54,15 +48,16 @@ class WanVaceTransformerWeights(WanTransformerWeights):
class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, is_offload_buffer, block_prefix):
super().__init__(block_index, task, mm_type, config, is_offload_buffer, block_prefix)
def __init__(self, base_block_idx, block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix):
super().__init__(block_index, task, mm_type, config, create_cuda_buffer, create_cpu_buffer, block_prefix)
if base_block_idx == 0:
self.compute_phases[0].add_module(
"before_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight",
f"{block_prefix}.{self.block_index}.before_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......@@ -73,7 +68,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight",
f"{block_prefix}.{self.block_index}.after_proj.bias",
is_offload_buffer,
create_cuda_buffer,
create_cpu_buffer,
self.lazy_load,
self.lazy_load_file,
),
......
......@@ -41,7 +41,8 @@ class DefaultRunner(BaseRunner):
self.load_model()
elif self.config.get("lazy_load", False):
assert self.config.get("cpu_offload", False)
self.model.set_scheduler(self.scheduler) # set scheduler to model
if hasattr(self, "model"):
self.model.set_scheduler(self.scheduler) # set scheduler to model
if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v
elif self.config["task"] == "flf2v":
......@@ -184,11 +185,6 @@ class DefaultRunner(BaseRunner):
del self.inputs
self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
self.model.transformer_infer.weights_stream_mgr.clear()
if hasattr(self.model.transformer_weights, "clear"):
self.model.transformer_weights.clear()
self.model.pre_weight.clear()
del self.model
if self.config.get("do_mm_calib", False):
calib_path = os.path.join(os.getcwd(), "calib.pt")
......@@ -279,6 +275,7 @@ class DefaultRunner(BaseRunner):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.model.set_scheduler(self.scheduler)
self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
......
......@@ -24,6 +24,7 @@ from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import load_weights, remove_substrings_from_keys
from lightx2v_platform.base.global_var import AI_DEVICE
@RUNNER_REGISTER("wan2.2_animate")
......@@ -182,7 +183,7 @@ class WanAnimateRunner(WanRunner):
],
dim=1,
)
.cuda()
.to(AI_DEVICE)
.unsqueeze(0)
)
mask_pixel_values = 1 - mask_pixel_values
......@@ -210,7 +211,7 @@ class WanAnimateRunner(WanRunner):
],
dim=1,
)
.cuda()
.to(AI_DEVICE)
.unsqueeze(0)
)
msk_reft = self.get_i2v_mask(self.latent_t, self.latent_h, self.latent_w, self.mask_reft_len)
......@@ -330,7 +331,7 @@ class WanAnimateRunner(WanRunner):
dtype=GET_DTYPE(),
) # c t h w
else:
refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().cuda() # c t h w
refer_t_pixel_values = self.gen_video[0, :, -self.config["refert_num"] :].transpose(0, 1).clone().detach().to(AI_DEVICE) # c t h w
bg_pixel_values, mask_pixel_values = None, None
if self.config["replace_flag"] if "replace_flag" in self.config else False:
......@@ -408,8 +409,8 @@ class WanAnimateRunner(WanRunner):
return model
def load_encoders(self):
motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).cuda()
motion_encoder = Generator(size=512, style_dim=512, motion_dim=20).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4).eval().requires_grad_(False).to(GET_DTYPE()).to(AI_DEVICE)
motion_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["motion_encoder"]), "motion_encoder.")
face_weight_dict = remove_substrings_from_keys(load_weights(self.config["model_path"], include_keys=["face_encoder"]), "face_encoder.")
motion_encoder.load_state_dict(motion_weight_dict)
......
......@@ -435,7 +435,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def process_single_mask(self, mask_file):
mask_img = load_image(mask_file)
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
mask_img = TF.to_tensor(mask_img).sub_(0.5).div_(0.5).unsqueeze(0).to(AI_DEVICE)
if mask_img.shape[1] == 3: # If it is an RGB three-channel image
mask_img = mask_img[:, :1] # Only take the first channel
......
......@@ -13,6 +13,7 @@ from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class VAEWrapper:
......@@ -90,8 +91,8 @@ def get_current_action(mode="universal"):
flag = 1
except Exception as e:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).to(AI_DEVICE)
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE)
elif mode == "gta_drive":
print()
print("-" * 30)
......@@ -118,8 +119,8 @@ def get_current_action(mode="universal"):
flag = 1
except Exception as e:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).cuda()
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).to(AI_DEVICE)
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).to(AI_DEVICE)
elif mode == "templerun":
print()
print("-" * 30)
......@@ -142,7 +143,7 @@ def get_current_action(mode="universal"):
flag = 1
except Exception as e:
pass
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).to(AI_DEVICE)
if mode != "templerun":
return {"mouse": mouse_cond, "keyboard": keyboard_cond}
......
......@@ -164,7 +164,7 @@ class WanRunner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.init_device)
vae_device = torch.device(AI_DEVICE)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
......@@ -178,7 +178,7 @@ class WanRunner(DefaultRunner):
}
if self.config.get("use_tae", False):
tae_path = find_torch_model_path(self.config, "tae_path", self.tiny_vae_name)
vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to("cuda")
vae_decoder = self.tiny_vae_cls(vae_path=tae_path, device=self.init_device, need_scaled=self.config.get("need_scaled", False)).to(AI_DEVICE)
else:
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
......
......@@ -2,6 +2,8 @@ from typing import List, Tuple, Union
import torch
from lightx2v_platform.base.global_var import AI_DEVICE
def _to_tuple(x, dim=2):
if isinstance(x, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment