wan_sf_runner.py 7.44 KB
Newer Older
1
2
3
4
5
import gc

import torch
from loguru import logger

6
from lightx2v.deploy.common.video_recorder import VideoRecorder
7
8
9
10
11
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.self_forcing.scheduler import WanSFScheduler
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanSFVAE
12
from lightx2v.server.metrics import monitor_cli
13
from lightx2v.utils.envs import *
14
from lightx2v.utils.memory_profiler import peak_memory_decorator
15
16
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
17
from lightx2v.utils.utils import vae_to_comfyui_image_inplace
18
19
20
21
22
23
24


@RUNNER_REGISTER("wan2.1_sf")
class WanSFRunner(WanRunner):
    def __init__(self, config):
        super().__init__(config)
        self.vae_cls = WanSFVAE
25
26
27
28
29
        self.is_live = config.get("is_live", False)
        if self.is_live:
            self.width = self.config["target_width"]
            self.height = self.config["target_height"]
            self.run_main = self.run_main_live
30
31
32
33
34
35
36
37

    def load_transformer(self):
        model = WanSFModel(
            self.config,
            self.config,
            self.init_device,
        )
        if self.config.get("lora_configs") and self.config.lora_configs:
38
            assert not self.config.get("dit_quantized", False)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            lora_wrapper = WanLoraWrapper(model)
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                lora_name = lora_wrapper.load_lora(lora_path)
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
        return model

    def init_scheduler(self):
        self.scheduler = WanSFScheduler(self.config)

    def set_target_shape(self):
        self.num_output_frames = 21
        self.config.target_shape = [16, self.num_output_frames, 60, 104]

    def get_video_segment_num(self):
        self.video_segment_num = self.scheduler.num_blocks

    @ProfilingContext4DebugL1("Run VAE Decoder")
    def run_vae_decoder(self, latents):
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            self.vae_decoder = self.load_vae_decoder()
        images = self.vae_decoder.decode(latents.to(GET_DTYPE()), use_cache=True)
        if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
            del self.vae_decoder
            torch.cuda.empty_cache()
            gc.collect()
        return images

69
70
71
72
    def init_run(self):
        super().init_run()

    @peak_memory_decorator
PengGao's avatar
PengGao committed
73
74
75
    def run_segment(self, segment_idx=0):
        infer_steps = self.model.scheduler.infer_steps
        for step_index in range(infer_steps):
76
77
78
            # only for single segment, check stop signal every step
            if self.video_segment_num == 1:
                self.check_stop()
PengGao's avatar
PengGao committed
79
            logger.info(f"==> step_index: {step_index + 1} / {infer_steps}")
80
81

            with ProfilingContext4DebugL1("step_pre"):
PengGao's avatar
PengGao committed
82
                self.model.scheduler.step_pre(seg_index=segment_idx, step_index=step_index, is_rerun=False)
83
84

            with ProfilingContext4DebugL1("🚀 infer_main"):
85
86
                self.model.infer(self.inputs)

87
88
89
90
            with ProfilingContext4DebugL1("step_post"):
                self.model.scheduler.step_post()

            if self.progress_callback:
PengGao's avatar
PengGao committed
91
92
93
                current_step = segment_idx * infer_steps + step_index + 1
                total_all_steps = self.video_segment_num * infer_steps
                self.progress_callback((current_step / total_all_steps) * 100, 100)
94

95
        return self.model.scheduler.stream_output
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

    def get_rank_and_world_size(self):
        rank = 0
        world_size = 1
        if dist.is_initialized():
            rank = dist.get_rank()
            world_size = dist.get_world_size()
        return rank, world_size

    def init_video_recorder(self):
        output_video_path = self.input_info.save_result_path
        self.video_recorder = None
        if isinstance(output_video_path, dict):
            output_video_path = output_video_path["data"]
        logger.info(f"init video_recorder with output_video_path: {output_video_path}")
        rank, world_size = self.get_rank_and_world_size()
        if output_video_path and rank == world_size - 1:
            record_fps = self.config.get("target_fps", 16)
            audio_sr = self.config.get("audio_sr", 16000)
            if "video_frame_interpolation" in self.config and self.vfi_model is not None:
                record_fps = self.config["video_frame_interpolation"]["target_fps"]

            self.video_recorder = VideoRecorder(
                livestream_url=output_video_path,
                fps=record_fps,
            )

    @ProfilingContext4DebugL1("End run segment")
    def end_run_segment(self, segment_idx=None):
        with ProfilingContext4DebugL1("step_pre_in_rerun"):
            self.model.scheduler.step_pre(seg_index=segment_idx, step_index=self.model.scheduler.infer_steps - 1, is_rerun=True)
        with ProfilingContext4DebugL1("🚀 infer_main_in_rerun"):
            self.model.infer(self.inputs)

        self.gen_video_final = torch.cat([self.gen_video_final, self.gen_video], dim=0) if self.gen_video_final is not None else self.gen_video
        if self.is_live:
            if self.video_recorder:
                stream_video = vae_to_comfyui_image_inplace(self.gen_video)
                self.video_recorder.pub_video(stream_video)

        torch.cuda.empty_cache()

    @ProfilingContext4DebugL2("Run DiT")
    def run_main_live(self, total_steps=None):
        try:
            self.init_video_recorder()
            logger.info(f"init video_recorder: {self.video_recorder}")
            rank, world_size = self.get_rank_and_world_size()
            if rank == world_size - 1:
                assert self.video_recorder is not None, "video_recorder is required for stream audio input for rank 2"
                self.video_recorder.start(self.width, self.height)
            if world_size > 1:
                dist.barrier()
            self.init_run()
            if self.config.get("compile", False):
                self.model.select_graph_for_compile(self.input_info)

            for segment_idx in range(self.video_segment_num):
                logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
                with ProfilingContext4DebugL1(
                    f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
                    recorder_mode=GET_RECORDER_MODE(),
                    metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
                    metrics_labels=["DefaultRunner"],
                ):
                    self.check_stop()
                    # 1. default do nothing
                    self.init_run_segment(segment_idx)
                    # 2. main inference loop
                    latents = self.run_segment(segment_idx)
                    # 3. vae decoder
                    self.gen_video = self.run_vae_decoder(latents)
                    # 4. default do nothing
                    self.end_run_segment(segment_idx)
        finally:
            if hasattr(self.model, "inputs"):
                self.end_run()
            if self.video_recorder:
                self.video_recorder.stop()
                self.video_recorder = None