Unverified Commit f3b4ba24 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

[Recon] Reconstruct disk-cpu-cuda offload system (#578)

parent 67d6c6c1
from concurrent.futures import ThreadPoolExecutor
import torch
from loguru import logger
from packaging.version import parse
from tqdm import tqdm
from lightx2v.utils.profiler import ExcludedProfilingContext
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
......@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
self.offload_granularity = offload_granularity
self.init_stream = torch_device_module.Stream(priority=0)
self.need_init_first_buffer = True
self.lazy_load = False
torch_version = parse(torch.__version__.split("+")[0])
if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
self.cuda_load_stream = torch_device_module.Stream(priority=1)
......@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
def init_first_buffer(self, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.init_stream):
if hasattr(self, "cpu_buffers"):
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx)
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
else:
if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
......@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"):
self.cpu_buffers[phase_idx].load_state_dict_from_disk(block_idx, adapter_block_idx)
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[phase_idx].state_dict(), block_idx, adapter_block_idx)
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
else:
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
......@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
def swap_phases(self):
self.cuda_load_stream.synchronize()
self.compute_stream.synchronize()
@ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
def warm_up_cpu_buffers(self, blocks_num):
logger.info("🔥 Warming up cpu buffers...")
for i in tqdm(range(blocks_num)):
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(0, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(1, None)
logger.info("✅ CPU buffers warm-up completed.")
def init_lazy_load(self, num_workers=6):
self.lazy_load = True
self.executor = ThreadPoolExecutor(max_workers=num_workers)
self.prefetch_futures = []
self.prefetch_block_idx = -1
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
self.prefetch_block_idx = block_idx
self.prefetch_futures = []
for phase in self.cpu_buffers[1]:
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
import time
wait_start = time.time()
already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
f.result()
wait_time = time.time() - wait_start
logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def shutdown(self, wait=True):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if hasattr(self, "executor") and self.executor is not None:
# Wait for all pending futures to complete before shutting down
if hasattr(self, "prefetch_futures"):
for f in self.prefetch_futures:
try:
if not f.done():
f.result()
except Exception:
pass
self.executor.shutdown(wait=wait)
self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.")
def __del__(self):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try:
if hasattr(self, "executor") and self.executor is not None:
self.executor.shutdown(wait=False)
except Exception:
pass
import os
import re
from abc import ABCMeta, abstractmethod
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import GGMLTensor
......@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
def _get_source_tensor(self, source_name, weight_dict=None):
if self.lazy_load:
return self.lazy_load_file.get_tensor(source_name)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
return lazy_load_file.get_tensor(source_name)
return weight_dict[source_name]
def _create_pin_tensor(self, tensor, transpose=False):
......@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate):
self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
if self.lazy_load:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
if self.bias_name is not None:
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
self.pin_bias = self._create_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias = self._create_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
......@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
......@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate):
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
......@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self._load_default_tensors(weight_dict)
def _load_cuda_buffers(self, weight_dict):
source = self.lazy_load_file if self.lazy_load else weight_dict
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
if self.lazy_load:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
else:
source = weight_dict
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
def _get_cuda_tensor_pair(self, source, is_lazy):
if is_lazy:
......@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy:
weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
scale_dtype = torch.float
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
else:
weight_tensor = source[self.weight_name]
scale_tensor = source[self.weight_scale_name]
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
return pin_weight, pin_scale
def _get_cpu_pin_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
else:
bias_tensor = source[self.bias_name]
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
def _create_pin_tensor(self, tensor, dtype=None):
dtype = dtype or tensor.dtype
......@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
if self.weight_need_transpose:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
weight_scale_tensor = self.lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_tensor
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
......@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if self.weight_need_transpose:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_scale_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
@MM_WEIGHT_REGISTER("fp8-vllm")
......
import os
import re
from abc import ABCMeta, abstractmethod
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
......@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
if name is None:
return None
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[name]
return tensor
......@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta):
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias.copy_(bias_tensor)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
......
import os
import re
from abc import ABCMeta, abstractmethod
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
......@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.weight_name]
return tensor
......@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
......
import os
import re
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER
......@@ -39,9 +41,11 @@ class DefaultTensor:
def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.tensor_name]
return tensor
......@@ -92,7 +96,8 @@ class DefaultTensor:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
del tensor
import torch
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
from loguru import logger
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
except ImportError:
flash_attn_varlen_qkvpacked_func = None
logger.info("flash_attn_varlen_qkvpacked_func not available")
try:
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError:
pad_input = None
unpad_input = None
logger.info("flash_attn.bert_padding not available")
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
......
......@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if offload_granularity != "model":
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load and offload_granularity == "phase":
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
......@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)):
self.block_idx = block_idx
if self.lazy_load:
next_prefetch = (block_idx + 1) % len(blocks)
self.offload_manager.start_prefetch_block(next_prefetch)
x = self.infer_phases(block_idx, blocks, x, pre_infer_out)
if self.clean_cuda_cache:
del (
......@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.offload_manager.init_first_buffer(blocks)
next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num
if self.lazy_load:
if phase_idx == self.phases_num - 1:
self.offload_manager.swap_cpu_buffers()
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
with torch_device_module.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
......
......@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)
if self.clean_cuda_cache:
del norm1_out, norm1_weight, norm1_bias
del norm1_out, shift_msa, scale_msa
torch.cuda.empty_cache()
if self.config["seq_parallel"]:
......@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias
del norm2_out, x
torch.cuda.empty_cache()
y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache:
......
......@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
rope_func,
):
seq_len = cos_sin_cache.size(0)
x_q = torch.empty_like(xq)
x_k = torch.empty_like(xk)
xq_output_chunks = []
xk_output_chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
xq_chunk = xq[start:end]
xk_chunk = xk[start:end]
cos_sin_chunk = cos_sin_cache[start:end]
xq_chunk, xk_chunk = rope_func(xq_chunk, xk_chunk, cos_sin_chunk)
xq_output_chunks.append(xq_chunk)
xk_output_chunks.append(xk_chunk)
torch.cuda.empty_cache()
x_q = torch.cat(xq_output_chunks, dim=0)
del xq_output_chunks
torch.cuda.empty_cache()
x_k = torch.cat(xk_output_chunks, dim=0)
del xk_output_chunks
torch.cuda.empty_cache()
return x_q.to(GET_DTYPE()), x_k.to(GET_DTYPE())
xq_chunk_out, xk_chunk_out = rope_func(xq_chunk, xk_chunk, cos_sin_chunk)
x_q[start:end].copy_(xq_chunk_out, non_blocking=True)
x_k[start:end].copy_(xk_chunk_out, non_blocking=True)
del xq_chunk_out, xk_chunk_out
target_dtype = GET_DTYPE()
if x_q.dtype != target_dtype:
x_q = x_q.to(target_dtype)
if x_k.dtype != target_dtype:
x_k = x_k.to(target_dtype)
return x_q, x_k
def apply_wan_rope_with_flashinfer(
......
......@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_files = [safetensors_path]
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
self.lazy_load_path = safetensors_path
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}")
weight_dict = {}
for file_path in safetensors_files:
......@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"]
else:
......@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_path = os.path.dirname(safetensors_path)
if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode"
self.lazy_load_path = safetensors_files[0]
self.lazy_load_path = safetensors_path
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}, Please check the lazy load model path")
weight_dict = {}
for safetensor_path in safetensors_files:
......@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
self._init_offload_manager()
def _init_offload_manager(self):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
if self.config.get("warm_up_cpu_buffers", False):
self.transformer_infer.offload_manager.warm_up_cpu_buffers(self.transformer_weights.blocks_num)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
......
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER,
......@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
if config.get("do_mm_calib", False):
self.mm_type = "Calib"
self.lazy_load = self.config.get("lazy_load", False)
if not self.lazy_load:
self.lazy_load_file = None
else:
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
self.blocks = WeightModuleList(
[
WanTransformerAttentionBlock(
......@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
lazy_load_path=lazy_load_path,
)
for i in range(self.blocks_num)
]
)
self.register_offload_buffers(config)
self.register_offload_buffers(config, lazy_load_path)
self.add_module("blocks", self.blocks)
# non blocks weights
......@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def register_offload_buffers(self, config):
def register_offload_buffers(self, config, lazy_load_path):
if config["cpu_offload"]:
if config["offload_granularity"] == "block":
self.offload_blocks_num = 2
......@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
lazy_load_path=lazy_load_path,
)
for i in range(self.offload_blocks_num)
]
......@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
lazy_load_path=lazy_load_path,
)
for i in range(self.offload_blocks_num)
]
......@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
lazy_load_path=lazy_load_path,
).compute_phases
self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers)
self.offload_block_cuda_buffers = None
if self.lazy_load:
self.offload_phase_cpu_buffers = WanTransformerAttentionBlock(
block_index=0,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
).compute_phases
self.offload_phase_cpu_buffers = WeightModuleList(
[
WanTransformerAttentionBlock(
block_index=i,
task=self.task,
mm_type=self.mm_type,
config=self.config,
create_cuda_buffer=False,
create_cpu_buffer=True,
block_prefix="blocks",
lazy_load=self.lazy_load,
lazy_load_path=lazy_load_path,
).compute_phases
for i in range(2)
]
)
self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers)
self.offload_block_cpu_buffers = None
......@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
create_cpu_buffer=False,
block_prefix="blocks",
lazy_load=False,
lazy_load_file=None,
lazy_load_path=None,
):
super().__init__()
self.block_index = block_index
......@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.lazy_load:
self.lazy_load_file = lazy_load_path
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList(
[
......
......@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner):
del self.inputs
self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.model
if hasattr(self.model, "model") and len(self.model.model) == 2: # MultiModelStruct
for model in self.model.model:
if hasattr(model.transformer_infer, "offload_manager"):
del model.transformer_infer.offload_manager
torch.cuda.empty_cache()
gc.collect()
del model
else:
if hasattr(self.model.transformer_infer, "offload_manager"):
del self.model.transformer_infer.offload_manager
torch.cuda.empty_cache()
gc.collect()
del self.model
if self.config.get("do_mm_calib", False):
calib_path = os.path.join(os.getcwd(), "calib.pt")
torch.save(CALIB, calib_path)
......
......@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
self.to_cuda(model_index=1)
self.cur_model_index = 1
def infer(self, inputs):
self.get_current_model_index()
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.model[self.cur_model_index].infer(inputs)
else:
if self.model[self.cur_model_index] is not None:
self.model[self.cur_model_index].infer(inputs)
else:
if self.cur_model_index == 0:
high_noise_model = WanDistillModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_noise_model.set_scheduler(self.scheduler)
self.model[0] = high_noise_model
self.model[0].infer(inputs)
elif self.cur_model_index == 1:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_noise_model.set_scheduler(self.scheduler)
self.model[1] = low_noise_model
self.model[1].infer(inputs)
@RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner):
......@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner):
raise FileNotFoundError(f"Low Noise Model does not find")
def load_transformer(self):
use_high_lora, use_low_lora = False, False
if self.config.get("lora_configs") and self.config["lora_configs"]:
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
use_high_lora = True
elif lora_config.get("name", "") == "low_noise_model":
use_low_lora = True
if use_high_lora:
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = high_lora_wrapper.load_lora(lora_path)
high_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
use_high_lora, use_low_lora = False, False
if self.config.get("lora_configs") and self.config["lora_configs"]:
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
use_high_lora = True
elif lora_config.get("name", "") == "low_noise_model":
use_low_lora = True
if use_high_lora:
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_lora_wrapper = WanLoraWrapper(high_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = high_lora_wrapper.load_lora(lora_path)
high_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
high_noise_model = WanDistillModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
if use_low_lora:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = low_lora_wrapper.load_lora(lora_path)
low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
else:
high_noise_model = WanDistillModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
if use_low_lora:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = low_lora_wrapper.load_lora(lora_path)
low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
model_struct = MultiDistillModelStruct([None, None], self.config, self.config["boundary_step_index"])
model_struct.low_noise_model_path = self.low_noise_model_path
model_struct.high_noise_model_path = self.high_noise_model_path
model_struct.init_device = self.init_device
return model_struct
def init_scheduler(self):
if self.config["feature_caching"] == "NoCaching":
......
......@@ -468,11 +468,37 @@ class MultiModelStruct:
def set_scheduler(self, shared_scheduler):
self.scheduler = shared_scheduler
for model in self.model:
model.set_scheduler(shared_scheduler)
if model is not None:
model.set_scheduler(shared_scheduler)
def infer(self, inputs):
self.get_current_model_index()
self.model[self.cur_model_index].infer(inputs)
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.model[self.cur_model_index].infer(inputs)
else:
if self.model[self.cur_model_index] is not None:
self.model[self.cur_model_index].infer(inputs)
else:
if self.cur_model_index == 0:
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_noise_model.set_scheduler(self.scheduler)
self.model[0] = high_noise_model
self.model[0].infer(inputs)
elif self.cur_model_index == 1:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_noise_model.set_scheduler(self.scheduler)
self.model[1] = low_noise_model
self.model[1].infer(inputs)
@ProfilingContext4DebugL2("Swtich models in infer_main costs")
def get_current_model_index(self):
......@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner):
def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False)
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
base_name = os.path.basename(lora_path)
if base_name.startswith("high"):
lora_wrapper = WanLoraWrapper(high_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
elif base_name.startswith("low"):
lora_wrapper = WanLoraWrapper(low_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
raise ValueError(f"Unsupported LoRA path: {lora_path}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False)
for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
base_name = os.path.basename(lora_path)
if base_name.startswith("high"):
lora_wrapper = WanLoraWrapper(high_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
elif base_name.startswith("low"):
lora_wrapper = WanLoraWrapper(low_noise_model)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
raise ValueError(f"Unsupported LoRA path: {lora_path}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
else:
model_struct = MultiModelStruct([None, None], self.config, self.config["boundary"])
model_struct.low_noise_model_path = self.low_noise_model_path
model_struct.high_noise_model_path = self.high_noise_model_path
model_struct.init_device = self.init_device
return model_struct
@RUNNER_REGISTER("wan2.2")
......
import asyncio
import threading
import time
from functools import wraps
......@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE)
_excluded_time_local = threading.local()
def _get_excluded_time_stack():
if not hasattr(_excluded_time_local, "stack"):
_excluded_time_local.stack = []
return _excluded_time_local.stack
class _ProfilingContext:
......@@ -32,11 +40,14 @@ class _ProfilingContext:
def __enter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
_get_excluded_time_stack().append(0.0)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
total_elapsed = time.perf_counter() - self.start_time
excluded = _get_excluded_time_stack().pop()
elapsed = total_elapsed - excluded
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
......@@ -49,11 +60,14 @@ class _ProfilingContext:
async def __aenter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
_get_excluded_time_stack().append(0.0)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
total_elapsed = time.perf_counter() - self.start_time
excluded = _get_excluded_time_stack().pop()
elapsed = total_elapsed - excluded
if self.enable_recorder and self.metrics_func:
if self.metrics_labels:
self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
......@@ -103,6 +117,65 @@ class _NullContext:
return func
class _ExcludedProfilingContext:
"""用于标记应该从外层 profiling 中排除的时间段"""
def __init__(self, name=None):
self.name = name
if dist.is_initialized():
self.rank_info = f"Rank {dist.get_rank()}"
else:
self.rank_info = "Single GPU"
def __enter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
stack = _get_excluded_time_stack()
for i in range(len(stack)):
stack[i] += elapsed
if self.name and CHECK_PROFILING_DEBUG_LEVEL(1):
logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)")
return False
async def __aenter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
stack = _get_excluded_time_stack()
for i in range(len(stack)):
stack[i] += elapsed
if self.name and CHECK_PROFILING_DEBUG_LEVEL(1):
logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)")
return False
def __call__(self, func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with self:
return await func(*args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return sync_wrapper
class _ProfilingContextL1(_ProfilingContext):
"""Level 1 profiling context with Level1_Log prefix."""
......@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
"""
ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext # if user >= 1, enable profiling
ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext # if user >= 2, enable profiling
ExcludedProfilingContext = _ExcludedProfilingContext if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment