Commit 878f5a48 authored by helloyongyang's avatar helloyongyang
Browse files

Fix torch compile

parent 88b7a2dd
...@@ -27,8 +27,8 @@ def init_runner(config): ...@@ -27,8 +27,8 @@ def init_runner(config):
if CHECK_ENABLE_GRAPH_MODE(): if CHECK_ENABLE_GRAPH_MODE():
default_runner = RUNNER_REGISTER[config.model_cls](config) default_runner = RUNNER_REGISTER[config.model_cls](config)
default_runner.init_modules()
runner = GraphRunner(default_runner) runner = GraphRunner(default_runner)
runner.runner.init_modules()
else: else:
runner = RUNNER_REGISTER[config.model_cls](config) runner = RUNNER_REGISTER[config.model_cls](config)
runner.init_modules() runner.init_modules()
......
...@@ -12,6 +12,7 @@ class WanPostInfer: ...@@ -12,6 +12,7 @@ class WanPostInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, x, e, grid_sizes): def infer(self, weights, x, e, grid_sizes):
if e.dim() == 2: if e.dim() == 2:
modulation = weights.head_modulation.tensor # 1, 2, dim modulation = weights.head_modulation.tensor # 1, 2, dim
......
...@@ -28,6 +28,7 @@ class WanPreInfer: ...@@ -28,6 +28,7 @@ class WanPreInfer:
def set_scheduler(self, scheduler): def set_scheduler(self, scheduler):
self.scheduler = scheduler self.scheduler = scheduler
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
def infer(self, weights, inputs, positive, kv_start=0, kv_end=0): def infer(self, weights, inputs, positive, kv_start=0, kv_end=0):
x = self.scheduler.latents x = self.scheduler.latents
......
...@@ -6,7 +6,7 @@ from lightx2v.utils.envs import * ...@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
def compute_freqs(c, grid_sizes, freqs): def compute_freqs(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0]
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
...@@ -22,7 +22,7 @@ def compute_freqs(c, grid_sizes, freqs): ...@@ -22,7 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def compute_freqs_audio(c, grid_sizes, freqs): def compute_freqs_audio(c, grid_sizes, freqs):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0]
f = f + 1 ##for r2v add 1 channel f = f + 1 ##for r2v add 1 channel
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
...@@ -39,7 +39,7 @@ def compute_freqs_audio(c, grid_sizes, freqs): ...@@ -39,7 +39,7 @@ def compute_freqs_audio(c, grid_sizes, freqs):
def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0): def compute_freqs_causvid(c, grid_sizes, freqs, start_frame=0):
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
f, h, w = grid_sizes[0].tolist() f, h, w = grid_sizes[0]
seq_len = f * h * w seq_len = f * h * w
freqs_i = torch.cat( freqs_i = torch.cat(
[ [
......
...@@ -107,8 +107,9 @@ class DefaultRunner(BaseRunner): ...@@ -107,8 +107,9 @@ class DefaultRunner(BaseRunner):
def set_progress_callback(self, callback): def set_progress_callback(self, callback):
self.progress_callback = callback self.progress_callback = callback
def run(self): def run(self, total_steps=None):
total_steps = self.model.scheduler.infer_steps if total_steps is None:
total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
...@@ -126,13 +127,10 @@ class DefaultRunner(BaseRunner): ...@@ -126,13 +127,10 @@ class DefaultRunner(BaseRunner):
return self.model.scheduler.latents, self.model.scheduler.generator return self.model.scheduler.latents, self.model.scheduler.generator
def run_step(self, step_index=0): def run_step(self):
self.init_scheduler()
self.inputs = self.run_input_encoder() self.inputs = self.run_input_encoder()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.set_target_shape()
self.model.scheduler.step_pre(step_index=step_index) self.run_dit(total_steps=1)
self.model.infer(self.inputs)
self.model.scheduler.step_post()
def end_run(self): def end_run(self):
self.model.scheduler.clear() self.model.scheduler.clear()
...@@ -171,14 +169,14 @@ class DefaultRunner(BaseRunner): ...@@ -171,14 +169,14 @@ class DefaultRunner(BaseRunner):
} }
@ProfilingContext("Run DiT") @ProfilingContext("Run DiT")
def _run_dit_local(self): def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.model = self.load_transformer() self.model = self.load_transformer()
self.init_scheduler() self.init_scheduler()
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) self.model.scheduler.prepare(self.inputs["image_encoder_output"])
if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v": if self.config.get("model_cls") == "wan2.2" and self.config["task"] == "i2v":
self.inputs["image_encoder_output"]["vae_encoder_out"] = None self.inputs["image_encoder_output"]["vae_encoder_out"] = None
latents, generator = self.run() latents, generator = self.run(total_steps)
self.end_run() self.end_run()
return latents, generator return latents, generator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment