Commit ea618db2 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #74 from ModelTC/dev_demo

Update demo
parents 6bd320af 941fa16f
......@@ -48,7 +48,7 @@ def run_inference(
if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true"
os.environ["DTYPE"] = "BF16"
config = {
"infer_steps": infer_steps,
"target_video_length": num_frames,
......@@ -136,8 +136,8 @@ def run_inference(
asyncio.run(runner.run_pipeline())
del runner
gc.collect()
torch.cuda.empty_cache()
gc.collect()
return save_video_path
......@@ -185,6 +185,7 @@ with gr.Blocks(
lines=3,
placeholder="Unwanted content...",
max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
)
with gr.Column():
tiny_vae_path = gr.Textbox(
......
......@@ -2,6 +2,7 @@ import torch
import threading
import queue
import time
import gc
from loguru import logger
from collections import OrderedDict
......@@ -182,6 +183,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
logger.info("All worker threads have been closed")
def clear(self):
self.pin_memory_buffer.clear()
self.shutdown()
class MemoryBuffer:
def __init__(self, max_memory_bytes=8 * (1024**3)):
......@@ -245,3 +250,14 @@ class MemoryBuffer:
if not self.cache:
return -1
return (list(self.cache.keys())[-1][0] + 1) % 40
def clear(self):
with self.lock:
for key in list(self.cache.keys()):
self._remove_key(key)
self.insertion_order = []
self.insertion_index = 0
self.used_mem = 0
torch.cuda.empty_cache()
gc.collect()
......@@ -94,7 +94,7 @@ class LNWeight(LNWeightTemplate):
self.bias = None
def apply(self, input_tensor):
if self.weight is not None and self.weight.dtype == torch.bfloat16:
if self.weight is None or self.weight.dtype == torch.bfloat16:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
else:
input_tensor = torch.nn.functional.layer_norm(
......
import numpy as np
from ..transformer_infer import WanTransformerInfer
import torch
import gc
class WanTransformerInferTeaCaching(WanTransformerInfer):
......@@ -10,7 +11,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
modulated_inp = embed0 if self.scheduler.use_ret_steps else embed
# teacache
if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
......@@ -32,6 +32,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
gc.collect()
else: # odd -> unconditon
self.scheduler.is_even = False
......@@ -54,6 +55,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
modulated_inp = modulated_inp.cpu()
del modulated_inp
torch.cuda.empty_cache()
gc.collect()
if self.scheduler.is_even:
if not should_calc_even:
......@@ -76,6 +78,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
gc.collect()
else:
if not should_calc_odd:
x += self.scheduler.previous_residual_odd.cuda()
......@@ -97,4 +100,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
gc.collect()
return x
......@@ -129,6 +129,7 @@ class DefaultRunner:
self.model.scheduler.clear()
del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False):
self.model.transformer_infer.weights_stream_mgr.clear()
del self.model
torch.cuda.empty_cache()
gc.collect()
......
......@@ -35,6 +35,6 @@ python -m lightx2v.infer \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment