Unverified Commit 2620398e authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

update disk offload (#589)

parent 22484a22
......@@ -755,7 +755,7 @@ def auto_configure(resolution):
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
quant_op_priority = ["vllm", "sgl", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
......@@ -890,10 +890,7 @@ def auto_configure(resolution):
)
def main():
with gr.Blocks(
title="Lightx2v (Lightweight Video Inference and Generation Engine)",
css="""
css = """
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
......@@ -961,10 +958,13 @@ def main():
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 LightX2V Video Generator")
"""
def main():
with gr.Blocks(title="Lightx2v (Lightweight Video Inference and Generation Engine)") as demo:
gr.Markdown(f"# 🎬 LightX2V Video Generator")
gr.HTML(f"<style>{css}</style>")
# Main layout: left and right columns
with gr.Row():
# Left: configuration and input area
......
......@@ -755,7 +755,7 @@ def auto_configure(resolution):
if is_ada_architecture_gpu():
quant_op_priority = ["q8f", "vllm", "sgl"]
else:
quant_op_priority = ["sgl", "vllm", "q8f"]
quant_op_priority = ["vllm", "sgl", "q8f"]
for op in attn_priority:
if dict(available_attn_ops).get(op):
......@@ -890,10 +890,7 @@ def auto_configure(resolution):
)
def main():
with gr.Blocks(
title="Lightx2v (轻量级视频推理和生成引擎)",
css="""
css = """
.main-content { max-width: 1600px; margin: auto; padding: 20px; }
.warning { color: #ff6b6b; font-weight: bold; }
......@@ -961,10 +958,13 @@ def main():
border-radius: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.1);
}
""",
) as demo:
gr.Markdown(f"# 🎬 LightX2V 视频生成器")
"""
def main():
with gr.Blocks(title="Lightx2v (轻量级视频推理和生成引擎)") as demo:
gr.Markdown(f"# 🎬 LightX2V 视频生成器")
gr.HTML(f"<style>{css}</style>")
# 主布局:左右分栏
with gr.Row():
# 左侧:配置和输入区域
......
......@@ -14,11 +14,11 @@
# Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/path/to/LightX2V
lightx2v_path=/data/video_gen/lightx2v_debug/LightX2V
# Model path configuration
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
model_path=/path/to/models
model_path=/models/
# Server configuration
server_name="0.0.0.0"
......
import time
from concurrent.futures import ThreadPoolExecutor
import torch
......@@ -116,12 +115,12 @@ class WeightAsyncStreamManager(object):
self.prefetch_futures.append(future)
def swap_cpu_buffers(self):
wait_start = time.time()
already_done = all(f.done() for f in self.prefetch_futures)
# wait_start = time.time()
# already_done = all(f.done() for f in self.prefetch_futures)
for f in self.prefetch_futures:
f.result()
wait_time = time.time() - wait_start
logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
# wait_time = time.time() - wait_start
# logger.debug(f"[Prefetch] block {self.prefetch_block_idx}: wait={wait_time:.3f}s, already_done={already_done}")
self.cpu_buffers = [self.cpu_buffers[1], self.cpu_buffers[0]]
def __del__(self):
......
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
......@@ -130,6 +131,9 @@ class MMWeight(MMWeightTemplate):
def _get_source_tensor(self, source_name, weight_dict=None):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
return lazy_load_file.get_tensor(source_name)
......@@ -150,6 +154,9 @@ class MMWeight(MMWeightTemplate):
def _load_cpu_pin_buffers(self):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
......@@ -210,7 +217,9 @@ class MMWeight(MMWeightTemplate):
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
......@@ -294,6 +303,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def _load_cuda_buffers(self, weight_dict):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
......@@ -334,6 +346,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
weight_tensor = source.get_tensor(self.weight_name)
......@@ -353,6 +368,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if self.bias_name is None:
return None
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
bias_tensor = source.get_tensor(self.bias_name)
......@@ -673,7 +691,9 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if self.weight_need_transpose:
......
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
......@@ -55,6 +56,9 @@ class LNWeightTemplate(metaclass=ABCMeta):
if name is None:
return None
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(name)
......@@ -155,6 +159,9 @@ class LNWeightTemplate(metaclass=ABCMeta):
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
......@@ -167,6 +174,9 @@ class LNWeightTemplate(metaclass=ABCMeta):
del weight_tensor
if self.bias_name is not None:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
assert adapter_block_index is not None
......
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
......@@ -48,6 +49,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.weight_name)
......@@ -111,6 +115,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
......
import os
import re
from pathlib import Path
import torch
from safetensors import safe_open
......@@ -41,6 +42,9 @@ class DefaultTensor:
def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name)
......@@ -96,6 +100,9 @@ class DefaultTensor:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
......
......@@ -14,8 +14,13 @@ import numpy as np
import torch
import torch.distributed as dist
import zmq
from bson import BSON
from loguru import logger
try:
from bson import BSON
except ImportError:
BSON = None
logger.warning("BSON is not installed")
from scipy.signal import resample
......
......@@ -168,17 +168,19 @@ class WanModel(CompiledMethodsMixin):
safetensors_path = self.model_path
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
if self.lazy_load:
self.lazy_load_path = safetensors_path
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.")
else:
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
if self.lazy_load:
self.lazy_load_path = safetensors_path
safetensors_files = [safetensors_path]
weight_dict = {}
for file_path in safetensors_files:
......@@ -210,18 +212,20 @@ class WanModel(CompiledMethodsMixin):
return weight_dict
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
if self.lazy_load:
self.lazy_load_path = safetensors_path
non_block_file = os.path.join(safetensors_path, "non_block.safetensors")
if os.path.exists(non_block_file):
safetensors_files = [non_block_file]
else:
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path. Lazy load mode only supports loading chunked model weights.")
raise ValueError(f"Non-block file not found in {safetensors_path}. Please check the model path.")
else:
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
if self.lazy_load:
self.lazy_load_path = safetensors_path
safetensors_files = [safetensors_path]
safetensors_path = os.path.dirname(safetensors_path)
weight_dict = {}
for safetensor_path in safetensors_files:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment