Commit 878f5a48 authored by helloyongyang's avatar helloyongyang
Browse files

Fix torch compile

parent 88b7a2dd
......@@ -27,8 +27,8 @@ def init_runner(config):
if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config)
default_runner.init_modules()
runner = GraphRunner(default_runner)
runner.runner.init_modules()
else:
runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules()
......
......@@ -12,6 +12,7 @@ class WanPostInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes):
if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim
......
......@@ -28,6 +28,7 @@ class WanPreInfer:
def set_scheduler(self, scheduler):
self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = self.scheduler.latents
......
......@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
......@@ -22,7 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w
freqs_i = torch.cat(
......@@ -39,7 +39,7 @@ def compute_freqs_audio(c, grid_sizes, freqs):
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist()
f, h, w = grid_sizes[0]
seq_len = f * h * w
freqs_i = torch.cat(
[
......
......@@ -107,8 +107,9 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback):
self.progress_callback = callback
def run(self):
total_steps = self.model.scheduler.infer_steps
def run(self, total_steps=None):
if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
......@@ -126,13 +127,10 @@ class DefaultRunner(BaseRunner):
return self.model.scheduler.latents, self.model.scheduler.generator
def run_step(self, step_index=0):
self.init_scheduler()
def run_step(self):
self.inputs = self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
self.model.scheduler.step_pre(step_index=step_index)
self.model.infer(self.inputs)
self.model.scheduler.step_post()
self.set_target_shape()
self.run_dit(total_steps=1)
def end_run(self):
self.model.scheduler.clear()
......@@ -171,14 +169,14 @@ class DefaultRunner(BaseRunner):
}
@ProfilingContext("Run DiT")
def _run_dit_local(self):
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer()
self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run()
latents, generator = self.run(total_steps)
self.end_run()
return latents, generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment