Commit ea618db2 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #74 from ModelTC/dev_demo

Update demo
parents 6bd320af 941fa16f
...@@ -48,7 +48,7 @@ def run_inference( ...@@ -48,7 +48,7 @@ def run_inference(
if torch_compile: if torch_compile:
os.environ["ENABLE_GRAPH_MODE"] = "true" os.environ["ENABLE_GRAPH_MODE"] = "true"
os.environ["DTYPE"] = "BF16"
config = { config = {
"infer_steps": infer_steps, "infer_steps": infer_steps,
"target_video_length": num_frames, "target_video_length": num_frames,
...@@ -136,8 +136,8 @@ def run_inference( ...@@ -136,8 +136,8 @@ def run_inference(
asyncio.run(runner.run_pipeline()) asyncio.run(runner.run_pipeline())
del runner del runner
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
return save_video_path return save_video_path
...@@ -185,6 +185,7 @@ with gr.Blocks( ...@@ -185,6 +185,7 @@ with gr.Blocks(
lines=3, lines=3,
placeholder="Unwanted content...", placeholder="Unwanted content...",
max_lines=5, max_lines=5,
value="镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
) )
with gr.Column(): with gr.Column():
tiny_vae_path = gr.Textbox( tiny_vae_path = gr.Textbox(
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import threading import threading
import queue import queue
import time import time
import gc
from loguru import logger from loguru import logger
from collections import OrderedDict from collections import OrderedDict
...@@ -182,6 +183,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager): ...@@ -182,6 +183,10 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
logger.info("All worker threads have been closed") logger.info("All worker threads have been closed")
def clear(self):
self.pin_memory_buffer.clear()
self.shutdown()
class MemoryBuffer: class MemoryBuffer:
def __init__(self, max_memory_bytes=8 * (1024**3)): def __init__(self, max_memory_bytes=8 * (1024**3)):
...@@ -245,3 +250,14 @@ class MemoryBuffer: ...@@ -245,3 +250,14 @@ class MemoryBuffer:
if not self.cache: if not self.cache:
return -1 return -1
return (list(self.cache.keys())[-1][0] + 1) % 40 return (list(self.cache.keys())[-1][0] + 1) % 40
def clear(self):
with self.lock:
for key in list(self.cache.keys()):
self._remove_key(key)
self.insertion_order = []
self.insertion_index = 0
self.used_mem = 0
torch.cuda.empty_cache()
gc.collect()
...@@ -94,7 +94,7 @@ class LNWeight(LNWeightTemplate): ...@@ -94,7 +94,7 @@ class LNWeight(LNWeightTemplate):
self.bias = None self.bias = None
def apply(self, input_tensor): def apply(self, input_tensor):
if self.weight is not None and self.weight.dtype == torch.bfloat16: if self.weight is None or self.weight.dtype == torch.bfloat16:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps) input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
else: else:
input_tensor = torch.nn.functional.layer_norm( input_tensor = torch.nn.functional.layer_norm(
......
import numpy as np import numpy as np
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
import torch import torch
import gc
class WanTransformerInferTeaCaching(WanTransformerInfer): class WanTransformerInferTeaCaching(WanTransformerInfer):
...@@ -10,7 +11,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -10,7 +11,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
modulated_inp = embed0 if self.scheduler.use_ret_steps else embed modulated_inp = embed0 if self.scheduler.use_ret_steps else embed
# teacache
if self.scheduler.cnt % 2 == 0: # even -> conditon if self.scheduler.cnt % 2 == 0: # even -> conditon
self.scheduler.is_even = True self.scheduler.is_even = True
if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps: if self.scheduler.cnt < self.scheduler.ret_steps or self.scheduler.cnt >= self.scheduler.cutoff_steps:
...@@ -32,6 +32,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -32,6 +32,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
modulated_inp = modulated_inp.cpu() modulated_inp = modulated_inp.cpu()
del modulated_inp del modulated_inp
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
else: # odd -> unconditon else: # odd -> unconditon
self.scheduler.is_even = False self.scheduler.is_even = False
...@@ -54,6 +55,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -54,6 +55,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
modulated_inp = modulated_inp.cpu() modulated_inp = modulated_inp.cpu()
del modulated_inp del modulated_inp
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
if self.scheduler.is_even: if self.scheduler.is_even:
if not should_calc_even: if not should_calc_even:
...@@ -76,6 +78,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -76,6 +78,7 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
ori_x = ori_x.to("cpu") ori_x = ori_x.to("cpu")
del ori_x del ori_x
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
else: else:
if not should_calc_odd: if not should_calc_odd:
x += self.scheduler.previous_residual_odd.cuda() x += self.scheduler.previous_residual_odd.cuda()
...@@ -97,4 +100,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer): ...@@ -97,4 +100,6 @@ class WanTransformerInferTeaCaching(WanTransformerInfer):
ori_x = ori_x.to("cpu") ori_x = ori_x.to("cpu")
del ori_x del ori_x
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
return x return x
...@@ -129,6 +129,7 @@ class DefaultRunner: ...@@ -129,6 +129,7 @@ class DefaultRunner:
self.model.scheduler.clear() self.model.scheduler.clear()
del self.inputs, self.model.scheduler del self.inputs, self.model.scheduler
if self.config.get("lazy_load", False): if self.config.get("lazy_load", False):
self.model.transformer_infer.weights_stream_mgr.clear()
del self.model del self.model
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
......
...@@ -35,6 +35,6 @@ python -m lightx2v.infer \ ...@@ -35,6 +35,6 @@ python -m lightx2v.infer \
--model_path $model_path \ --model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_i2v.json \ --config_json ${lightx2v_path}/configs/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \ --prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \ --negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \ --image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4 --save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment