Commit 92539ed8 authored by gushiqiao's avatar gushiqiao
Browse files

Update gradio and offload

parent 8e941d39
...@@ -109,6 +109,24 @@ def get_cpu_memory(): ...@@ -109,6 +109,24 @@ def get_cpu_memory():
return available_bytes / 1024**3 return available_bytes / 1024**3
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
if hasattr(psutil, "virtual_memory"):
if os.name == "posix":
try:
os.system("sync")
except: # noqa
pass
except: # noqa
pass
def generate_unique_filename(base_dir="./saved_videos"): def generate_unique_filename(base_dir="./saved_videos"):
os.makedirs(base_dir, exist_ok=True) os.makedirs(base_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
...@@ -147,7 +165,6 @@ for op_name, is_installed in available_attn_ops: ...@@ -147,7 +165,6 @@ for op_name, is_installed in available_attn_ops:
def run_inference( def run_inference(
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -173,6 +190,8 @@ def run_inference( ...@@ -173,6 +190,8 @@ def run_inference(
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -181,6 +200,8 @@ def run_inference( ...@@ -181,6 +200,8 @@ def run_inference(
clean_cuda_cache, clean_cuda_cache,
image_path=None, image_path=None,
): ):
cleanup_memory()
quant_op = quant_op.split("(")[0].strip() quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
...@@ -192,7 +213,7 @@ def run_inference( ...@@ -192,7 +213,7 @@ def run_inference(
model_config = json.load(f) model_config = json.load(f)
if task == "t2v": if task == "t2v":
if model_type == "Wan2.1 1.3B": if model_size == "1.3b":
# 1.3B # 1.3B
coefficient = [ coefficient = [
[ [
...@@ -287,6 +308,7 @@ def run_inference( ...@@ -287,6 +308,7 @@ def run_inference(
needs_reinit = ( needs_reinit = (
lazy_load lazy_load
or unload_modules
or global_runner is None or global_runner is None
or current_config is None or current_config is None
or cur_dit_quant_scheme is None or cur_dit_quant_scheme is None
...@@ -325,6 +347,8 @@ def run_inference( ...@@ -325,6 +347,8 @@ def run_inference(
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")): if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f: with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f) quant_model_config = json.load(f)
else:
quant_model_config = {}
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None dit_quantized_ckpt = None
...@@ -355,6 +379,8 @@ def run_inference( ...@@ -355,6 +379,8 @@ def run_inference(
"coefficients": coefficient[0] if use_ret_steps else coefficient[1], "coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": use_ret_steps, "use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh, "teacache_thresh": teacache_thresh,
"t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules,
"t5_quantized": is_t5_quant, "t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt, "t5_quantized_ckpt": t5_quant_ckpt,
"t5_quant_scheme": t5_quant_scheme, "t5_quant_scheme": t5_quant_scheme,
...@@ -425,15 +451,25 @@ def run_inference( ...@@ -425,15 +451,25 @@ def run_inference(
asyncio.run(runner.run_pipeline()) asyncio.run(runner.run_pipeline())
if lazy_load: del config, args, model_config, quant_model_config
del runner if "dit_quantized_ckpt" in locals():
torch.cuda.empty_cache() del dit_quantized_ckpt
gc.collect() if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
cleanup_memory()
return save_video_path return save_video_path
def auto_configure(enable_auto_config, model_type, resolution): def handle_lazy_load_change(lazy_load_enabled):
"""Handle lazy_load checkbox change to automatically enable unload_modules"""
return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution):
default_config = { default_config = {
"torch_compile_val": False, "torch_compile_val": False,
"lazy_load_val": False, "lazy_load_val": False,
...@@ -443,6 +479,8 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -443,6 +479,8 @@ def auto_configure(enable_auto_config, model_type, resolution):
"cpu_offload_val": False, "cpu_offload_val": False,
"offload_granularity_val": "block", "offload_granularity_val": "block",
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_cpu_offload_val": False,
"unload_modules_val": False,
"t5_offload_granularity_val": "model", "t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1], "attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1], "quant_op_val": quant_op_choices[0][1],
...@@ -499,7 +537,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -499,7 +537,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else: else:
res = "480p" res = "480p"
if model_type in ["Wan2.1 14B"]: if model_size == "14b":
is_14b = True is_14b = True
else: else:
is_14b = False is_14b = False
...@@ -507,13 +545,14 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -507,13 +545,14 @@ def auto_configure(enable_auto_config, model_type, resolution):
if res == "720p" and is_14b: if res == "720p" and is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5}), (48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8}), (40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1}), (32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
( (
24, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -524,6 +563,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -524,6 +563,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -537,6 +577,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -537,6 +577,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
12, 12,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -552,6 +593,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -552,6 +593,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
8, 8,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -564,6 +606,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -564,6 +606,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
}, },
), ),
...@@ -572,13 +615,14 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -572,13 +615,14 @@ def auto_configure(enable_auto_config, model_type, resolution):
elif is_14b: elif is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2}), (48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5}), (40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8}), (24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
( (
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -591,6 +635,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -591,6 +635,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
( (
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -600,6 +645,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -600,6 +645,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True, "rotary_chunk_val": True,
"rotary_chunk_size_val": 10000, "rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
...@@ -607,6 +653,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -607,6 +653,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
if res == "540p" if res == "540p"
else { else {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -616,6 +663,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -616,6 +663,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
} }
), ),
...@@ -623,7 +671,17 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -623,7 +671,17 @@ def auto_configure(enable_auto_config, model_type, resolution):
] ]
else: else:
gpu_rules = {} gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
},
),
]
if is_14b: if is_14b:
cpu_rules = [ cpu_rules = [
...@@ -637,11 +695,22 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -637,11 +695,22 @@ def auto_configure(enable_auto_config, model_type, resolution):
"t5_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
}, },
), ),
] ]
else: else:
cpu_rules = {} cpu_rules = [
(64, {}),
(
16,
{
"t5_quant_scheme_val": quant_type,
"unload_modules_val": True,
"use_tiny_vae_val": True,
},
),
]
for threshold, updates in gpu_rules: for threshold, updates in gpu_rules:
if gpu_memory >= threshold: if gpu_memory >= threshold:
...@@ -680,20 +749,6 @@ def main(): ...@@ -680,20 +749,6 @@ def main():
with gr.Group(): with gr.Group():
gr.Markdown("## 📥 Input Parameters") gr.Markdown("## 📥 Input Parameters")
with gr.Row():
if task == "i2v":
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="Model Type",
)
else:
model_type = gr.Dropdown(
choices=["Wan2.1 14B", "Wan2.1 1.3B"],
value="Wan2.1 14B",
label="Model Type",
)
if task == "i2v": if task == "i2v":
with gr.Row(): with gr.Row():
image_path = gr.Image( image_path = gr.Image(
...@@ -849,6 +904,11 @@ def main(): ...@@ -849,6 +904,11 @@ def main():
info="Controls the chunk size for applying rotary embeddings. Larger values may improve performance but increase memory usage. Only effective if 'rotary_chunk' is checked.", info="Controls the chunk size for applying rotary embeddings. Larger values may improve performance but increase memory usage. Only effective if 'rotary_chunk' is checked.",
) )
unload_modules = gr.Checkbox(
label="Unload Modules",
value=False,
info="Unload modules (T5, CLIP, DIT, etc.) after inference to reduce GPU/CPU memory usage",
)
clean_cuda_cache = gr.Checkbox( clean_cuda_cache = gr.Checkbox(
label="Clean CUDA Memory Cache", label="Clean CUDA Memory Cache",
value=False, value=False,
...@@ -883,6 +943,12 @@ def main(): ...@@ -883,6 +943,12 @@ def main():
value=1.0, value=1.0,
info="Controls how much of the Dit model is offloaded to the CPU", info="Controls how much of the Dit model is offloaded to the CPU",
) )
t5_cpu_offload = gr.Checkbox(
label="T5 CPU Offloading",
value=False,
info="Offload the T5 Encoder model to CPU to reduce GPU memory usage",
)
t5_offload_granularity = gr.Dropdown( t5_offload_granularity = gr.Dropdown(
label="T5 Encoder Offload Granularity", label="T5 Encoder Offload Granularity",
choices=["model", "block"], choices=["model", "block"],
...@@ -971,7 +1037,7 @@ def main(): ...@@ -971,7 +1037,7 @@ def main():
enable_auto_config.change( enable_auto_config.change(
fn=auto_configure, fn=auto_configure,
inputs=[enable_auto_config, model_type, resolution], inputs=[enable_auto_config, resolution],
outputs=[ outputs=[
torch_compile, torch_compile,
lazy_load, lazy_load,
...@@ -981,6 +1047,8 @@ def main(): ...@@ -981,6 +1047,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -995,11 +1063,16 @@ def main(): ...@@ -995,11 +1063,16 @@ def main():
use_ret_steps, use_ret_steps,
], ],
) )
lazy_load.change(
fn=handle_lazy_load_change,
inputs=[lazy_load],
outputs=[unload_modules],
)
if task == "i2v": if task == "i2v":
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -1025,6 +1098,8 @@ def main(): ...@@ -1025,6 +1098,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -1039,7 +1114,6 @@ def main(): ...@@ -1039,7 +1114,6 @@ def main():
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -1065,6 +1139,8 @@ def main(): ...@@ -1065,6 +1139,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -1088,14 +1164,16 @@ if __name__ == "__main__": ...@@ -1088,14 +1164,16 @@ if __name__ == "__main__":
default="wan2.1", default="wan2.1",
help="Model class to use", help="Model class to use",
) )
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="Model type to use")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="Specify the task type. 'i2v' for image-to-video translation, 't2v' for text-to-video generation.") parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="Specify the task type. 'i2v' for image-to-video translation, 't2v' for text-to-video generation.")
parser.add_argument("--server_port", type=int, default=7862, help="Server port") parser.add_argument("--server_port", type=int, default=7862, help="Server port")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="Server ip")
args = parser.parse_args() args = parser.parse_args()
global model_path, model_cls global model_path, model_cls, model_size
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = args.model_cls
model_size = args.model_size
task = args.task task = args.task
main() main()
...@@ -109,6 +109,26 @@ def get_cpu_memory(): ...@@ -109,6 +109,26 @@ def get_cpu_memory():
return available_bytes / 1024**3 return available_bytes / 1024**3
def cleanup_memory():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
try:
import psutil
if hasattr(psutil, "virtual_memory"):
if os.name == "posix":
try:
os.system("sync")
except: # noqa
pass
except: # noqa
pass
def generate_unique_filename(base_dir="./saved_videos"): def generate_unique_filename(base_dir="./saved_videos"):
os.makedirs(base_dir, exist_ok=True) os.makedirs(base_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
...@@ -147,7 +167,6 @@ for op_name, is_installed in available_attn_ops: ...@@ -147,7 +167,6 @@ for op_name, is_installed in available_attn_ops:
def run_inference( def run_inference(
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -173,6 +192,8 @@ def run_inference( ...@@ -173,6 +192,8 @@ def run_inference(
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -181,6 +202,8 @@ def run_inference( ...@@ -181,6 +202,8 @@ def run_inference(
clean_cuda_cache, clean_cuda_cache,
image_path=None, image_path=None,
): ):
cleanup_memory()
quant_op = quant_op.split("(")[0].strip() quant_op = quant_op.split("(")[0].strip()
attention_type = attention_type.split("(")[0].strip() attention_type = attention_type.split("(")[0].strip()
...@@ -192,7 +215,7 @@ def run_inference( ...@@ -192,7 +215,7 @@ def run_inference(
model_config = json.load(f) model_config = json.load(f)
if task == "t2v": if task == "t2v":
if model_type == "Wan2.1 1.3B": if model_size == "1.3b":
# 1.3B # 1.3B
coefficient = [ coefficient = [
[ [
...@@ -287,6 +310,7 @@ def run_inference( ...@@ -287,6 +310,7 @@ def run_inference(
needs_reinit = ( needs_reinit = (
lazy_load lazy_load
or unload_modules
or global_runner is None or global_runner is None
or current_config is None or current_config is None
or cur_dit_quant_scheme is None or cur_dit_quant_scheme is None
...@@ -325,6 +349,8 @@ def run_inference( ...@@ -325,6 +349,8 @@ def run_inference(
if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")): if os.path.exists(os.path.join(dit_quantized_ckpt, "config.json")):
with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f: with open(os.path.join(dit_quantized_ckpt, "config.json"), "r") as f:
quant_model_config = json.load(f) quant_model_config = json.load(f)
else:
quant_model_config = {}
else: else:
mm_type = "Default" mm_type = "Default"
dit_quantized_ckpt = None dit_quantized_ckpt = None
...@@ -355,6 +381,8 @@ def run_inference( ...@@ -355,6 +381,8 @@ def run_inference(
"coefficients": coefficient[0] if use_ret_steps else coefficient[1], "coefficients": coefficient[0] if use_ret_steps else coefficient[1],
"use_ret_steps": use_ret_steps, "use_ret_steps": use_ret_steps,
"teacache_thresh": teacache_thresh, "teacache_thresh": teacache_thresh,
"t5_cpu_offload": t5_cpu_offload,
"unload_modules": unload_modules,
"t5_quantized": is_t5_quant, "t5_quantized": is_t5_quant,
"t5_quantized_ckpt": t5_quant_ckpt, "t5_quantized_ckpt": t5_quant_ckpt,
"t5_quant_scheme": t5_quant_scheme, "t5_quant_scheme": t5_quant_scheme,
...@@ -425,15 +453,25 @@ def run_inference( ...@@ -425,15 +453,25 @@ def run_inference(
asyncio.run(runner.run_pipeline()) asyncio.run(runner.run_pipeline())
if lazy_load: del config, args, model_config, quant_model_config
del runner if "dit_quantized_ckpt" in locals():
torch.cuda.empty_cache() del dit_quantized_ckpt
gc.collect() if "t5_quant_ckpt" in locals():
del t5_quant_ckpt
if "clip_quant_ckpt" in locals():
del clip_quant_ckpt
cleanup_memory()
return save_video_path return save_video_path
def auto_configure(enable_auto_config, model_type, resolution): def handle_lazy_load_change(lazy_load_enabled):
"""Handle lazy_load checkbox change to automatically enable unload_modules"""
return gr.update(value=lazy_load_enabled)
def auto_configure(enable_auto_config, resolution):
default_config = { default_config = {
"torch_compile_val": False, "torch_compile_val": False,
"lazy_load_val": False, "lazy_load_val": False,
...@@ -443,6 +481,8 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -443,6 +481,8 @@ def auto_configure(enable_auto_config, model_type, resolution):
"cpu_offload_val": False, "cpu_offload_val": False,
"offload_granularity_val": "block", "offload_granularity_val": "block",
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_cpu_offload_val": False,
"unload_modules_val": False,
"t5_offload_granularity_val": "model", "t5_offload_granularity_val": "model",
"attention_type_val": attn_op_choices[0][1], "attention_type_val": attn_op_choices[0][1],
"quant_op_val": quant_op_choices[0][1], "quant_op_val": quant_op_choices[0][1],
...@@ -499,7 +539,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -499,7 +539,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
else: else:
res = "480p" res = "480p"
if model_type in ["Wan2.1 14B"]: if model_size == "14b":
is_14b = True is_14b = True
else: else:
is_14b = False is_14b = False
...@@ -507,13 +547,14 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -507,13 +547,14 @@ def auto_configure(enable_auto_config, model_type, resolution):
if res == "720p" and is_14b: if res == "720p" and is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.5}), (48, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.8}), (40, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
(32, {"cpu_offload_val": True, "offload_ratio_val": 1}), (32, {"cpu_offload_val": True, "offload_ratio_val": 1, "t5_cpu_offload_val": True}),
( (
24, 24,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -524,6 +565,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -524,6 +565,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -537,6 +579,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -537,6 +579,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
12, 12,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -552,6 +595,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -552,6 +595,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
8, 8,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -564,6 +608,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -564,6 +608,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
}, },
), ),
...@@ -572,13 +617,14 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -572,13 +617,14 @@ def auto_configure(enable_auto_config, model_type, resolution):
elif is_14b: elif is_14b:
gpu_rules = [ gpu_rules = [
(80, {}), (80, {}),
(48, {"cpu_offload_val": True, "offload_ratio_val": 0.2}), (48, {"cpu_offload_val": True, "offload_ratio_val": 0.2, "t5_cpu_offload_val": True}),
(40, {"cpu_offload_val": True, "offload_ratio_val": 0.5}), (40, {"cpu_offload_val": True, "offload_ratio_val": 0.5, "t5_cpu_offload_val": True}),
(24, {"cpu_offload_val": True, "offload_ratio_val": 0.8}), (24, {"cpu_offload_val": True, "offload_ratio_val": 0.8, "t5_cpu_offload_val": True}),
( (
16, 16,
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -591,6 +637,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -591,6 +637,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
( (
{ {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -600,6 +647,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -600,6 +647,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"rotary_chunk_val": True, "rotary_chunk_val": True,
"rotary_chunk_size_val": 10000, "rotary_chunk_size_val": 10000,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
...@@ -607,6 +655,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -607,6 +655,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
if res == "540p" if res == "540p"
else { else {
"cpu_offload_val": True, "cpu_offload_val": True,
"t5_cpu_offload_val": True,
"offload_ratio_val": 1, "offload_ratio_val": 1,
"t5_offload_granularity_val": "block", "t5_offload_granularity_val": "block",
"precision_mode_val": "bf16", "precision_mode_val": "bf16",
...@@ -616,6 +665,7 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -616,6 +665,7 @@ def auto_configure(enable_auto_config, model_type, resolution):
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"dit_quant_scheme_val": quant_type, "dit_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
"use_tiny_vae_val": True, "use_tiny_vae_val": True,
} }
), ),
...@@ -623,7 +673,17 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -623,7 +673,17 @@ def auto_configure(enable_auto_config, model_type, resolution):
] ]
else: else:
gpu_rules = {} gpu_rules = [
(24, {}),
(
8,
{
"t5_cpu_offload_val": True,
"t5_offload_granularity_val": "block",
"t5_quant_scheme_val": quant_type,
},
),
]
if is_14b: if is_14b:
cpu_rules = [ cpu_rules = [
...@@ -637,11 +697,22 @@ def auto_configure(enable_auto_config, model_type, resolution): ...@@ -637,11 +697,22 @@ def auto_configure(enable_auto_config, model_type, resolution):
"t5_quant_scheme_val": quant_type, "t5_quant_scheme_val": quant_type,
"clip_quant_scheme_val": quant_type, "clip_quant_scheme_val": quant_type,
"lazy_load_val": True, "lazy_load_val": True,
"unload_modules_val": True,
}, },
), ),
] ]
else: else:
cpu_rules = {} cpu_rules = [
(64, {}),
(
16,
{
"t5_quant_scheme_val": quant_type,
"unload_modules_val": True,
"use_tiny_vae_val": True,
},
),
]
for threshold, updates in gpu_rules: for threshold, updates in gpu_rules:
if gpu_memory >= threshold: if gpu_memory >= threshold:
...@@ -680,20 +751,6 @@ def main(): ...@@ -680,20 +751,6 @@ def main():
with gr.Group(): with gr.Group():
gr.Markdown("## 📥 输入参数") gr.Markdown("## 📥 输入参数")
with gr.Row():
if task == "i2v":
model_type = gr.Dropdown(
choices=["Wan2.1 14B"],
value="Wan2.1 14B",
label="模型类型",
)
else:
model_type = gr.Dropdown(
choices=["Wan2.1 14B", "Wan2.1 1.3B"],
value="Wan2.1 14B",
label="模型类型",
)
if task == "i2v": if task == "i2v":
with gr.Row(): with gr.Row():
image_path = gr.Image( image_path = gr.Image(
...@@ -846,7 +903,11 @@ def main(): ...@@ -846,7 +903,11 @@ def main():
step=100, step=100,
info="控制应用旋转编码的块大小。较大的值可能提高性能但增加内存使用。仅在'rotary_chunk'勾选时有效。", info="控制应用旋转编码的块大小。较大的值可能提高性能但增加内存使用。仅在'rotary_chunk'勾选时有效。",
) )
unload_modules = gr.Checkbox(
label="卸载模块",
value=False,
info="推理后卸载模块(T5、CLIP、DIT等)以减少GPU/CPU内存使用",
)
clean_cuda_cache = gr.Checkbox( clean_cuda_cache = gr.Checkbox(
label="清理CUDA内存缓存", label="清理CUDA内存缓存",
value=False, value=False,
...@@ -881,6 +942,11 @@ def main(): ...@@ -881,6 +942,11 @@ def main():
value=1.0, value=1.0,
info="控制将多少Dit模型卸载到CPU", info="控制将多少Dit模型卸载到CPU",
) )
t5_cpu_offload = gr.Checkbox(
label="T5 CPU卸载",
value=False,
info="将T5编码器模型卸载到CPU以减少GPU内存使用",
)
t5_offload_granularity = gr.Dropdown( t5_offload_granularity = gr.Dropdown(
label="T5编码器卸载粒度", label="T5编码器卸载粒度",
choices=["model", "block"], choices=["model", "block"],
...@@ -969,7 +1035,7 @@ def main(): ...@@ -969,7 +1035,7 @@ def main():
enable_auto_config.change( enable_auto_config.change(
fn=auto_configure, fn=auto_configure,
inputs=[enable_auto_config, model_type, resolution], inputs=[enable_auto_config, resolution],
outputs=[ outputs=[
torch_compile, torch_compile,
lazy_load, lazy_load,
...@@ -979,6 +1045,8 @@ def main(): ...@@ -979,6 +1045,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -993,11 +1061,16 @@ def main(): ...@@ -993,11 +1061,16 @@ def main():
use_ret_steps, use_ret_steps,
], ],
) )
lazy_load.change(
fn=handle_lazy_load_change,
inputs=[lazy_load],
outputs=[unload_modules],
)
if task == "i2v": if task == "i2v":
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -1023,6 +1096,8 @@ def main(): ...@@ -1023,6 +1096,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -1037,7 +1112,6 @@ def main(): ...@@ -1037,7 +1112,6 @@ def main():
infer_btn.click( infer_btn.click(
fn=run_inference, fn=run_inference,
inputs=[ inputs=[
model_type,
prompt, prompt,
negative_prompt, negative_prompt,
save_video_path, save_video_path,
...@@ -1063,6 +1137,8 @@ def main(): ...@@ -1063,6 +1137,8 @@ def main():
cpu_offload, cpu_offload,
offload_granularity, offload_granularity,
offload_ratio, offload_ratio,
t5_cpu_offload,
unload_modules,
t5_offload_granularity, t5_offload_granularity,
attention_type, attention_type,
quant_op, quant_op,
...@@ -1086,14 +1162,16 @@ if __name__ == "__main__": ...@@ -1086,14 +1162,16 @@ if __name__ == "__main__":
default="wan2.1", default="wan2.1",
help="要使用的模型类别", help="要使用的模型类别",
) )
parser.add_argument("--model_size", type=str, required=True, choices=["14b", "1.3b"], help="模型大小:14b 或 1.3b")
parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="指定任务类型。'i2v'用于图像到视频转换,'t2v'用于文本到视频生成。") parser.add_argument("--task", type=str, required=True, choices=["i2v", "t2v"], help="指定任务类型。'i2v'用于图像到视频转换,'t2v'用于文本到视频生成。")
parser.add_argument("--server_port", type=int, default=7862, help="服务器端口") parser.add_argument("--server_port", type=int, default=7862, help="服务器端口")
parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP") parser.add_argument("--server_name", type=str, default="0.0.0.0", help="服务器IP")
args = parser.parse_args() args = parser.parse_args()
global model_path, model_cls global model_path, model_cls, model_size
model_path = args.model_path model_path = args.model_path
model_cls = args.model_cls model_cls = args.model_cls
model_size = args.model_size
task = args.task task = args.task
main() main()
...@@ -15,16 +15,19 @@ ...@@ -15,16 +15,19 @@
# Lightx2v project root directory path # Lightx2v project root directory path
# Example: /home/user/lightx2v or /data/video_gen/lightx2v # Example: /home/user/lightx2v or /data/video_gen/lightx2v
lightx2v_path=/path/to/lightx2v lightx2v_path=/path/to/lightx2v
# Model path configuration # Model path configuration
# Image-to-video model path (for i2v tasks) # Image-to-video model path (for i2v tasks)
# Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v # Example: /path/to/Wan2.1-I2V-14B-720P-Lightx2v
i2v_model_path=/path/to/Wan2.1-I2V-14B-720P-Lightx2v i2v_model_path=/path/to/Wan2.1-I2V-14B-720P-Lightx2v-Step-Distill
# Text-to-video model path (for t2v tasks) # Text-to-video model path (for t2v tasks)
# Example: /path/to/Wan2.1-T2V-1.3B # Example: /path/to/Wan2.1-T2V-1.3B
t2v_model_path=/path/to/Wan2.1-T2V-1.3B t2v_model_path=/path/to/Wan2.1-T2V-1.3B
# Model size configuration
# Default model size (14b, 1.3b)
model_size="14b"
# Server configuration # Server configuration
server_name="0.0.0.0" server_name="0.0.0.0"
server_port=8032 server_port=8032
...@@ -65,6 +68,10 @@ while [[ $# -gt 0 ]]; do ...@@ -65,6 +68,10 @@ while [[ $# -gt 0 ]]; do
export CUDA_VISIBLE_DEVICES=$gpu_id export CUDA_VISIBLE_DEVICES=$gpu_id
shift 2 shift 2
;; ;;
--model_size)
model_size="$2"
shift 2
;;
--help) --help)
echo "🎬 Lightx2v Gradio Demo Startup Script" echo "🎬 Lightx2v Gradio Demo Startup Script"
echo "==========================================" echo "=========================================="
...@@ -79,6 +86,10 @@ while [[ $# -gt 0 ]]; do ...@@ -79,6 +86,10 @@ while [[ $# -gt 0 ]]; do
echo " en: English interface" echo " en: English interface"
echo " --port PORT Server port (default: 8032)" echo " --port PORT Server port (default: 8032)"
echo " --gpu GPU_ID GPU device ID (default: 0)" echo " --gpu GPU_ID GPU device ID (default: 0)"
echo " --model_size MODEL_SIZE"
echo " Model size (default: 14b)"
echo " 14b: 14 billion parameters model"
echo " 1.3b: 1.3 billion parameters model"
echo " --help Show this help message" echo " --help Show this help message"
echo "" echo ""
echo "🚀 Usage examples:" echo "🚀 Usage examples:"
...@@ -86,6 +97,8 @@ while [[ $# -gt 0 ]]; do ...@@ -86,6 +97,8 @@ while [[ $# -gt 0 ]]; do
echo " $0 --task i2v --lang zh --port 8032 # Start with specified parameters" echo " $0 --task i2v --lang zh --port 8032 # Start with specified parameters"
echo " $0 --task t2v --lang en --port 7860 # Text-to-video with English interface" echo " $0 --task t2v --lang en --port 7860 # Text-to-video with English interface"
echo " $0 --task i2v --gpu 1 --port 8032 # Use GPU 1" echo " $0 --task i2v --gpu 1 --port 8032 # Use GPU 1"
echo " $0 --task t2v --model_size 1.3b # Use 1.3B model"
echo " $0 --task i2v --model_size 14b # Use 14B model"
echo "" echo ""
echo "📝 Notes:" echo "📝 Notes:"
echo " - Edit script to configure model paths before first use" echo " - Edit script to configure model paths before first use"
...@@ -113,6 +126,12 @@ if [[ "$lang" != "zh" && "$lang" != "en" ]]; then ...@@ -113,6 +126,12 @@ if [[ "$lang" != "zh" && "$lang" != "en" ]]; then
exit 1 exit 1
fi fi
# Validate model size
if [[ "$model_size" != "14b" && "$model_size" != "1.3b" ]]; then
echo "Error: Model size must be '14b' or '1.3b'"
exit 1
fi
# Select model path based on task type # Select model path based on task type
if [[ "$task" == "i2v" ]]; then if [[ "$task" == "i2v" ]]; then
model_path=$i2v_model_path model_path=$i2v_model_path
...@@ -161,6 +180,7 @@ echo "==========================================" ...@@ -161,6 +180,7 @@ echo "=========================================="
echo "📁 Project path: $lightx2v_path" echo "📁 Project path: $lightx2v_path"
echo "🤖 Model path: $model_path" echo "🤖 Model path: $model_path"
echo "🎯 Task type: $task" echo "🎯 Task type: $task"
echo "🤖 Model size: $model_size"
echo "🌏 Interface language: $lang" echo "🌏 Interface language: $lang"
echo "🖥️ GPU device: $gpu_id" echo "🖥️ GPU device: $gpu_id"
echo "🌐 Server address: $server_name:$server_port" echo "🌐 Server address: $server_name:$server_port"
...@@ -190,7 +210,8 @@ python $demo_file \ ...@@ -190,7 +210,8 @@ python $demo_file \
--model_path "$model_path" \ --model_path "$model_path" \
--task "$task" \ --task "$task" \
--server_name "$server_name" \ --server_name "$server_name" \
--server_port "$server_port" --server_port "$server_port" \
--model_size "$model_size"
# Display final system resource usage # Display final system resource usage
echo "" echo ""
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
"sample_shift": 5, "sample_shift": 5,
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"t5_cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
......
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"t5_cpu_offload": true,
"t5_offload_granularity": "block",
"t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
"t5_quant_scheme": "fp8",
"unload_modules": true,
"use_tiling_vae": true
}
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "block", "offload_granularity": "block",
"t5_cpu_offload": true,
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm", "mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm",
"weight_auto_quant": true "weight_auto_quant": true
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false "weight_auto_quant": false
}, },
"t5_cpu_offload": true,
"t5_quantized": true, "t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth", "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
"t5_quant_scheme": "fp8", "t5_quant_scheme": "fp8",
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm", "mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm",
"weight_auto_quant": false "weight_auto_quant": false
}, },
"t5_cpu_offload": true,
"t5_quantized": true, "t5_quantized": true,
"t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth", "t5_quantized_ckpt": "/path/to/models_t5_umt5-xxl-enc-fp8.pth",
"t5_quant_scheme": "fp8", "t5_quant_scheme": "fp8",
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"t5_cpu_offload": true,
"t5_offload_granularity": "block", "t5_offload_granularity": "block",
"dit_quantized_ckpt": "/path/to/dit_int8", "dit_quantized_ckpt": "/path/to/dit_int8",
"mm_config": { "mm_config": {
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
"sample_shift": 8, "sample_shift": 8,
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
"t5_cpu_offload": true,
"offload_granularity": "phase", "offload_granularity": "phase",
"dit_quantized_ckpt": "/path/to/dit_int8", "dit_quantized_ckpt": "/path/to/dit_int8",
"mm_config": { "mm_config": {
......
...@@ -15,7 +15,7 @@ This project contains two main demo files: ...@@ -15,7 +15,7 @@ This project contains two main demo files:
- Python 3.10+ (recommended) - Python 3.10+ (recommended)
- CUDA 12.4+ (recommended) - CUDA 12.4+ (recommended)
- At least 8GB GPU VRAM - At least 8GB GPU VRAM
- At least 16GB system memory - At least 16GB system memory (preferably at least 32GB)
- At least 128GB SSD solid-state drive (**💾 Strongly recommend using SSD solid-state drives to store model files! During "lazy loading" startup, significantly improves model loading speed and inference performance**) - At least 128GB SSD solid-state drive (**💾 Strongly recommend using SSD solid-state drives to store model files! During "lazy loading" startup, significantly improves model loading speed and inference performance**)
### Install Dependencies ### Install Dependencies
...@@ -80,8 +80,9 @@ vim run_gradio.sh ...@@ -80,8 +80,9 @@ vim run_gradio.sh
bash run_gradio.sh bash run_gradio.sh
# 3. Or start with parameters (recommended) # 3. Or start with parameters (recommended)
bash run_gradio.sh --task i2v --lang en --port 8032 bash run_gradio.sh --task i2v --lang en --model_size 14b --port 8032
# bash run_gradio.sh --task t2v --lang en --port 8032 # bash run_gradio.sh --task i2v --lang en --model_size 14b --port 8032
# bash run_gradio.sh --task i2v --lang en --model_size 1.3b --port 8032
``` ```
#### Method 2: Direct Command Line Startup #### Method 2: Direct Command Line Startup
...@@ -90,6 +91,7 @@ bash run_gradio.sh --task i2v --lang en --port 8032 ...@@ -90,6 +91,7 @@ bash run_gradio.sh --task i2v --lang en --port 8032
```bash ```bash
python gradio_demo.py \ python gradio_demo.py \
--model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \ --model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \
--model_size 14b \
--task i2v \ --task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -99,6 +101,7 @@ python gradio_demo.py \ ...@@ -99,6 +101,7 @@ python gradio_demo.py \
```bash ```bash
python gradio_demo.py \ python gradio_demo.py \
--model_path /path/to/Wan2.1-T2V-1.3B \ --model_path /path/to/Wan2.1-T2V-1.3B \
--model_size 1.3b \
--task t2v \ --task t2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -108,6 +111,7 @@ python gradio_demo.py \ ...@@ -108,6 +111,7 @@ python gradio_demo.py \
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo_zh.py \
--model_path /path/to/model \ --model_path /path/to/model \
--model_size 14b \
--task i2v \ --task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -119,6 +123,7 @@ python gradio_demo_zh.py \ ...@@ -119,6 +123,7 @@ python gradio_demo_zh.py \
|-----------|------|----------|---------|-------------| |-----------|------|----------|---------|-------------|
| `--model_path` | str | ✅ | - | Model folder path | | `--model_path` | str | ✅ | - | Model folder path |
| `--model_cls` | str | ❌ | wan2.1 | Model class (currently only supports wan2.1) | | `--model_cls` | str | ❌ | wan2.1 | Model class (currently only supports wan2.1) |
| `--model_size` | str | ✅ | - | Model size: `14b(t2v or i2v)` or `1.3b(t2v)` |
| `--task` | str | ✅ | - | Task type: `i2v` (image-to-video) or `t2v` (text-to-video) | | `--task` | str | ✅ | - | Task type: `i2v` (image-to-video) or `t2v` (text-to-video) |
| `--server_port` | int | ❌ | 7862 | Server port | | `--server_port` | int | ❌ | 7862 | Server port |
| `--server_name` | str | ❌ | 0.0.0.0 | Server IP address | | `--server_name` | str | ❌ | 0.0.0.0 | Server IP address |
...@@ -127,10 +132,6 @@ python gradio_demo_zh.py \ ...@@ -127,10 +132,6 @@ python gradio_demo_zh.py \
### Basic Settings ### Basic Settings
#### Model Type Selection
- **Wan2.1 14B**: Large parameter count, high generation quality, suitable for high-quality video generation
- **Wan2.1 1.3B**: Lightweight model, fast speed, suitable for rapid prototyping and testing
#### Input Parameters #### Input Parameters
- **Prompt**: Describe the expected video content - **Prompt**: Describe the expected video content
- **Negative Prompt**: Specify elements you don't want to appear - **Negative Prompt**: Specify elements you don't want to appear
...@@ -217,7 +218,7 @@ lightx2v/app/ ...@@ -217,7 +218,7 @@ lightx2v/app/
## 🎨 Interface Description ## 🎨 Interface Description
### Basic Settings Tab ### Basic Settings Tab
- **Input Parameters**: Model type, prompts, resolution, and other basic settings - **Input Parameters**: Prompts, resolution, and other basic settings
- **Video Parameters**: FPS, frame count, CFG, and other video generation parameters - **Video Parameters**: FPS, frame count, CFG, and other video generation parameters
- **Output Settings**: Video save path configuration - **Output Settings**: Video save path configuration
......
# 参数卸载 # Lightx2v Parameter Offloading Mechanism Documentation
xxx ## 📖 Overview
Lightx2v implements an advanced parameter offloading mechanism designed for large model inference under limited hardware resources. This system provides excellent speed-memory balance through intelligent management of model weights across different memory hierarchies.
**Core Features:**
- **Block/Phase Offloading**: Efficiently manages model weights in block/phase units for optimal memory usage
- **Block**: Basic computational unit of Transformer models, containing complete Transformer layers (self-attention, cross-attention, feed-forward networks, etc.), serving as larger memory management units
- **Phase**: Finer-grained computational stages within blocks, containing individual computational components (such as self-attention, cross-attention, feed-forward networks, etc.), providing more precise memory control
- **Multi-level Storage Support**: GPU → CPU → Disk hierarchy with intelligent caching
- **Asynchronous Operations**: Uses CUDA streams to overlap computation and data transfer
- **Disk/NVMe Serialization**: Supports secondary storage when memory is insufficient
## 🎯 Offloading Strategies
### Strategy 1: GPU-CPU Block/Phase Offloading
**Applicable Scenarios**: GPU VRAM insufficient but system memory adequate
**Working Principle**: Manages model weights in block or phase units between GPU and CPU memory, utilizing CUDA streams to overlap computation and data transfer. Blocks contain complete Transformer layers, while phases are individual computational components within blocks.
**Block vs Phase Explanation**:
- **Block Granularity**: Larger memory management units containing complete Transformer layers (self-attention, cross-attention, feed-forward networks, etc.), suitable for memory-sufficient scenarios, reducing management overhead
- **Phase Granularity**: Finer-grained memory management containing individual computational components (such as self-attention, cross-attention, feed-forward networks, etc.), suitable for memory-constrained scenarios, providing more flexible memory control
```
GPU-CPU Block/Phase Offloading Workflow:
╔═════════════════════════════════════════════════════════════════╗
║ 🎯 GPU Memory ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 Current │ │ ⏳ Prefetch │ │ 📤 To Offload │ ║
║ │ block/phase N │◄──►│ block/phase N+1 │◄──►│ block/phase N-1 │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ │ │ │ ║
║ ▼ ▼ ▼ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ Compute │ │ GPU Load │ │ CPU Load │ ║
║ │ Stream │ │ Stream │ │ Stream │ ║
║ │(priority=-1)│ │ (priority=0) │ │ (priority=0) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💾 CPU Memory ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 📥 Cache │ │ 📥 Cache │ │ 📥 Cache │ │ 📥 Cache │ ║
║ │ block/phase │ │ block/phase │ │ block/phase │ │ block/phase │ ║
║ │ N-2 │ │ N-1 │ │ N │ │ N+1 │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ▲ ▲ ▲ ▲ ║
║ │ │ │ │ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ CPU Load │ │ CPU Load │ │ CPU Load │ │ CPU Load │ ║
║ │ Stream │ │ Stream │ │ Stream │ │ Stream │ ║
║ │(priority=0) │ │(priority=0) │ │(priority=0) │ │(priority=0) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ║
║ 💡 CPU memory stores multiple blocks/phases, forming cache pool ║
║ 🔄 GPU load stream prefetches from CPU cache, CPU load stream ║
║ offloads to CPU cache ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🔄 Swap Operation Flow ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ Step 1: Parallel Execution Phase ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 Compute │ │ ⏳ Prefetch │ │ 📤 Offload │ ║
║ │ block/phase N │ │ block/phase N+1 │ │ block/phase N-1 │ ║
║ │ (Compute Stream)│ │ (GPU Load Stream)│ │ (CPU Load Stream)│ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ ║
║ Step 2: Swap Rotation Phase ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 Compute │ │ ⏳ Prefetch │ │ 📤 Offload │ ║
║ │ block/phase N+1 │ │ block/phase N+2 │ │ block/phase N │ ║
║ │ (Compute Stream)│ │ (GPU Load Stream)│ │ (CPU Load Stream)│ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ ║
║ Swap Concept: Achieves continuous computation through position ║
║ rotation, avoiding repeated loading/unloading ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💡 Swap Core Concept ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 🔄 Traditional vs Swap Method Comparison: ║
║ ║
║ Traditional Method: ║
║ ┌─────────────┐ ┌──────────┐ ┌─────────┐ ┌────────┐ ║
║ │ Compute N │───►│ Offload N│───►│ Load N+1│───►│Compute │ ║
║ │ │ │ │ │ │ │N+1 │ ║
║ └─────────────┘ └──────────┘ └─────────┘ └────────┘ ║
║ ❌ Serial execution, waiting time, low efficiency ║
║ ║
║ Swap Method: ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ Compute N │ │ Prefetch │ │ Offload │ ║
║ │(Compute │ │N+1 │ │N-1 │ ║
║ │ Stream) │ │(GPU Load │ │(CPU Load │ ║
║ └─────────────┘ │ Stream) │ │ Stream) │ ║
║ └─────────────┘ └─────────────┘ ║
║ ✅ Parallel execution, no waiting time, high efficiency ║
║ ║
║ 🎯 Swap Advantages: ║
║ • Avoids repeated loading/unloading of same data ║
║ • Achieves continuous computation through position rotation ║
║ • Maximizes GPU utilization ║
║ • Reduces memory fragmentation ║
╚════════════════════════════════════════════════════════════════╝
```
**Key Features:**
- **Asynchronous Transfer**: Uses three CUDA streams with different priorities to parallelize computation and transfer
- Compute Stream (priority=-1): High priority, responsible for current computation
- GPU Load Stream (priority=0): Medium priority, responsible for prefetching from CPU to GPU
- CPU Load Stream (priority=0): Medium priority, responsible for offloading from GPU to CPU
- **Prefetch Mechanism**: Preloads the next block/phase to GPU
- **Intelligent Caching**: Maintains weight cache in CPU memory
- **Stream Synchronization**: Ensures correctness of data transfer and computation
- **Swap Operation**: Rotates block/phase positions after computation completion for continuous processing
### Strategy 2: Disk-CPU-GPU Block/Phase Offloading (Lazy Loading)
**Applicable Scenarios**: Both GPU VRAM and system memory insufficient
**Working Principle**: Introduces disk storage on top of Strategy 1, implementing a three-level storage hierarchy (Disk → CPU → GPU). CPU continues as a cache pool but with configurable size, suitable for CPU memory-constrained devices.
```
Disk-CPU-GPU Block/Phase Offloading Workflow:
╔═════════════════════════════════════════════════════════════════╗
║ 💿 SSD/NVMe Storage ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 📁 block_0 │ │ 📁 block_1 │ │ 📁 block_2 │ │ 📁 block_N │ ║
║ │ .safetensors│ │ .safetensors│ │ .safetensors│ │ .safetensors│ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ │ │ │ ║
║ ▼ ▼ ▼ ▼ ║
║ ┌─────────────────────────────────────────────────────────────┐ ║
║ │ 🎯 Disk Worker Thread Pool │ ║
║ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ ║
║ │ │ Disk Thread │ │ Disk Thread │ │ Disk Thread │ │ ║
║ │ │ 1 │ │ 2 │ │ N │ │ ║
║ │ │(Async Load) │ │(Async Load) │ │(Async Load) │ │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ │ ║
║ │ │ │ │ │ ║
║ │ └───────────────┼───────────────┘ │ ║
║ │ ▼ │ ║
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
║ │ │ 📋 Priority Task Queue │ │ ║
║ │ │ (Manages disk loading task scheduling) │ │ ║
║ │ └─────────────────────────────────────────────────────────┘ │ ║
║ └─────────────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💾 CPU Memory Buffer ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────────────────────────────────────────────────┐ ║
║ │ 🎯 FIFO Intelligent Cache │ ║
║ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ │ 📥 Cache │ │ 📥 Cache │ │ 📥 Cache │ │ 📥 Cache │ ║
║ │ │ block/phase │ │ block/phase │ │ block/phase │ │ block/phase │ ║
║ │ │ N-2 │ │ N-1 │ │ N │ │ N+1 │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ ▲ ▲ ▲ ▲ ║
║ │ │ │ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ │ CPU Load │ │ CPU Load │ │ CPU Load │ │ CPU Load │ ║
║ │ │ Stream │ │ Stream │ │ Stream │ │ Stream │ ║
║ │ │(priority=0) │ │(priority=0) │ │(priority=0) │ │(priority=0) │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ │ ║
║ │ 💡 Configurable Size 🎯 FIFO Eviction 🔄 Cache Hit/Miss │ ║
║ └─────────────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🎯 GPU Memory ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 Current │ │ ⏳ Prefetch │ │ 📤 To Offload │ ║
║ │ block/phase N │◄──►│ block/phase N+1 │◄──►│ block/phase N-1 │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ │ │ │ ║
║ ▼ ▼ ▼ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ Compute │ │ GPU Load │ │ CPU Load │ ║
║ │ Stream │ │ Stream │ │ Stream │ ║
║ │(priority=-1)│ │ (priority=0) │ │ (priority=0) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🔄 Complete Workflow ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ Step 1: Cache Miss Handling ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 Disk │───►│ 💾 CPU Cache│───►│ 🎯 GPU │ ║
║ │ (On-demand │ │ (FIFO │ │ Memory │ ║
║ │ loading) │ │ Management)│ │ (Compute │ ║
║ └─────────────┘ └─────────────┘ │ Execution) │ ║
║ └─────────────┘ ║
║ ║
║ Step 2: Cache Hit Handling ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 Disk │ │ 💾 CPU Cache│───►│ 🎯 GPU │ ║
║ │ (Skip │ │ (Direct │ │ Memory │ ║
║ │ loading) │ │ Access) │ │ (Compute │ ║
║ └─────────────┘ └─────────────┘ │ Execution) │ ║
║ └─────────────┘ ║
║ ║
║ Step 3: Memory Management ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 Disk │ │ 💾 CPU Cache│ │ 🎯 GPU │ ║
║ │ (Persistent │ │ (FIFO │ │ Memory │ ║
║ │ Storage) │ │ Eviction) │ │ (Swap │ ║
║ └─────────────┘ └─────────────┘ │ Rotation) │ ║
║ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
Work Steps:
1. Disk Storage: Model weights stored by block on SSD/NVMe, one .safetensors file per block
2. Task Scheduling: When a block/phase is needed, priority task queue assigns disk worker threads
3. Async Loading: Multiple disk threads parallelly read weight files from disk to CPU memory buffer
4. Intelligent Caching: CPU memory buffer uses FIFO strategy for cache management with configurable size
5. Cache Hit: If weights are already in cache, directly transfer to GPU without disk reading
6. Prefetch Transfer: Weights in cache asynchronously transfer to GPU memory (using GPU load stream)
7. Compute Execution: Weights on GPU perform computation (using compute stream), while background continues prefetching next block/phase
8. Swap Rotation: After computation completion, rotate block/phase positions for continuous computation
9. Memory Management: When CPU cache is full, automatically evict earliest used weight blocks/phases
```
**Key Features:**
- **Lazy Loading**: Model weights loaded from disk on-demand, avoiding loading entire model at once
- **Intelligent Caching**: CPU memory buffer uses FIFO strategy with configurable size
- **Multi-threaded Prefetching**: Uses multiple disk worker threads for parallel loading
- **Asynchronous Transfer**: Uses CUDA streams to overlap computation and data transfer
- **Swap Rotation**: Achieves continuous computation through position rotation, avoiding repeated loading/unloading
## ⚙️ Configuration Parameters
### GPU-CPU Offloading Configuration
```python
config = {
"cpu_offload": True,
"offload_ratio": 1.0, # Offload ratio (0.0-1.0)
"offload_granularity": "block", # Offload granularity: "block" or "phase"
"lazy_load": False, # Disable lazy loading
}
```
### Disk-CPU-GPU Offloading Configuration
```python
config = {
"cpu_offload": True,
"lazy_load": True, # Enable lazy loading
"offload_ratio": 1.0, # Offload ratio
"offload_granularity": "phase", # Recommended to use phase granularity
"num_disk_workers": 2, # Number of disk worker threads
"offload_to_disk": True, # Enable disk offloading
"offload_path": ".", # Disk offload path
}
```
**Intelligent Cache Key Parameters:**
- `max_memory`: Controls CPU cache size, affects cache hit rate and memory usage
- `num_disk_workers`: Controls number of disk loading threads, affects prefetch speed
- `offload_granularity`: Controls cache granularity (block or phase), affects cache efficiency
- `"block"`: Cache management in units of complete Transformer layers
- `"phase"`: Cache management in units of individual computational components
Detailed configuration files can be referenced at [config](https://github.com/ModelTC/lightx2v/tree/main/configs/offload)
## 🎯 Usage Recommendations
```
╔═════════════════════════════════════════════════════════════════╗
║ 📋 Configuration Guide ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 🔄 GPU-CPU Block/Phase Offloading: ║
║ Suitable for insufficient GPU VRAM (RTX 3090/4090 24G) ║
║ but adequate system memory (>64/128G) ║
║ 💾 Disk-CPU-GPU Block/Phase Offloading: ║
║ Suitable for insufficient GPU VRAM (RTX 3060/4090 8G) ║
║ and system memory (16/32G) ║
║ 🚫 No Offload: Suitable for high-end hardware configurations, ║
║ pursuing optimal performance ║
║ ║
╚═════════════════════════════════════════════════════════════════╝
```
## 🔍 Troubleshooting
### Common Issues and Solutions
1. **Disk I/O Bottleneck**
```
Solution: Use NVMe SSD, increase num_disk_workers
```
2. **Memory Buffer Overflow**
```
Solution: Increase max_memory or decrease num_disk_workers
```
3. **Loading Timeout**
```
Solution: Check disk performance, optimize file system
```
**Note**: This offloading mechanism is specifically designed for Lightx2v, fully utilizing modern hardware's asynchronous computing capabilities, significantly reducing the hardware threshold for large model inference.
...@@ -15,7 +15,7 @@ Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Grad ...@@ -15,7 +15,7 @@ Lightx2v 是一个轻量级的视频推理和生成引擎,提供了基于 Grad
- Python 3.10+ (推荐) - Python 3.10+ (推荐)
- CUDA 12.4+ (推荐) - CUDA 12.4+ (推荐)
- 至少 8GB GPU 显存 - 至少 8GB GPU 显存
- 至少 16GB 系统内存 - 至少 16GB 系统内存(最好最少有 32G)
- 至少 128GB SSD固态硬盘 (**💾 强烈建议使用SSD固态硬盘存储模型文件!"延迟加载"启动时,显著提升模型加载速度和推理性能**) - 至少 128GB SSD固态硬盘 (**💾 强烈建议使用SSD固态硬盘存储模型文件!"延迟加载"启动时,显著提升模型加载速度和推理性能**)
...@@ -83,8 +83,9 @@ vim run_gradio.sh ...@@ -83,8 +83,9 @@ vim run_gradio.sh
bash run_gradio.sh bash run_gradio.sh
# 3. 或使用参数启动(推荐) # 3. 或使用参数启动(推荐)
bash run_gradio.sh --task i2v --lang zh --port 8032 bash run_gradio.sh --task i2v --lang zh --model_size 14b --port 8032
# bash run_gradio.sh --task t2v --lang zh --port 8032 # bash run_gradio.sh --task i2v --lang zh --model_size 14b --port 8032
# bash run_gradio.sh --task i2v --lang zh --model_size 1.3b --port 8032
``` ```
#### 方式二:直接命令行启动 #### 方式二:直接命令行启动
...@@ -93,6 +94,7 @@ bash run_gradio.sh --task i2v --lang zh --port 8032 ...@@ -93,6 +94,7 @@ bash run_gradio.sh --task i2v --lang zh --port 8032
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \ --model_path /path/to/Wan2.1-I2V-14B-720P-Lightx2v \
--model_size 14b \
--task i2v \ --task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -102,6 +104,7 @@ python gradio_demo_zh.py \ ...@@ -102,6 +104,7 @@ python gradio_demo_zh.py \
```bash ```bash
python gradio_demo_zh.py \ python gradio_demo_zh.py \
--model_path /path/to/Wan2.1-T2V-1.3B \ --model_path /path/to/Wan2.1-T2V-1.3B \
--model_size 1.3b \
--task t2v \ --task t2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -111,6 +114,7 @@ python gradio_demo_zh.py \ ...@@ -111,6 +114,7 @@ python gradio_demo_zh.py \
```bash ```bash
python gradio_demo.py \ python gradio_demo.py \
--model_path /path/to/model \ --model_path /path/to/model \
--model_size 14b \
--task i2v \ --task i2v \
--server_name 0.0.0.0 \ --server_name 0.0.0.0 \
--server_port 7862 --server_port 7862
...@@ -122,6 +126,7 @@ python gradio_demo.py \ ...@@ -122,6 +126,7 @@ python gradio_demo.py \
|------|------|------|--------|------| |------|------|------|--------|------|
| `--model_path` | str | ✅ | - | 模型文件夹路径 | | `--model_path` | str | ✅ | - | 模型文件夹路径 |
| `--model_cls` | str | ❌ | wan2.1 | 模型类别(目前仅支持wan2.1) | | `--model_cls` | str | ❌ | wan2.1 | 模型类别(目前仅支持wan2.1) |
| `--model_size` | str | ✅ | - | 模型大小:`14b(图像到视频或者文本到视频)``1.3b(文本到视频)` |
| `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) | | `--task` | str | ✅ | - | 任务类型:`i2v`(图像到视频)或 `t2v`(文本到视频) |
| `--server_port` | int | ❌ | 7862 | 服务器端口 | | `--server_port` | int | ❌ | 7862 | 服务器端口 |
| `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 | | `--server_name` | str | ❌ | 0.0.0.0 | 服务器IP地址 |
...@@ -130,10 +135,6 @@ python gradio_demo.py \ ...@@ -130,10 +135,6 @@ python gradio_demo.py \
### 基本设置 ### 基本设置
#### 模型类型选择
- **Wan2.1 14B**: 参数量大,生成质量高,适合高质量视频生成
- **Wan2.1 1.3B**: 轻量级模型,速度快,适合快速原型和测试
#### 输入参数 #### 输入参数
- **提示词 (Prompt)**: 描述期望的视频内容 - **提示词 (Prompt)**: 描述期望的视频内容
- **负向提示词 (Negative Prompt)**: 指定不希望出现的元素 - **负向提示词 (Negative Prompt)**: 指定不希望出现的元素
...@@ -221,7 +222,7 @@ lightx2v/app/ ...@@ -221,7 +222,7 @@ lightx2v/app/
## 🎨 界面说明 ## 🎨 界面说明
### 基本设置标签页 ### 基本设置标签页
- **输入参数**: 模型类型、提示词、分辨率等基本设置 - **输入参数**: 提示词、分辨率等基本设置
- **视频参数**: FPS、帧数、CFG等视频生成参数 - **视频参数**: FPS、帧数、CFG等视频生成参数
- **输出设置**: 视频保存路径配置 - **输出设置**: 视频保存路径配置
......
# 参数卸载 # Lightx2v 参数卸载机制文档
xxx ## 📖 概述
Lightx2v 实现了先进的参数卸载机制,专为在有限硬件资源下处理大型模型推理而设计。该系统通过智能管理不同内存层次中的模型权重,提供了优秀的速度-内存平衡。
**核心特性:**
- **分block/phase卸载**:高效地以block/phase为单位管理模型权重,实现最优内存使用
- **Block**:Transformer模型的基本计算单元,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),是较大的内存管理单位
- **Phase**:Block内部的更细粒度计算阶段,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),提供更精细的内存控制
- **多级存储支持**:GPU → CPU → 磁盘层次结构,配合智能缓存
- **异步操作**:使用 CUDA 流实现计算和数据传输的重叠
- **磁盘/NVMe 序列化**:当内存不足时支持二级存储
## 🎯 卸载策略
### 策略一:GPU-CPU 分block/phase卸载
**适用场景**:GPU 显存不足但系统内存充足
**工作原理**:在 GPU 和 CPU 内存之间以block或phase为单位管理模型权重,利用 CUDA 流实现计算和数据传输的重叠。Block包含完整的Transformer层,而Phase则是Block内部的单个计算组件。
**Block vs Phase 说明**
- **Block粒度**:较大的内存管理单位,包含完整的Transformer层(自注意力、交叉注意力、前馈网络等),适合内存充足的情况,减少管理开销
- **Phase粒度**:更细粒度的内存管理,包含单个计算组件(如自注意力、交叉注意力、前馈网络等),适合内存受限的情况,提供更灵活的内存控制
```
GPU-CPU 分block/phase卸载工作流程:
╔═════════════════════════════════════════════════════════════════╗
║ 🎯 GPU 内存 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 当前计算 │ │ ⏳ 预取 │ │ 📤 待卸载 │ ║
║ │ block/phase N │◄──►│ block/phase N+1 │◄──►│ block/phase N-1 │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ │ │ │ ║
║ ▼ ▼ ▼ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 计算流 │ │ GPU加载流 │ │ CPU加载流 │ ║
║ │ (priority=-1)│ │ (priority=0) │ │ (priority=0) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💾 CPU 内存 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 📥 缓存 │ │ 📥 缓存 │ │ 📥 缓存 │ │ 📥 缓存 │ ║
║ │ block/phase │ │ block/phase │ │ block/phase │ │ block/phase │ ║
║ │ N-2 │ │ N-1 │ │ N │ │ N+1 │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ▲ ▲ ▲ ▲ ║
║ │ │ │ │ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ CPU加载流 │ │ CPU加载流 │ │ CPU加载流 │ │ CPU加载流 │ ║
║ │ (priority=0)│ │ (priority=0)│ │ (priority=0)│ │ (priority=0)│ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ║
║ 💡 CPU内存中存储了多个block/phase,形成缓存池 ║
║ 🔄 GPU加载流从CPU缓存中预取,CPU加载流向CPU缓存卸载 ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🔄 Swap 操作流程 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 步骤1: 并行执行阶段 ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 计算 │ │ ⏳ 预取 │ │ 📤 卸载 │ ║
║ │ block/phase N │ │ block/phase N+1 │ │ block/phase N-1 │ ║
║ │ (计算流) │ │ (GPU加载流) │ │ (CPU加载流) │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ ║
║ 步骤2: Swap 轮换阶段 ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 计算 │ │ ⏳ 预取 │ │ 📤 卸载 │ ║
║ │ block/phase N+1 │ │ block/phase N+2 │ │ block/phase N │ ║
║ │ (计算流) │ │ (GPU加载流) │ │ (CPU加载流) │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ ║
║ Swap 思想:通过轮换位置实现连续计算,避免重复加载/卸载 ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💡 Swap 核心思想 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 🔄 传统方式 vs Swap方式对比: ║
║ ║
║ 传统方式: ║
║ ┌─────────────┐ ┌──────────┐ ┌─────────┐ ┌────────┐ ║
║ │ 计算N │───►│ 卸载N │───►│ 加载N+1 │───►│ 计算N+1│ ║
║ └─────────────┘ └──────────┘ └─────────┘ └────────┘ ║
║ ❌ 串行执行,存在等待时间,效率低 ║
║ ║
║ Swap方式: ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 计算N │ │ 预取N+1 │ │ 卸载N-1 │ ║
║ │ (计算流) │ │ (GPU加载流) │ │ (CPU加载流) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ✅ 并行执行,无等待时间,效率高 ║
║ ║
║ 🎯 Swap优势: ║
║ • 避免重复加载/卸载同一数据 ║
║ • 通过位置轮换实现连续计算 ║
║ • 最大化GPU利用率 ║
║ • 减少内存碎片 ║
╚════════════════════════════════════════════════════════════════╝
```
**关键特性:**
- **异步传输**:使用三个不同优先级的CUDA流实现计算和传输的并行
- 计算流(priority=-1):高优先级,负责当前计算
- GPU加载流(priority=0):中优先级,负责从CPU到GPU的预取
- CPU加载流(priority=0):中优先级,负责从GPU到CPU的卸载
- **预取机制**:提前将下一个block/phase加载到 GPU
- **智能缓存**:在 CPU 内存中维护权重缓存
- **流同步**:确保数据传输和计算的正确性
- **Swap操作**:计算完成后轮换block/phase位置,实现连续计算
### 策略二:磁盘-CPU-GPU 分block/phase卸载(延迟加载)
**适用场景**:GPU 显存和系统内存都不足
**工作原理**:在策略一的基础上引入磁盘存储,实现三级存储层次(磁盘 → CPU → GPU)。CPU继续作为缓存池,但大小可配置,适用于CPU内存受限的设备。
```
磁盘-CPU-GPU 分block/phase卸载工作流程:
╔═════════════════════════════════════════════════════════════════╗
║ 💿 SSD/NVMe 存储 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 📁 block_0 │ │ 📁 block_1 │ │ 📁 block_2 │ │ 📁 block_N │ ║
║ │ .safetensors│ │ .safetensors│ │ .safetensors│ │ .safetensors│ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ │ │ │ ║
║ ▼ ▼ ▼ ▼ ║
║ ┌─────────────────────────────────────────────────────────────┐ ║
║ │ 🎯 磁盘工作线程池 │ ║
║ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ ║
║ │ │ 磁盘线程1 │ │ 磁盘线程2 │ │ 磁盘线程N │ │ ║
║ │ │ (异步加载) │ │ (异步加载) │ │ (异步加载) │ │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ │ ║
║ │ │ │ │ │ ║
║ │ └───────────────┼───────────────┘ │ ║
║ │ ▼ │ ║
║ │ ┌─────────────────────────────────────────────────────────┐ │ ║
║ │ │ 📋 优先级任务队列 │ │ ║
║ │ │ (管理磁盘加载任务调度) │ │ ║
║ │ └─────────────────────────────────────────────────────────┘ │ ║
║ └─────────────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 💾 CPU 内存缓冲区 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────────────────────────────────────────────────┐ ║
║ │ 🎯 FIFO 智能缓存 │ ║
║ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ │ 📥 缓存 │ │ 📥 缓存 │ │ 📥 缓存 │ │ 📥 缓存 │ ║
║ │ │ block/phase │ │ block/phase │ │ block/phase │ │ block/phase │ ║
║ │ │ N-2 │ │ N-1 │ │ N │ │ N+1 │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ ▲ ▲ ▲ ▲ ║
║ │ │ │ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ │ CPU加载流 │ │ CPU加载流 │ │ CPU加载流 │ │ CPU加载流 │ ║
║ │ │ (priority=0)│ │ (priority=0)│ │ (priority=0)│ │ (priority=0)│ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ │ │ ║
║ │ 💡 可配置大小 🎯 FIFO淘汰策略 🔄 缓存命中/未命中处理 │ ║
║ └─────────────────────────────────────────────────────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🎯 GPU 内存 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ ║
║ │ 🔄 当前计算 │ │ ⏳ 预取 │ │ 📤 待卸载 │ ║
║ │ block/phase N │◄──►│ block/phase N+1 │◄──►│ block/phase N-1 │ ║
║ └─────────────────┘ └─────────────────┘ └─────────────────┘ ║
║ │ │ │ ║
║ ▼ ▼ ▼ ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 计算流 │ │ GPU加载流 │ │ CPU加载流 │ ║
║ │ (priority=-1)│ │ (priority=0) │ │ (priority=0) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
╔═════════════════════════════════════════════════════════════════╗
║ 🔄 完整工作流程 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 步骤1: 缓存未命中处理 ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 磁盘 │───►│ 💾 CPU缓存 │───►│ 🎯 GPU内存 │ ║
║ │ (按需加载) │ │ (FIFO管理) │ │ (计算执行) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ║
║ 步骤2: 缓存命中处理 ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 磁盘 │ │ 💾 CPU缓存 │───►│ 🎯 GPU内存 │ ║
║ │ (跳过加载) │ │ (直接获取) │ │ (计算执行) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
║ ║
║ 步骤3: 内存管理 ║
║ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ║
║ │ 💿 磁盘 │ │ 💾 CPU缓存 │ │ 🎯 GPU内存 │ ║
║ │ (持久存储) │ │ (FIFO淘汰) │ │ (Swap轮换) │ ║
║ └─────────────┘ └─────────────┘ └─────────────┘ ║
╚═════════════════════════════════════════════════════════════════╝
工作步骤:
1. 磁盘存储:模型权重按block存储在SSD/NVMe上,每个block一个.safetensors文件
2. 任务调度:当需要某个block/phase时,优先级任务队列分配磁盘工作线程
3. 异步加载:多个磁盘线程并行从磁盘读取权重文件到CPU内存缓冲区
4. 智能缓存:CPU内存缓冲区使用FIFO策略管理缓存,可配置大小
5. 缓存命中:如果权重已在缓存中,直接传输到GPU,无需磁盘读取
6. 预取传输:缓存中的权重异步传输到GPU内存(使用GPU加载流)
7. 计算执行:GPU上的权重进行计算(使用计算流),同时后台继续预取下一个block/phase
8. Swap轮换:计算完成后轮换block/phase位置,实现连续计算
9. 内存管理:当CPU缓存满时,自动淘汰最早使用的权重block/phase
```
**关键特性:**
- **延迟加载**:模型权重按需从磁盘加载,避免一次性加载全部模型
- **智能缓存**:CPU内存缓冲区使用FIFO策略管理,可配置大小
- **多线程预取**:使用多个磁盘工作线程并行加载
- **异步传输**:使用CUDA流实现计算和数据传输的重叠
- **Swap轮换**:通过位置轮换实现连续计算,避免重复加载/卸载
## ⚙️ 配置参数
### GPU-CPU 卸载配置
```python
config = {
"cpu_offload": True,
"offload_ratio": 1.0, # 卸载比例(0.0-1.0)
"offload_granularity": "block", # 卸载粒度:"block"或"phase"
"lazy_load": False, # 禁用延迟加载
}
```
### 磁盘-CPU-GPU 卸载配置
```python
config = {
"cpu_offload": True,
"lazy_load": True, # 启用延迟加载
"offload_ratio": 1.0, # 卸载比例
"offload_granularity": "phase", # 推荐使用phase粒度
"num_disk_workers": 2, # 磁盘工作线程数
"offload_to_disk": True, # 启用磁盘卸载
"offload_path": ".", # 磁盘卸载路径
}
```
**智能缓存关键参数:**
- `max_memory`:控制CPU缓存大小,影响缓存命中率和内存使用
- `num_disk_workers`:控制磁盘加载线程数,影响预取速度
- `offload_granularity`:控制缓存粒度(block或phase),影响缓存效率
- `"block"`:以完整的Transformer层为单位进行缓存管理
- `"phase"`:以单个计算组件为单位进行缓存管理
详细配置文件可参考[config](https://github.com/ModelTC/lightx2v/tree/main/configs/offload)
## 🎯 使用建议
╔═════════════════════════════════════════════════════════════════╗
║ 📋 配置建议 ║
╠═════════════════════════════════════════════════════════════════╣
║ ║
║ 🔄 GPU-CPU分block/phase卸载: ║
║ 适合GPU显存不足(RTX 3090/4090 24G)但系统内存(>64/128G)充足 ║
║ 💾 磁盘-CPU-GPU分block/phase卸载: ║
║ 适合GPU显存(RTX 3060/4090 8G)和系统内存(16/32G)都不足 ║
║ 🚫 无Offload:适合高端硬件配置,追求最佳性能 ║
║ ║
╚═════════════════════════════════════════════════════════════════╝
```
## 🔍 故障排除
### 常见问题及解决方案
1. **磁盘I/O瓶颈**
```
解决方案:使用NVMe SSD,增加num_disk_workers
```
2. **内存缓冲区溢出**
```
解决方案:增加max_memory或减少num_disk_workers
```
3. **加载超时**
```
解决方案:检查磁盘性能,优化文件系统
```
**注意**:本卸载机制专为Lightx2v设计,充分利用了现代硬件的异步计算能力,能够显著降低大模型推理的硬件门槛。
...@@ -15,6 +15,7 @@ class WeightAsyncStreamManager(object): ...@@ -15,6 +15,7 @@ class WeightAsyncStreamManager(object):
self.cuda_load_stream = torch.cuda.Stream(priority=0) self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.offload_block_num = int(offload_ratio * blocks_num) self.offload_block_num = int(offload_ratio * blocks_num)
self.phases_num = phases_num self.phases_num = phases_num
self.block_nums = blocks_num
self.offload_phases_num = blocks_num * phases_num * offload_ratio self.offload_phases_num = blocks_num * phases_num * offload_ratio
def prefetch_weights(self, block_idx, blocks_weights): def prefetch_weights(self, block_idx, blocks_weights):
...@@ -128,6 +129,9 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -128,6 +129,9 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if next_block_idx < 0: if next_block_idx < 0:
next_block_idx = 0 next_block_idx = 0
if next_block_idx == self.block_nums:
return
if self.offload_gra == "phase": if self.offload_gra == "phase":
for phase_idx in range(self.phases_num): for phase_idx in range(self.phases_num):
obj_key = (next_block_idx, phase_idx) obj_key = (next_block_idx, phase_idx)
...@@ -170,6 +174,8 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -170,6 +174,8 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self.pin_memory_buffer.push(block_idx, block) self.pin_memory_buffer.push(block_idx, block)
block_idx += 1 block_idx += 1
if block_idx == self.block_nums:
break
def prefetch_weights_from_disk(self, blocks): def prefetch_weights_from_disk(self, blocks):
if self.initial_prefetch_done: if self.initial_prefetch_done:
......
...@@ -56,3 +56,10 @@ class Conv2dWeight(Conv2dWeightTemplate): ...@@ -56,3 +56,10 @@ class Conv2dWeight(Conv2dWeightTemplate):
if self.bias is not None: if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination return destination
def clear(self):
attrs = ["weight", "bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
...@@ -66,3 +66,10 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -66,3 +66,10 @@ class Conv3dWeight(Conv3dWeightTemplate):
if self.bias is not None: if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone() destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination return destination
def clear(self):
attrs = ["weight", "bias"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
...@@ -34,9 +34,11 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -34,9 +34,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
return self.weight.numel() * self.weight.element_size() return self.weight.numel() * self.weight.element_size()
def clear(self): def clear(self):
del self.weight attrs = ["weight", "bias"]
if self.bias is not None: for attr in attrs:
del self.bias if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
@abstractmethod @abstractmethod
def apply(self, input_tensor): def apply(self, input_tensor):
......
...@@ -23,7 +23,11 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -23,7 +23,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype) self.pinned_weight = torch.empty(self.weight.shape, pin_memory=True, dtype=self.weight.dtype)
def clear(self): def clear(self):
del self.weight attrs = ["weight"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
@abstractmethod @abstractmethod
def apply(self, input_tensor): def apply(self, input_tensor):
......
...@@ -22,7 +22,11 @@ class DefaultTensor: ...@@ -22,7 +22,11 @@ class DefaultTensor:
self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype) self.pinned_tensor = torch.empty(self.tensor.shape, pin_memory=True, dtype=self.tensor.dtype)
def clear(self): def clear(self):
del self.tensor attrs = ["tensor"]
for attr in attrs:
if hasattr(self, attr):
delattr(self, attr)
setattr(self, attr, None)
def _calculate_size(self): def _calculate_size(self):
return self.tensor.numel() * self.tensor.element_size() return self.tensor.numel() * self.tensor.element_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment