Commit 7fc021e2 authored by helloyongyang's avatar helloyongyang
Browse files

support runners & torch.compile

parent cbf7820f
......@@ -10,6 +10,7 @@ import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
from lightx2v.utils.envs import *
from lightx2v.utils.utils import save_videos_grid, seed_all, cache_video
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.set_config import set_config
......@@ -32,6 +33,9 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model import VideoEncoderKLCausal3DModel
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import *
......@@ -271,22 +275,6 @@ def init_scheduler(config, image_encoder_output):
return scheduler
def run_main_inference(model, inputs):
for step_index in range(model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
model.infer(inputs)
with ProfilingContext4Debug("step_post"):
model.scheduler.step_post()
return model.scheduler.latents, model.scheduler.generator
def run_vae(latents, generator, config):
images = vae_model.decode(latents, generator=generator, config=config)
return images
......@@ -358,7 +346,14 @@ if __name__ == "__main__":
gc.collect()
torch.cuda.empty_cache()
latents, generator = run_main_inference(model, inputs)
if ENABLE_GRAPH_MODE:
default_runner = DefaultRunner(model, inputs)
runner = GraphRunner(default_runner)
else:
runner = DefaultRunner(model, inputs)
latents, generator = runner.run()
if config.cpu_offload:
scheduler.clear()
......
......@@ -3,6 +3,7 @@ from einops import rearrange
from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb
from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import *
class HunyuanTransformerInfer:
......@@ -25,6 +26,7 @@ class HunyuanTransformerInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not ENABLE_GRAPH_MODE)
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec=None, frist_frame_token_num=None):
return self.infer_func(weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis, token_replace_vec, frist_frame_token_num)
......
......@@ -2,6 +2,7 @@ import torch
from .utils import compute_freqs, compute_freqs_dist, apply_rotary_emb
from lightx2v.attentions import attention
from lightx2v.common.offload.manager import WeightStreamManager
from lightx2v.utils.envs import *
class WanTransformerInfer:
......@@ -34,6 +35,7 @@ class WanTransformerInfer:
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
return cu_seqlens_q, cu_seqlens_k, lq, lk
@torch.compile(disable=not ENABLE_GRAPH_MODE)
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
return self.infer_func(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
......
from lightx2v.utils.profiler import ProfilingContext4Debug
class DefaultRunner:
def __init__(self, model, inputs):
self.model = model
self.inputs = inputs
def run(self):
for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
return self.model.scheduler.latents, self.model.scheduler.generator
def run_step(self, step_index=0):
self.model.scheduler.step_pre(step_index=step_index)
self.model.infer(self.inputs)
self.model.scheduler.step_post()
import copy
from lightx2v.utils.profiler import ProfilingContext4Debug
class GraphRunner:
def __init__(self, runner):
self.runner = runner
self.compile()
def compile(self):
scheduler = copy.deepcopy(self.runner.model.scheduler)
inputs = copy.deepcopy(self.runner.inputs)
print("start compile...")
with ProfilingContext4Debug("compile"):
self.runner.run_step()
print("end compile...")
self.runner.model.set_scheduler(scheduler)
setattr(self.runner, "inputs", inputs)
def run(self):
return self.runner.run()
import os
global ENABLE_PROFILING_DEBUG
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
global ENABLE_GRAPH_MODE
ENABLE_GRAPH_MODE = os.getenv("ENABLE_GRAPH_MODE", "false").lower() == "true"
import time
import os
import torch
from contextlib import ContextDecorator
ENABLE_PROFILING_DEBUG = os.getenv("ENABLE_PROFILING_DEBUG", "false").lower() == "true"
from lightx2v.utils.envs import *
class _ProfilingContext(ContextDecorator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment