Unverified Commit f3b4ba24 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

[Recon] Reconstruct disk-cpu-cuda offload system (#578)

parent 67d6c6c1
from concurrent.futures import ThreadPoolExecutor
import torch import torch
from loguru import logger
from packaging.version import parse from packaging.version import parse
from tqdm import tqdm
from lightx2v.utils.profiler import ExcludedProfilingContext
from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE) torch_device_module = getattr(torch, AI_DEVICE)
...@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object): ...@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
self.offload_granularity = offload_granularity self.offload_granularity = offload_granularity
self.init_stream = torch_device_module.Stream(priority=0) self.init_stream = torch_device_module.Stream(priority=0)
self.need_init_first_buffer = True self.need_init_first_buffer = True
self.lazy_load = False
torch_version = parse(torch.__version__.split("+")[0]) torch_version = parse(torch.__version__.split("+")[0])
if AI_DEVICE == "cuda" and torch_version >= parse("2.7"): if AI_DEVICE == "cuda" and torch_version >= parse("2.7"):
self.cuda_load_stream = torch_device_module.Stream(priority=1) self.cuda_load_stream = torch_device_module.Stream(priority=1)
...@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object): ...@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
def init_first_buffer(self, blocks, adapter_block_idx=None): def init_first_buffer(self, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.init_stream): with torch_device_module.stream(self.init_stream):
if hasattr(self, "cpu_buffers"): if hasattr(self, "cpu_buffers"):
self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0].state_dict(), 0, adapter_block_idx) self.cuda_buffers[0].load_state_dict(self.cpu_buffers[0][0].state_dict(), 0, adapter_block_idx)
else: else:
if self.offload_granularity == "block": if self.offload_granularity == "block":
self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx) self.cuda_buffers[0].load_state_dict(blocks[0].state_dict(), 0, adapter_block_idx)
...@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object): ...@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None): def prefetch_phase(self, block_idx, phase_idx, blocks, adapter_block_idx=None):
with torch_device_module.stream(self.cuda_load_stream): with torch_device_module.stream(self.cuda_load_stream):
if hasattr(self, "cpu_buffers"): if hasattr(self, "cpu_buffers"):
self.cpu_buffers[phase_idx].load_state_dict_from_disk(block_idx, adapter_block_idx) self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[0][phase_idx].state_dict(), block_idx, adapter_block_idx)
self.cuda_buffers[phase_idx].load_state_dict(self.cpu_buffers[phase_idx].state_dict(), block_idx, adapter_block_idx)
else: else:
self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx) self.cuda_buffers[phase_idx].load_state_dict(blocks[block_idx].compute_phases[phase_idx].state_dict(), block_idx, adapter_block_idx)
...@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object): ...@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
def swap_phases(self): def swap_phases(self):
self.cuda_load_stream.synchronize() self.cuda_load_stream.synchronize()
self.compute_stream.synchronize() self.compute_stream.synchronize()
@ExcludedProfilingContext("🔥 warm_up_cpu_buffers")
def warm_up_cpu_buffers(self, blocks_num):
logger.info("🔥 Warming up cpu buffers...")
for i in tqdm(range(blocks_num)):
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(i, None)
for phase in self.cpu_buffers[0]:
phase.load_state_dict_from_disk(0, None)
for phase in self.cpu_buffers[1]:
phase.load_state_dict_from_disk(1, None)
logger.info("✅ CPU buffers warm-up completed.")
def init_lazy_load(self, num_workers=6):
self.lazy_load = True
self.executor = ThreadPoolExecutor(max_workers=num_workers)
self.prefetch_futures = []
self.prefetch_block_idx = -1
def start_prefetch_block(self, block_idx, adapter_block_idx=None):
self.prefetch_block_idx = block_idx
self.prefetch_futures = []
for phase in self.cpu_buffers[1]:
future = self.executor.submit(phase.load_state_dict_from_disk, block_idx, adapter_block_idx)
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
import time
wait_start = time.time()
already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
f.result()
wait_time = time.time() - wait_start
logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def shutdown(self, wait=True):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if hasattr(self, "executor") and self.executor is not None:
# Wait for all pending futures to complete before shutting down
if hasattr(self, "prefetch_futures"):
for f in self.prefetch_futures:
try:
if not f.done():
f.result()
except Exception:
pass
self.executor.shutdown(wait=wait)
self.executor = None
logger.debug("ThreadPoolExecutor shut down successfully.")
def __del__(self):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try:
if hasattr(self, "executor") and self.executor is not None:
self.executor.shutdown(wait=False)
except Exception:
pass
import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import GGMLTensor from lightx2v.utils.ggml_tensor import GGMLTensor
...@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate): ...@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
def _get_source_tensor(self, source_name, weight_dict=None): def _get_source_tensor(self, source_name, weight_dict=None):
if self.lazy_load: if self.lazy_load:
return self.lazy_load_file.get_tensor(source_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
return lazy_load_file.get_tensor(source_name)
return weight_dict[source_name] return weight_dict[source_name]
def _create_pin_tensor(self, tensor, transpose=False): def _create_pin_tensor(self, tensor, transpose=False):
...@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate): ...@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate):
self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE) self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE)
def _load_cpu_pin_buffers(self): def _load_cpu_pin_buffers(self):
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name) if self.lazy_load:
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
if self.bias_name is not None: if self.bias_name is not None:
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name) bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias = self._create_pin_tensor(bias_tensor) self.pin_bias = self._create_pin_tensor(bias_tensor)
else: else:
self.bias = None self.bias = None
self.pin_bias = None self.pin_bias = None
def _load_default_tensors(self, weight_dict): def _load_default_tensors(self, weight_dict):
if not self.lazy_load: if not self.lazy_load:
...@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate): ...@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
else: else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None: if self.bias_name is not None:
if self.is_post_adapter: if self.is_post_adapter:
assert adapter_block_index is not None assert adapter_block_index is not None
...@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate): ...@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate):
else: else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
self.pin_bias.copy_(bias_tensor) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
del bias_tensor weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
def load_state_dict(self, destination, block_index, adapter_block_index=None): def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter: if self.is_post_adapter:
...@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self._load_default_tensors(weight_dict) self._load_default_tensors(weight_dict)
def _load_cuda_buffers(self, weight_dict): def _load_cuda_buffers(self, weight_dict):
source = self.lazy_load_file if self.lazy_load else weight_dict if self.lazy_load:
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
else:
source = weight_dict
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
def _get_cuda_tensor_pair(self, source, is_lazy): def _get_cuda_tensor_pair(self, source, is_lazy):
if is_lazy: if is_lazy:
...@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def _get_cpu_pin_tensor_pair(self, source, is_lazy): def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy: if is_lazy:
weight_tensor = source.get_tensor(self.weight_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
scale_tensor = source.get_tensor(self.weight_scale_name) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
scale_dtype = torch.float weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
else: else:
weight_tensor = source[self.weight_name] weight_tensor = source[self.weight_name]
scale_tensor = source[self.weight_scale_name] scale_tensor = source[self.weight_scale_name]
scale_dtype = torch.float scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_weight = self._create_pin_tensor(weight_tensor) pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
return pin_weight, pin_scale return pin_weight, pin_scale
def _get_cpu_pin_bias_tensor(self, source, is_lazy): def _get_cpu_pin_bias_tensor(self, source, is_lazy):
if self.bias_name is None: if self.bias_name is None:
return None return None
if is_lazy: if is_lazy:
bias_tensor = source.get_tensor(self.bias_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
if not self.bias_force_fp32: with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
bias_tensor = bias_tensor.to(self.infer_dtype) bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
else: else:
bias_tensor = source[self.bias_name] bias_tensor = source[self.bias_name]
if self.bias_force_fp32: if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32) bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor) return self._create_pin_tensor(bias_tensor)
def _create_pin_tensor(self, tensor, dtype=None): def _create_pin_tensor(self, tensor, dtype=None):
dtype = dtype or tensor.dtype dtype = dtype or tensor.dtype
...@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1) self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
if self.weight_need_transpose:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
weight_scale_tensor = self.lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_tensor
if self.bias_name is not None: if self.bias_name is not None:
if self.is_post_adapter: if self.is_post_adapter:
assert adapter_block_index is not None assert adapter_block_index is not None
...@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate): ...@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else: else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
self.pin_bias.copy_(bias_tensor) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
del bias_tensor if self.weight_need_transpose:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_scale_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
@MM_WEIGHT_REGISTER("fp8-vllm") @MM_WEIGHT_REGISTER("fp8-vllm")
......
import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
...@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
if name is None: if name is None:
return None return None
if self.lazy_load: if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors")
if use_infer_dtype: with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = tensor.to(self.infer_dtype) tensor = lazy_load_file.get_tensor(name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else: else:
tensor = weight_dict[name] tensor = weight_dict[name]
return tensor return tensor
...@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta):
def load_state_dict_from_disk(self, block_index, adapter_block_index=None): def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None: if self.weight_name is not None:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter: if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else: else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
self.pin_weight = self.pin_weight.copy_(weight_tensor) weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor del weight_tensor
if self.bias_name is not None: if self.bias_name is not None:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter: if self.is_post_adapter:
assert adapter_block_index is not None assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1) self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else: else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1) self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
bias_tensor = self.lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
self.pin_bias.copy_(bias_tensor) bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias.copy_(bias_tensor)
del bias_tensor del bias_tensor
......
import os
import re import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
...@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False): def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load: if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.weight_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
if use_infer_dtype: with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = tensor.to(self.infer_dtype) tensor = lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else: else:
tensor = weight_dict[self.weight_name] tensor = weight_dict[self.weight_name]
return tensor return tensor
...@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else: else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1) self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
self.pin_weight = self.pin_weight.copy_(weight_tensor) weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor del weight_tensor
......
import os
import re import re
import torch import torch
from safetensors import safe_open
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER from lightx2v.utils.registry_factory import TENSOR_REGISTER
...@@ -39,9 +41,11 @@ class DefaultTensor: ...@@ -39,9 +41,11 @@ class DefaultTensor:
def _get_tensor(self, weight_dict=None, use_infer_dtype=False): def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load: if self.lazy_load:
tensor = self.lazy_load_file.get_tensor(self.tensor_name) lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
if use_infer_dtype: with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = tensor.to(self.infer_dtype) tensor = lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else: else:
tensor = weight_dict[self.tensor_name] tensor = weight_dict[self.tensor_name]
return tensor return tensor
...@@ -92,7 +96,8 @@ class DefaultTensor: ...@@ -92,7 +96,8 @@ class DefaultTensor:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1) self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else: else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1) self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
tensor = self.lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype) with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
self.pin_tensor = self.pin_tensor.copy_(tensor) tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
del tensor del tensor
import torch import torch
from einops import rearrange from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
from loguru import logger from loguru import logger
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
except ImportError:
flash_attn_varlen_qkvpacked_func = None
logger.info("flash_attn_varlen_qkvpacked_func not available")
try:
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError:
pad_input = None
unpad_input = None
logger.info("flash_attn.bert_padding not available")
try: try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError: except ImportError:
......
...@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if offload_granularity != "model": if offload_granularity != "model":
self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity) self.offload_manager = WeightAsyncStreamManager(offload_granularity=offload_granularity)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load and offload_granularity == "phase":
self.offload_manager.init_lazy_load(num_workers=self.config.get("num_disk_workers", 4))
def infer_with_blocks_offload(self, blocks, x, pre_infer_out): def infer_with_blocks_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
...@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
def infer_with_phases_offload(self, blocks, x, pre_infer_out): def infer_with_phases_offload(self, blocks, x, pre_infer_out):
for block_idx in range(len(blocks)): for block_idx in range(len(blocks)):
self.block_idx = block_idx self.block_idx = block_idx
if self.lazy_load:
next_prefetch = (block_idx + 1) % len(blocks)
self.offload_manager.start_prefetch_block(next_prefetch)
x = self.infer_phases(block_idx, blocks, x, pre_infer_out) x = self.infer_phases(block_idx, blocks, x, pre_infer_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del ( del (
...@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer): ...@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self.offload_manager.init_first_buffer(blocks) self.offload_manager.init_first_buffer(blocks)
next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx next_block_idx = (block_idx + 1) % len(blocks) if phase_idx == self.phases_num - 1 else block_idx
next_phase_idx = (phase_idx + 1) % self.phases_num next_phase_idx = (phase_idx + 1) % self.phases_num
if self.lazy_load:
if phase_idx == self.phases_num - 1:
self.offload_manager.swap_cpu_buffers()
self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks) self.offload_manager.prefetch_phase(next_block_idx, next_phase_idx, blocks)
with torch_device_module.stream(self.offload_manager.compute_stream): with torch_device_module.stream(self.offload_manager.compute_stream):
x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out) x = self.infer_phase(phase_idx, self.offload_manager.cuda_buffers[phase_idx], x, pre_infer_out)
......
...@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(q.device, non_blocking=True) cu_seqlens_qkv = torch.tensor([0, img_qkv_len], dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del norm1_out, norm1_weight, norm1_bias del norm1_out, shift_msa, scale_msa
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
...@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y = phase.ffn_0.apply(norm2_out) y = phase.ffn_0.apply(norm2_out)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del norm2_out, x, norm2_weight, norm2_bias del norm2_out, x
torch.cuda.empty_cache() torch.cuda.empty_cache()
y = torch.nn.functional.gelu(y, approximate="tanh") y = torch.nn.functional.gelu(y, approximate="tanh")
if self.clean_cuda_cache: if self.clean_cuda_cache:
......
...@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk( ...@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
rope_func, rope_func,
): ):
seq_len = cos_sin_cache.size(0) seq_len = cos_sin_cache.size(0)
x_q = torch.empty_like(xq)
x_k = torch.empty_like(xk)
xq_output_chunks = []
xk_output_chunks = []
for start in range(0, seq_len, chunk_size): for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len) end = min(start + chunk_size, seq_len)
xq_chunk = xq[start:end] xq_chunk = xq[start:end]
xk_chunk = xk[start:end] xk_chunk = xk[start:end]
cos_sin_chunk = cos_sin_cache[start:end] cos_sin_chunk = cos_sin_cache[start:end]
xq_chunk_out, xk_chunk_out = rope_func(xq_chunk, xk_chunk, cos_sin_chunk)
xq_chunk, xk_chunk = rope_func(xq_chunk, xk_chunk, cos_sin_chunk) x_q[start:end].copy_(xq_chunk_out, non_blocking=True)
xq_output_chunks.append(xq_chunk) x_k[start:end].copy_(xk_chunk_out, non_blocking=True)
xk_output_chunks.append(xk_chunk) del xq_chunk_out, xk_chunk_out
torch.cuda.empty_cache()
target_dtype = GET_DTYPE()
x_q = torch.cat(xq_output_chunks, dim=0) if x_q.dtype != target_dtype:
del xq_output_chunks x_q = x_q.to(target_dtype)
torch.cuda.empty_cache() if x_k.dtype != target_dtype:
x_k = x_k.to(target_dtype)
x_k = torch.cat(xk_output_chunks, dim=0)
del xk_output_chunks return x_q, x_k
torch.cuda.empty_cache()
return x_q.to(GET_DTYPE()), x_k.to(GET_DTYPE())
def apply_wan_rope_with_flashinfer( def apply_wan_rope_with_flashinfer(
......
...@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin): ...@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_files = [safetensors_path] safetensors_files = [safetensors_path]
if self.lazy_load: if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode" self.lazy_load_path = safetensors_path
self.lazy_load_path = safetensors_files[0] non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}")
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
...@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin): ...@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
def _load_quant_ckpt(self, unified_dtype, sensitive_layer): def _load_quant_ckpt(self, unified_dtype, sensitive_layer):
remove_keys = self.remove_keys if hasattr(self, "remove_keys") else [] remove_keys = self.remove_keys if hasattr(self, "remove_keys") else []
if self.config.get("dit_quantized_ckpt", None): if self.config.get("dit_quantized_ckpt", None):
safetensors_path = self.config["dit_quantized_ckpt"] safetensors_path = self.config["dit_quantized_ckpt"]
else: else:
...@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin): ...@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_path = os.path.dirname(safetensors_path) safetensors_path = os.path.dirname(safetensors_path)
if self.lazy_load: if self.lazy_load:
assert len(safetensors_files) == 1, "Only support single safetensors file in lazy load mode" self.lazy_load_path = safetensors_path
self.lazy_load_path = safetensors_files[0] non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}, Please check the lazy load model path")
weight_dict = {} weight_dict = {}
for safetensor_path in safetensors_files: for safetensor_path in safetensors_files:
...@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin): ...@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
self.post_infer = self.post_infer_class(self.config) self.post_infer = self.post_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config) self.transformer_infer = self.transformer_infer_class(self.config)
if hasattr(self.transformer_infer, "offload_manager"): if hasattr(self.transformer_infer, "offload_manager"):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers) self._init_offload_manager()
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers) def _init_offload_manager(self):
self.transformer_infer.offload_manager.init_cuda_buffer(self.transformer_weights.offload_block_cuda_buffers, self.transformer_weights.offload_phase_cuda_buffers)
if self.lazy_load:
self.transformer_infer.offload_manager.init_cpu_buffer(self.transformer_weights.offload_block_cpu_buffers, self.transformer_weights.offload_phase_cpu_buffers)
if self.config.get("warm_up_cpu_buffers", False):
self.transformer_infer.offload_manager.warm_up_cpu_buffers(self.transformer_weights.blocks_num)
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
......
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.utils.registry_factory import ( from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER,
...@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule): ...@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
if config.get("do_mm_calib", False): if config.get("do_mm_calib", False):
self.mm_type = "Calib" self.mm_type = "Calib"
self.lazy_load = self.config.get("lazy_load", False) self.lazy_load = self.config.get("lazy_load", False)
if not self.lazy_load:
self.lazy_load_file = None
else:
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
self.blocks = WeightModuleList( self.blocks = WeightModuleList(
[ [
WanTransformerAttentionBlock( WanTransformerAttentionBlock(
...@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule): ...@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False, create_cpu_buffer=False,
block_prefix="blocks", block_prefix="blocks",
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_path=lazy_load_path,
) )
for i in range(self.blocks_num) for i in range(self.blocks_num)
] ]
) )
self.register_offload_buffers(config) self.register_offload_buffers(config, lazy_load_path)
self.add_module("blocks", self.blocks) self.add_module("blocks", self.blocks)
# non blocks weights # non blocks weights
...@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule): ...@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias")) self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation")) self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def register_offload_buffers(self, config): def register_offload_buffers(self, config, lazy_load_path):
if config["cpu_offload"]: if config["cpu_offload"]:
if config["offload_granularity"] == "block": if config["offload_granularity"] == "block":
self.offload_blocks_num = 2 self.offload_blocks_num = 2
...@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule): ...@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False, create_cpu_buffer=False,
block_prefix="blocks", block_prefix="blocks",
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_path=lazy_load_path,
) )
for i in range(self.offload_blocks_num) for i in range(self.offload_blocks_num)
] ]
...@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule): ...@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=True, create_cpu_buffer=True,
block_prefix="blocks", block_prefix="blocks",
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_path=lazy_load_path,
) )
for i in range(self.offload_blocks_num) for i in range(self.offload_blocks_num)
] ]
...@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule): ...@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer=False, create_cpu_buffer=False,
block_prefix="blocks", block_prefix="blocks",
lazy_load=self.lazy_load, lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file, lazy_load_path=lazy_load_path,
).compute_phases ).compute_phases
self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers) self.add_module("offload_phase_cuda_buffers", self.offload_phase_cuda_buffers)
self.offload_block_cuda_buffers = None self.offload_block_cuda_buffers = None
if self.lazy_load: if self.lazy_load:
self.offload_phase_cpu_buffers = WanTransformerAttentionBlock( self.offload_phase_cpu_buffers = WeightModuleList(
block_index=0, [
task=self.task, WanTransformerAttentionBlock(
mm_type=self.mm_type, block_index=i,
config=self.config, task=self.task,
create_cuda_buffer=False, mm_type=self.mm_type,
create_cpu_buffer=True, config=self.config,
block_prefix="blocks", create_cuda_buffer=False,
lazy_load=self.lazy_load, create_cpu_buffer=True,
lazy_load_file=self.lazy_load_file, block_prefix="blocks",
).compute_phases lazy_load=self.lazy_load,
lazy_load_path=lazy_load_path,
).compute_phases
for i in range(2)
]
)
self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers) self.add_module("offload_phase_cpu_buffers", self.offload_phase_cpu_buffers)
self.offload_block_cpu_buffers = None self.offload_block_cpu_buffers = None
...@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
create_cpu_buffer=False, create_cpu_buffer=False,
block_prefix="blocks", block_prefix="blocks",
lazy_load=False, lazy_load=False,
lazy_load_file=None, lazy_load_path=None,
): ):
super().__init__() super().__init__()
self.block_index = block_index self.block_index = block_index
...@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule): ...@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
self.quant_method = config.get("quant_method", None) self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file if self.lazy_load:
self.lazy_load_file = lazy_load_path
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList( self.compute_phases = WeightModuleList(
[ [
......
...@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner): ...@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner):
del self.inputs del self.inputs
self.input_info = None self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.model if hasattr(self.model, "model") and len(self.model.model) == 2: # MultiModelStruct
for model in self.model.model:
if hasattr(model.transformer_infer, "offload_manager"):
del model.transformer_infer.offload_manager
torch.cuda.empty_cache()
gc.collect()
del model
else:
if hasattr(self.model.transformer_infer, "offload_manager"):
del self.model.transformer_infer.offload_manager
torch.cuda.empty_cache()
gc.collect()
del self.model
if self.config.get("do_mm_calib", False): if self.config.get("do_mm_calib", False):
calib_path = os.path.join(os.getcwd(), "calib.pt") calib_path = os.path.join(os.getcwd(), "calib.pt")
torch.save(CALIB, calib_path) torch.save(CALIB, calib_path)
......
...@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct): ...@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
self.to_cuda(model_index=1) self.to_cuda(model_index=1)
self.cur_model_index = 1 self.cur_model_index = 1
def infer(self, inputs):
self.get_current_model_index()
if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.model[self.cur_model_index].infer(inputs)
else:
if self.model[self.cur_model_index] is not None:
self.model[self.cur_model_index].infer(inputs)
else:
if self.cur_model_index == 0:
high_noise_model = WanDistillModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_noise_model.set_scheduler(self.scheduler)
self.model[0] = high_noise_model
self.model[0].infer(inputs)
elif self.cur_model_index == 1:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_noise_model.set_scheduler(self.scheduler)
self.model[1] = low_noise_model
self.model[1].infer(inputs)
@RUNNER_REGISTER("wan2.2_moe_distill") @RUNNER_REGISTER("wan2.2_moe_distill")
class Wan22MoeDistillRunner(WanDistillRunner): class Wan22MoeDistillRunner(WanDistillRunner):
...@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner): ...@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner):
raise FileNotFoundError(f"Low Noise Model does not find") raise FileNotFoundError(f"Low Noise Model does not find")
def load_transformer(self): def load_transformer(self):
use_high_lora, use_low_lora = False, False if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
if self.config.get("lora_configs") and self.config["lora_configs"]: use_high_lora, use_low_lora = False, False
for lora_config in self.config["lora_configs"]: if self.config.get("lora_configs") and self.config["lora_configs"]:
if lora_config.get("name", "") == "high_noise_model": for lora_config in self.config["lora_configs"]:
use_high_lora = True if lora_config.get("name", "") == "high_noise_model":
elif lora_config.get("name", "") == "low_noise_model": use_high_lora = True
use_low_lora = True elif lora_config.get("name", "") == "low_noise_model":
use_low_lora = True
if use_high_lora:
high_noise_model = WanModel( if use_high_lora:
self.high_noise_model_path, high_noise_model = WanModel(
self.config, self.high_noise_model_path,
self.init_device, self.config,
model_type="wan2.2_moe_high_noise", self.init_device,
) model_type="wan2.2_moe_high_noise",
high_lora_wrapper = WanLoraWrapper(high_noise_model) )
for lora_config in self.config["lora_configs"]: high_lora_wrapper = WanLoraWrapper(high_noise_model)
if lora_config.get("name", "") == "high_noise_model": for lora_config in self.config["lora_configs"]:
lora_path = lora_config["path"] if lora_config.get("name", "") == "high_noise_model":
strength = lora_config.get("strength", 1.0) lora_path = lora_config["path"]
lora_name = high_lora_wrapper.load_lora(lora_path) strength = lora_config.get("strength", 1.0)
high_lora_wrapper.apply_lora(lora_name, strength) lora_name = high_lora_wrapper.load_lora(lora_path)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}") high_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"High noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
high_noise_model = WanDistillModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
if use_low_lora:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = low_lora_wrapper.load_lora(lora_path)
low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
else: else:
high_noise_model = WanDistillModel( model_struct = MultiDistillModelStruct([None, None], self.config, self.config["boundary_step_index"])
self.high_noise_model_path, model_struct.low_noise_model_path = self.low_noise_model_path
self.config, model_struct.high_noise_model_path = self.high_noise_model_path
self.init_device, model_struct.init_device = self.init_device
model_type="wan2.2_moe_high_noise", return model_struct
)
if use_low_lora:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_lora_wrapper = WanLoraWrapper(low_noise_model)
for lora_config in self.config["lora_configs"]:
if lora_config.get("name", "") == "low_noise_model":
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = low_lora_wrapper.load_lora(lora_path)
low_lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Low noise model loaded LoRA: {lora_name} with strength: {strength}")
else:
low_noise_model = WanDistillModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
return MultiDistillModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary_step_index"])
def init_scheduler(self): def init_scheduler(self):
if self.config["feature_caching"] == "NoCaching": if self.config["feature_caching"] == "NoCaching":
......
...@@ -468,11 +468,37 @@ class MultiModelStruct: ...@@ -468,11 +468,37 @@ class MultiModelStruct:
def set_scheduler(self, shared_scheduler): def set_scheduler(self, shared_scheduler):
self.scheduler = shared_scheduler self.scheduler = shared_scheduler
for model in self.model: for model in self.model:
model.set_scheduler(shared_scheduler) if model is not None:
model.set_scheduler(shared_scheduler)
def infer(self, inputs): def infer(self, inputs):
self.get_current_model_index() self.get_current_model_index()
self.model[self.cur_model_index].infer(inputs) if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.model[self.cur_model_index].infer(inputs)
else:
if self.model[self.cur_model_index] is not None:
self.model[self.cur_model_index].infer(inputs)
else:
if self.cur_model_index == 0:
high_noise_model = WanModel(
self.high_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_high_noise",
)
high_noise_model.set_scheduler(self.scheduler)
self.model[0] = high_noise_model
self.model[0].infer(inputs)
elif self.cur_model_index == 1:
low_noise_model = WanModel(
self.low_noise_model_path,
self.config,
self.init_device,
model_type="wan2.2_moe_low_noise",
)
low_noise_model.set_scheduler(self.scheduler)
self.model[1] = low_noise_model
self.model[1].infer(inputs)
@ProfilingContext4DebugL2("Swtich models in infer_main costs") @ProfilingContext4DebugL2("Swtich models in infer_main costs")
def get_current_model_index(self): def get_current_model_index(self):
...@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner): ...@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner):
def load_transformer(self): def load_transformer(self):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model = WanModel( if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False):
self.high_noise_model_path, high_noise_model = WanModel(
self.config, self.high_noise_model_path,
self.init_device, self.config,
model_type="wan2.2_moe_high_noise", self.init_device,
) model_type="wan2.2_moe_high_noise",
low_noise_model = WanModel( )
self.low_noise_model_path, low_noise_model = WanModel(
self.config, self.low_noise_model_path,
self.init_device, self.config,
model_type="wan2.2_moe_low_noise", self.init_device,
) model_type="wan2.2_moe_low_noise",
)
if self.config.get("lora_configs") and self.config["lora_configs"]:
assert not self.config.get("dit_quantized", False)
for lora_config in self.config["lora_configs"]: if self.config.get("lora_configs") and self.config["lora_configs"]:
lora_path = lora_config["path"] assert not self.config.get("dit_quantized", False)
strength = lora_config.get("strength", 1.0)
base_name = os.path.basename(lora_path) for lora_config in self.config["lora_configs"]:
if base_name.startswith("high"): lora_path = lora_config["path"]
lora_wrapper = WanLoraWrapper(high_noise_model) strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path) base_name = os.path.basename(lora_path)
lora_wrapper.apply_lora(lora_name, strength) if base_name.startswith("high"):
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") lora_wrapper = WanLoraWrapper(high_noise_model)
elif base_name.startswith("low"): lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper = WanLoraWrapper(low_noise_model) lora_wrapper.apply_lora(lora_name, strength)
lora_name = lora_wrapper.load_lora(lora_path) logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
lora_wrapper.apply_lora(lora_name, strength) elif base_name.startswith("low"):
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}") lora_wrapper = WanLoraWrapper(low_noise_model)
else: lora_name = lora_wrapper.load_lora(lora_path)
raise ValueError(f"Unsupported LoRA path: {lora_path}") lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"]) else:
raise ValueError(f"Unsupported LoRA path: {lora_path}")
return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config["boundary"])
else:
model_struct = MultiModelStruct([None, None], self.config, self.config["boundary"])
model_struct.low_noise_model_path = self.low_noise_model_path
model_struct.high_noise_model_path = self.high_noise_model_path
model_struct.init_device = self.init_device
return model_struct
@RUNNER_REGISTER("wan2.2") @RUNNER_REGISTER("wan2.2")
......
import asyncio import asyncio
import threading
import time import time
from functools import wraps from functools import wraps
...@@ -10,6 +11,13 @@ from lightx2v.utils.envs import * ...@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v_platform.base.global_var import AI_DEVICE
torch_device_module = getattr(torch, AI_DEVICE) torch_device_module = getattr(torch, AI_DEVICE)
_excluded_time_local = threading.local()
def _get_excluded_time_stack():
if not hasattr(_excluded_time_local, "stack"):
_excluded_time_local.stack = []
return _excluded_time_local.stack
class _ProfilingContext: class _ProfilingContext:
...@@ -32,11 +40,14 @@ class _ProfilingContext: ...@@ -32,11 +40,14 @@ class _ProfilingContext:
def __enter__(self): def __enter__(self):
torch_device_module.synchronize() torch_device_module.synchronize()
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
_get_excluded_time_stack().append(0.0)
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize() torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time total_elapsed = time.perf_counter() - self.start_time
excluded = _get_excluded_time_stack().pop()
elapsed = total_elapsed - excluded
if self.enable_recorder and self.metrics_func: if self.enable_recorder and self.metrics_func:
if self.metrics_labels: if self.metrics_labels:
self.metrics_func.labels(*self.metrics_labels).observe(elapsed) self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
...@@ -49,11 +60,14 @@ class _ProfilingContext: ...@@ -49,11 +60,14 @@ class _ProfilingContext:
async def __aenter__(self): async def __aenter__(self):
torch_device_module.synchronize() torch_device_module.synchronize()
self.start_time = time.perf_counter() self.start_time = time.perf_counter()
_get_excluded_time_stack().append(0.0)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize() torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time total_elapsed = time.perf_counter() - self.start_time
excluded = _get_excluded_time_stack().pop()
elapsed = total_elapsed - excluded
if self.enable_recorder and self.metrics_func: if self.enable_recorder and self.metrics_func:
if self.metrics_labels: if self.metrics_labels:
self.metrics_func.labels(*self.metrics_labels).observe(elapsed) self.metrics_func.labels(*self.metrics_labels).observe(elapsed)
...@@ -103,6 +117,65 @@ class _NullContext: ...@@ -103,6 +117,65 @@ class _NullContext:
return func return func
class _ExcludedProfilingContext:
"""用于标记应该从外层 profiling 中排除的时间段"""
def __init__(self, name=None):
self.name = name
if dist.is_initialized():
self.rank_info = f"Rank {dist.get_rank()}"
else:
self.rank_info = "Single GPU"
def __enter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
stack = _get_excluded_time_stack()
for i in range(len(stack)):
stack[i] += elapsed
if self.name and CHECK_PROFILING_DEBUG_LEVEL(1):
logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)")
return False
async def __aenter__(self):
torch_device_module.synchronize()
self.start_time = time.perf_counter()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch_device_module.synchronize()
elapsed = time.perf_counter() - self.start_time
stack = _get_excluded_time_stack()
for i in range(len(stack)):
stack[i] += elapsed
if self.name and CHECK_PROFILING_DEBUG_LEVEL(1):
logger.info(f"[Profile-Excluded] {self.rank_info} - {self.name} cost {elapsed:.6f} seconds (excluded from outer profiling)")
return False
def __call__(self, func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with self:
return await func(*args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return sync_wrapper
class _ProfilingContextL1(_ProfilingContext): class _ProfilingContextL1(_ProfilingContext):
"""Level 1 profiling context with Level1_Log prefix.""" """Level 1 profiling context with Level1_Log prefix."""
...@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De ...@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
""" """
ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext # if user >= 1, enable profiling ProfilingContext4DebugL1 = _ProfilingContextL1 if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext # if user >= 1, enable profiling
ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext # if user >= 2, enable profiling ProfilingContext4DebugL2 = _ProfilingContextL2 if CHECK_PROFILING_DEBUG_LEVEL(2) else _NullContext # if user >= 2, enable profiling
ExcludedProfilingContext = _ExcludedProfilingContext if CHECK_PROFILING_DEBUG_LEVEL(1) else _NullContext
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment