wan_audio_runner.py 35.4 KB
Newer Older
wangshankun's avatar
wangshankun committed
1
import gc
PengGao's avatar
PengGao committed
2
3
4
5
import os
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
6
from typing import Dict, List, Optional, Tuple
PengGao's avatar
PengGao committed
7

wangshankun's avatar
wangshankun committed
8
9
import numpy as np
import torch
10
import torch.distributed as dist
gushiqiao's avatar
gushiqiao committed
11
import torchaudio as ta
wangshankun's avatar
wangshankun committed
12
from PIL import Image
gushiqiao's avatar
gushiqiao committed
13
from einops import rearrange
PengGao's avatar
PengGao committed
14
from loguru import logger
gushiqiao's avatar
gushiqiao committed
15
16
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
PengGao's avatar
PengGao committed
17
from transformers import AutoFeatureExtractor
18

wangshankun's avatar
wangshankun committed
19
from lightx2v.models.networks.wan.audio_adapter import AudioAdapter, AudioAdapterPipe, rank0_load_state_dict_from_path
PengGao's avatar
PengGao committed
20
21
22
from lightx2v.models.networks.wan.audio_model import Wan22MoeAudioModel, WanAudioModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.runners.wan.wan_runner import MultiModelStruct, WanRunner
wangshankun's avatar
wangshankun committed
23
from lightx2v.models.schedulers.wan.audio.scheduler import ConsistencyModelScheduler
24
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
25
from lightx2v.utils.envs import *
PengGao's avatar
PengGao committed
26
27
from lightx2v.utils.profiler import ProfilingContext, ProfilingContext4Debug
from lightx2v.utils.registry_factory import RUNNER_REGISTER
28
29
from lightx2v.utils.utils import find_torch_model_path, save_to_video, vae_to_comfyui_image

wangshankun's avatar
wangshankun committed
30

31
32
33
34
35
36
37
38
39
40
41
@contextmanager
def memory_efficient_inference():
    """Context manager for memory-efficient inference"""
    try:
        yield
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()


42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
    assert sp_size > 0 and (sp_size & (sp_size - 1)) == 0, "sp_size must be a power of 2"

    h_ratio, w_ratio = 1, 1
    while sp_size != 1:
        sp_size //= 2
        if patched_h % 2 == 0:
            patched_h //= 2
            h_ratio *= 2
        elif patched_w % 2 == 0:
            patched_w //= 2
            w_ratio *= 2
        else:
            if patched_h > patched_w:
                patched_h //= 2
57
58
                h_ratio *= 2
            else:
59
                patched_w //= 2
60
                w_ratio *= 2
61
    return patched_h * h_ratio, patched_w * w_ratio
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89


def get_crop_bbox(ori_h, ori_w, tgt_h, tgt_w):
    tgt_ar = tgt_h / tgt_w
    ori_ar = ori_h / ori_w
    if abs(ori_ar - tgt_ar) < 0.01:
        return 0, ori_h, 0, ori_w
    if ori_ar > tgt_ar:
        crop_h = int(tgt_ar * ori_w)
        y0 = (ori_h - crop_h) // 2
        y1 = y0 + crop_h
        return y0, y1, 0, ori_w
    else:
        crop_w = int(ori_h / tgt_ar)
        x0 = (ori_w - crop_w) // 2
        x1 = x0 + crop_w
        return 0, ori_h, x0, x1


def isotropic_crop_resize(frames: torch.Tensor, size: tuple):
    """
    frames: (T, C, H, W)
    size: (H, W)
    """
    ori_h, ori_w = frames.shape[2:]
    h, w = size
    y0, y1, x0, x1 = get_crop_bbox(ori_h, ori_w, h, w)
    cropped_frames = frames[:, :, y0:y1, x0:x1]
90
    resized_frames = resize(cropped_frames, [h, w], InterpolationMode.BICUBIC, antialias=True)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    return resized_frames


def adaptive_resize(img):
    bucket_config = {
        0.667: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64), np.array([0.2, 0.5, 0.3])),
        1.0: (np.array([[480, 480], [576, 576], [704, 704], [960, 960]], dtype=np.int64), np.array([0.1, 0.1, 0.5, 0.3])),
        1.5: (np.array([[480, 832], [544, 960], [720, 1280]], dtype=np.int64)[:, ::-1], np.array([0.2, 0.5, 0.3])),
    }
    ori_height = img.shape[-2]
    ori_weight = img.shape[-1]
    ori_ratio = ori_height / ori_weight
    aspect_ratios = np.array(np.array(list(bucket_config.keys())))
    closet_aspect_idx = np.argmin(np.abs(aspect_ratios - ori_ratio))
    closet_ratio = aspect_ratios[closet_aspect_idx]
    if ori_ratio < 1.0:
        target_h, target_w = 480, 832
    elif ori_ratio == 1.0:
        target_h, target_w = 480, 480
    else:
        target_h, target_w = 832, 480
    for resolution in bucket_config[closet_ratio][0]:
        if ori_height * ori_weight >= resolution[0] * resolution[1]:
            target_h, target_w = resolution
    cropped_img = isotropic_crop_resize(img, (target_h, target_w))
    return cropped_img, target_h, target_w


119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
@dataclass
class AudioSegment:
    """Data class for audio segment information"""

    audio_array: np.ndarray
    start_frame: int
    end_frame: int
    is_last: bool = False
    useful_length: Optional[int] = None


class FramePreprocessor:
    """Handles frame preprocessing including noise and masking"""

    def __init__(self, noise_mean: float = -3.0, noise_std: float = 0.5, mask_rate: float = 0.1):
        self.noise_mean = noise_mean
        self.noise_std = noise_std
        self.mask_rate = mask_rate

    def add_noise(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
        """Add noise to frames"""
        if self.noise_mean is None or self.noise_std is None:
            return frames

        if rnd_state is None:
            rnd_state = np.random.RandomState()

        shape = frames.shape
        bs = 1 if len(shape) == 4 else shape[0]
        sigma = rnd_state.normal(loc=self.noise_mean, scale=self.noise_std, size=(bs,))
        sigma = np.exp(sigma)
        sigma = np.expand_dims(sigma, axis=tuple(range(1, len(shape))))
        noise = rnd_state.randn(*shape) * sigma
        return frames + noise

    def add_mask(self, frames: np.ndarray, rnd_state: Optional[np.random.RandomState] = None) -> np.ndarray:
        """Add mask to frames"""
        if self.mask_rate is None:
            return frames

        if rnd_state is None:
            rnd_state = np.random.RandomState()

        h, w = frames.shape[-2:]
        mask = rnd_state.rand(h, w) > self.mask_rate
        return frames * mask

    def process_prev_frames(self, frames: torch.Tensor) -> torch.Tensor:
        """Process previous frames with noise and masking"""
        frames_np = frames.cpu().detach().numpy()
        frames_np = self.add_noise(frames_np)
        frames_np = self.add_mask(frames_np)
        return torch.from_numpy(frames_np).to(dtype=frames.dtype, device=frames.device)


class AudioProcessor:
    """Handles audio loading and segmentation"""

    def __init__(self, audio_sr: int = 16000, target_fps: int = 16):
        self.audio_sr = audio_sr
        self.target_fps = target_fps

    def load_audio(self, audio_path: str) -> np.ndarray:
        """Load and resample audio"""
        audio_array, ori_sr = ta.load(audio_path)
        audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=self.audio_sr)
        return audio_array.numpy()

    def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
        """Calculate audio range for given frame range"""
        audio_frame_rate = self.audio_sr / self.target_fps
        return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate)

    def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
        """Segment audio based on frame requirements"""
        segments = []

        # Calculate intervals
        interval_num = 1
        res_frame_num = 0

        if expected_frames <= max_num_frames:
            interval_num = 1
        else:
            interval_num = max(int((expected_frames - max_num_frames) / (max_num_frames - prev_frame_length)) + 1, 1)
            res_frame_num = expected_frames - interval_num * (max_num_frames - prev_frame_length)
            if res_frame_num > 5:
                interval_num += 1

        # Create segments
        for idx in range(interval_num):
            if idx == 0:
                # First segment
                audio_start, audio_end = self.get_audio_range(0, max_num_frames)
                segment_audio = audio_array[audio_start:audio_end]
                useful_length = None

                if expected_frames < max_num_frames:
                    useful_length = segment_audio.shape[0]
                    max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
                    segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)

                segments.append(AudioSegment(segment_audio, 0, max_num_frames, False, useful_length))

            elif res_frame_num > 5 and idx == interval_num - 1:
                # Last segment (might be shorter)
                start_frame = idx * max_num_frames - idx * prev_frame_length
                audio_start, audio_end = self.get_audio_range(start_frame, expected_frames)
                segment_audio = audio_array[audio_start:audio_end]
                useful_length = segment_audio.shape[0]

                max_num_audio_length = int((max_num_frames + 1) / self.target_fps * self.audio_sr)
                segment_audio = np.concatenate((segment_audio, np.zeros(max_num_audio_length - useful_length)), axis=0)

                segments.append(AudioSegment(segment_audio, start_frame, expected_frames, True, useful_length))

            else:
                # Middle segments
                start_frame = idx * max_num_frames - idx * prev_frame_length
                end_frame = (idx + 1) * max_num_frames - idx * prev_frame_length
                audio_start, audio_end = self.get_audio_range(start_frame, end_frame)
                segment_audio = audio_array[audio_start:audio_end]

                segments.append(AudioSegment(segment_audio, start_frame, end_frame, False))

        return segments


class VideoGenerator:
    """Handles video generation for each segment"""

250
    def __init__(self, model, vae_encoder, vae_decoder, config, progress_callback=None):
251
252
253
254
255
        self.model = model
        self.vae_encoder = vae_encoder
        self.vae_decoder = vae_decoder
        self.config = config
        self.frame_preprocessor = FramePreprocessor()
256
257
        self.progress_callback = progress_callback
        self.total_segments = 1
258
259
260

    def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
        """Prepare previous latents for conditioning"""
wangshankun's avatar
wangshankun committed
261
        device = torch.device("cuda")
262
        dtype = GET_DTYPE()
263
264
265
266
267
        vae_dtype = torch.float

        tgt_h, tgt_w = self.config.tgt_h, self.config.tgt_w
        prev_frames = torch.zeros((1, 3, self.config.target_video_length, tgt_h, tgt_w), device=device)

268
269
270
271
272
        if prev_video is not None:
            # Extract and process last frames
            last_frames = prev_video[:, :, -prev_frame_length:].clone().to(device)
            last_frames = self.frame_preprocessor.process_prev_frames(last_frames)
            prev_frames[:, :, :prev_frame_length] = last_frames
273
274

        _, nframe, height, width = self.model.scheduler.latents.shape
275
276
277
278
279
280
281
282
283
284
285
        if self.config.model_cls == "wan2.2_audio":
            prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config).to(dtype)
            _, prev_mask = self._wan22_masks_like([self.model.scheduler.latents], zero=True, prev_length=prev_latents.shape[1])
        else:
            prev_latents = self.vae_encoder.encode(prev_frames.to(vae_dtype), self.config)[0].to(dtype)

            if prev_video is not None:
                prev_token_length = (prev_frame_length - 1) // 4 + 1
                prev_frame_len = max((prev_token_length - 1) * 4 + 1, 0)
            else:
                prev_frame_len = 0
286

287
288
289
290
            frames_n = (nframe - 1) * 4 + 1
            prev_mask = torch.ones((1, frames_n, height, width), device=device, dtype=dtype)
            prev_mask[:, prev_frame_len:] = 0
            prev_mask = self._wan_mask_rearrange(prev_mask).unsqueeze(0)
helloyongyang's avatar
fix ci  
helloyongyang committed
291

292
293
        if prev_latents.shape[-2:] != (height, width):
            logger.warning(f"Size mismatch: prev_latents {prev_latents.shape} vs scheduler latents (H={height}, W={width}). Config tgt_h={self.config.tgt_h}, tgt_w={self.config.tgt_w}")
helloyongyang's avatar
fix ci  
helloyongyang committed
294
            prev_latents = torch.nn.functional.interpolate(prev_latents, size=(height, width), mode="bilinear", align_corners=False)
295
296
297

        return {"prev_latents": prev_latents, "prev_mask": prev_mask}

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    def _wan22_masks_like(self, tensor, zero=False, generator=None, p=0.2, prev_length=1):
        assert isinstance(tensor, list)
        out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]
        out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor]

        if prev_length == 0:
            return out1, out2

        if zero:
            if generator is not None:
                for u, v in zip(out1, out2):
                    random_num = torch.rand(1, generator=generator, device=generator.device).item()
                    if random_num < p:
                        u[:, :prev_length] = torch.normal(mean=-3.5, std=0.5, size=(1,), device=u.device, generator=generator).expand_as(u[:, :prev_length]).exp()
                        v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])
                    else:
                        u[:, :prev_length] = u[:, :prev_length]
                        v[:, :prev_length] = v[:, :prev_length]
            else:
                for u, v in zip(out1, out2):
                    u[:, :prev_length] = torch.zeros_like(u[:, :prev_length])
                    v[:, :prev_length] = torch.zeros_like(v[:, :prev_length])

        return out1, out2

323
324
325
326
327
328
329
330
331
332
333
334
335
    def _wan_mask_rearrange(self, mask: torch.Tensor) -> torch.Tensor:
        """Rearrange mask for WAN model"""
        if mask.ndim == 3:
            mask = mask[None]
        assert mask.ndim == 4
        _, t, h, w = mask.shape
        assert t == ((t - 1) // 4 * 4 + 1)
        mask_first_frame = torch.repeat_interleave(mask[:, 0:1], repeats=4, dim=1)
        mask = torch.concat([mask_first_frame, mask[:, 1:]], dim=1)
        mask = mask.view(mask.shape[1] // 4, 4, h, w)
        return mask.transpose(0, 1)

    @torch.no_grad()
336
    def generate_segment(self, inputs, audio_features, prev_video=None, prev_frame_length=5, segment_idx=0, total_steps=None):
337
338
339
340
341
342
343
344
        """Generate video segment"""
        # Update inputs with audio features
        inputs["audio_encoder_output"] = audio_features

        # Reset scheduler for non-first segments
        if segment_idx > 0:
            self.model.scheduler.reset()

345
        inputs["previmg_encoder_output"] = self.prepare_prev_latents(prev_video, prev_frame_length)
wangshankun's avatar
wangshankun committed
346

347
        # Run inference loop
348
349
        if total_steps is None:
            total_steps = self.model.scheduler.infer_steps
350
351
        for step_index in range(total_steps):
            logger.info(f"==> Segment {segment_idx}, Step {step_index}/{total_steps}")
wangshankun's avatar
wangshankun committed
352

353
354
            with ProfilingContext4Debug("step_pre"):
                self.model.scheduler.step_pre(step_index=step_index)
wangshankun's avatar
wangshankun committed
355

helloyongyang's avatar
helloyongyang committed
356
            with ProfilingContext4Debug("🚀 infer_main"):
357
                self.model.infer(inputs)
wangshankun's avatar
wangshankun committed
358

359
360
            with ProfilingContext4Debug("step_post"):
                self.model.scheduler.step_post()
361
362
363
364
                if self.config.model_cls == "wan2.2_audio":
                    prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
                    prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
                    self.model.scheduler.latents = (1.0 - prev_mask[0]) * prev_latents + prev_mask[0] * self.model.scheduler.latents
wangshankun's avatar
wangshankun committed
365

366
367
368
369
            if self.progress_callback:
                segment_progress = (segment_idx * total_steps + step_index + 1) / (self.total_segments * total_steps)
                self.progress_callback(int(segment_progress * 100), 100)

370
371
372
373
374
375
376
        # Decode latents
        latents = self.model.scheduler.latents
        generator = self.model.scheduler.generator
        gen_video = self.vae_decoder.decode(latents, generator=generator, config=self.config)
        gen_video = torch.clamp(gen_video, -1, 1).to(torch.float)

        return gen_video
wangshankun's avatar
wangshankun committed
377
378


379
@RUNNER_REGISTER("wan2.1_audio")
380
class WanAudioRunner(WanRunner):  # type:ignore
381
382
383
384
385
386
    def __init__(self, config):
        super().__init__(config)
        self._audio_adapter_pipe = None
        self._audio_processor = None
        self._video_generator = None
        self._audio_preprocess = None
PengGao's avatar
PengGao committed
387

388
389
390
391
392
        if self.seq_p_group is None:
            self.sp_size = 1
        else:
            self.sp_size = dist.get_world_size(self.seq_p_group)

393
    def initialize(self):
394
        """Initialize all models once for multiple runs"""
wangshankun's avatar
wangshankun committed
395

396
397
398
399
        # Initialize audio processor
        audio_sr = self.config.get("audio_sr", 16000)
        target_fps = self.config.get("target_fps", 16)
        self._audio_processor = AudioProcessor(audio_sr, target_fps)
PengGao's avatar
PengGao committed
400

401
402
        # Initialize scheduler
        self.init_scheduler()
wangshankun's avatar
wangshankun committed
403

wangshankun's avatar
wangshankun committed
404
    def init_scheduler(self):
405
        """Initialize consistency model scheduler"""
wangshankun's avatar
wangshankun committed
406
        scheduler = ConsistencyModelScheduler(self.config)
wangshankun's avatar
wangshankun committed
407
408
        self.model.set_scheduler(scheduler)

409
410
411
412
    def load_audio_adapter_lazy(self):
        """Lazy load audio adapter when needed"""
        if self._audio_adapter_pipe is not None:
            return self._audio_adapter_pipe
wangshankun's avatar
wangshankun committed
413

414
        # Audio adapter
wangshankun's avatar
wangshankun committed
415
        audio_adapter_path = self.config["model_path"] + "/audio_adapter.safetensors"
416
        audio_adapter = AudioAdapter.from_transformer(
wangshankun's avatar
wangshankun committed
417
418
419
420
421
422
423
            self.model,
            audio_feature_dim=1024,
            interval=1,
            time_freq_dim=256,
            projection_transformer_layers=4,
        )

424
        # Audio encoder
gushiqiao's avatar
gushiqiao committed
425
426
427
428
429
        cpu_offload = self.config.get("cpu_offload", False)
        if cpu_offload:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda")
wangshankun's avatar
wangshankun committed
430
        audio_encoder_repo = self.config["model_path"] + "/audio_encoder"
wangshankun's avatar
wangshankun committed
431
432
433
434
435
436

        if self.model.transformer_infer.seq_p_group is not None:
            seq_p_group = self.model.transformer_infer.seq_p_group
        else:
            seq_p_group = None

437
        audio_adapter = rank0_load_state_dict_from_path(audio_adapter, audio_adapter_path, strict=False)
438

wangshankun's avatar
wangshankun committed
439
440
441
        self._audio_adapter_pipe = AudioAdapterPipe(
            audio_adapter, audio_encoder_repo=audio_encoder_repo, dtype=GET_DTYPE(), device=device, weight=1.0, cpu_offload=cpu_offload, seq_p_group=seq_p_group
        )
wangshankun's avatar
wangshankun committed
442

443
444
445
446
447
448
449
450
        return self._audio_adapter_pipe

    def prepare_inputs(self):
        """Prepare inputs for the model"""
        image_encoder_output = None

        if os.path.isfile(self.config.image_path):
            with ProfilingContext("Run Img Encoder"):
451
                vae_encoder_out, clip_encoder_out = self.run_image_encoder(self.config, self.vae_encoder)
452
453
                image_encoder_output = {
                    "clip_encoder_out": clip_encoder_out,
454
                    "vae_encoder_out": vae_encoder_out,
455
456
457
458
459
460
461
462
463
464
465
466
467
                }

        with ProfilingContext("Run Text Encoder"):
            img = Image.open(self.config["image_path"]).convert("RGB")
            text_encoder_output = self.run_text_encoder(self.config["prompt"], img)

        self.set_target_shape()

        return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output, "audio_adapter_pipe": self.load_audio_adapter_lazy()}

    def run_pipeline(self, save_video=True):
        """Optimized pipeline with modular components"""

468
469
        try:
            self.initialize()
470

471
472
            assert self._audio_processor is not None
            assert self._audio_preprocess is not None
473

474
            self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
475

476
477
478
            with memory_efficient_inference():
                if self.config["use_prompt_enhancer"]:
                    self.config["prompt_enhanced"] = self.post_prompt_enhancer()
479

480
481
482
483
                self.inputs = self.prepare_inputs()
                # Re-initialize scheduler after image encoding sets correct dimensions
                self.init_scheduler()
                self.model.scheduler.prepare(self.inputs["image_encoder_output"])
484

485
486
            # Re-create video generator with updated model/scheduler
            self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)
487

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            # Process audio
            audio_array = self._audio_processor.load_audio(self.config["audio_path"])
            video_duration = self.config.get("video_duration", 5)
            target_fps = self.config.get("target_fps", 16)
            max_num_frames = self.config.get("target_video_length", 81)

            audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
            expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)

            # Segment audio
            audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)

            self._video_generator.total_segments = len(audio_segments)

            # Generate video segments
            gen_video_list = []
            cut_audio_list = []
            prev_video = None

            for idx, segment in enumerate(audio_segments):
                self.config.seed = self.config.seed + idx
                torch.manual_seed(self.config.seed)
                logger.info(f"Processing segment {idx + 1}/{len(audio_segments)}, seed: {self.config.seed}")

                # Process audio features
                audio_features = self._audio_preprocess(segment.audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)

                # Generate video segment
                with memory_efficient_inference():
                    gen_video = self._video_generator.generate_segment(
                        self.inputs.copy(),  # Copy to avoid modifying original
                        audio_features,
                        prev_video=prev_video,
                        prev_frame_length=5,
                        segment_idx=idx,
                    )

                # Extract relevant frames
                start_frame = 0 if idx == 0 else 5
                start_audio_frame = 0 if idx == 0 else int(6 * self._audio_processor.audio_sr / target_fps)

                if segment.is_last and segment.useful_length:
                    end_frame = segment.end_frame - segment.start_frame
                    gen_video_list.append(gen_video[:, :, start_frame:end_frame].cpu())
                    cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
                elif segment.useful_length and expected_frames < max_num_frames:
                    gen_video_list.append(gen_video[:, :, start_frame:expected_frames].cpu())
                    cut_audio_list.append(segment.audio_array[start_audio_frame : segment.useful_length])
                else:
                    gen_video_list.append(gen_video[:, :, start_frame:].cpu())
                    cut_audio_list.append(segment.audio_array[start_audio_frame:])

                # Update prev_video for next iteration
                prev_video = gen_video

                # Clean up GPU memory after each segment
                del gen_video
                torch.cuda.empty_cache()

            # Merge results
548
            with memory_efficient_inference():
549
550
551
552
553
554
555
556
557
558
559
560
                gen_lvideo = torch.cat(gen_video_list, dim=2).float()
                merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
                comfyui_images = vae_to_comfyui_image(gen_lvideo)

            # Apply frame interpolation if configured
            if "video_frame_interpolation" in self.config and self.vfi_model is not None:
                interpolation_target_fps = self.config["video_frame_interpolation"]["target_fps"]
                logger.info(f"Interpolating frames from {target_fps} to {interpolation_target_fps}")
                comfyui_images = self.vfi_model.interpolate_frames(
                    comfyui_images,
                    source_fps=target_fps,
                    target_fps=interpolation_target_fps,
561
                )
562
                target_fps = interpolation_target_fps
563

564
565
566
            # Convert audio to ComfyUI format
            audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
            comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
567

568
            # Save video if requested
569
570
571
            if (self.config.get("device_mesh") is not None and dist.get_rank() == 0) or self.config.get("device_mesh") is None:
                if save_video and self.config.get("save_video_path", None):
                    self._save_video_with_audio(comfyui_images, merge_audio, target_fps)
572

573
574
575
576
            # Final cleanup
            self.end_run()

            return comfyui_images, comfyui_audio
577

578
579
580
581
582
        finally:
            self._video_generator = None
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
583
584
585
586
587
588
589
590
591
592
593
594
595

    def _save_video_with_audio(self, images, audio_array, fps):
        """Save video with audio"""
        import tempfile

        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as video_tmp:
            video_path = video_tmp.name

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as audio_tmp:
            audio_path = audio_tmp.name

        try:
            save_to_video(images, video_path, fps)
596
            ta.save(audio_path, torch.tensor(audio_array[None]), sample_rate=self._audio_processor.audio_sr)  # type: ignore
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

            output_path = self.config.get("save_video_path")
            parent_dir = os.path.dirname(output_path)
            if parent_dir and not os.path.exists(parent_dir):
                os.makedirs(parent_dir, exist_ok=True)

            subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_path, "-i", audio_path, output_path])

            logger.info(f"Saved video with audio to: {output_path}")

        finally:
            # Clean up temp files
            if os.path.exists(video_path):
                os.remove(video_path)
            if os.path.exists(audio_path):
                os.remove(audio_path)
wangshankun's avatar
wangshankun committed
613
614

    def load_transformer(self):
615
        """Load transformer with LoRA support"""
616
        base_model = WanAudioModel(self.config.model_path, self.config, self.init_device, self.seq_p_group)
617
        if self.config.get("lora_configs") and self.config.lora_configs:
wangshankun's avatar
wangshankun committed
618
619
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)
            lora_wrapper = WanLoraWrapper(base_model)
620
621
622
623
624
625
            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                lora_name = lora_wrapper.load_lora(lora_path)
                lora_wrapper.apply_lora(lora_name, strength)
                logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
wangshankun's avatar
wangshankun committed
626

627
628
629
        # XXX: trick
        self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")

wangshankun's avatar
wangshankun committed
630
631
632
        return base_model

    def run_image_encoder(self, config, vae_model):
633
634
        """Run image encoder"""

wangshankun's avatar
wangshankun committed
635
636
        ref_img = Image.open(config.image_path)
        ref_img = (np.array(ref_img).astype(np.float32) - 127.5) / 127.5
gushiqiao's avatar
gushiqiao committed
637
        ref_img = torch.from_numpy(ref_img).cuda()
wangshankun's avatar
wangshankun committed
638
639
640
        ref_img = rearrange(ref_img, "H W C -> 1 C H W")
        ref_img = ref_img[:, :3]

641
642
643
        adaptive = config.get("adaptive_resize", False)

        if adaptive:
644
            # Use adaptive_resize to modify aspect ratio
645
646
647
648
649
            ref_img, h, w = adaptive_resize(ref_img)

            patched_h = h // self.config.vae_stride[1] // self.config.patch_size[1]
            patched_w = w // self.config.vae_stride[2] // self.config.patch_size[2]

650
651
652
653
654
        else:
            h, w = ref_img.shape[2:]
            aspect_ratio = h / w
            max_area = config.target_height * config.target_width

655
656
657
            patched_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1])
            patched_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2])

658
        patched_h, patched_w = get_optimal_patched_size_with_sp(patched_h, patched_w, 1)
659

660
661
        config.lat_h = patched_h * self.config.patch_size[1]
        config.lat_w = patched_w * self.config.patch_size[2]
662

663
664
        config.tgt_h = config.lat_h * self.config.vae_stride[1]
        config.tgt_w = config.lat_w * self.config.vae_stride[2]
665

666
        logger.info(f"[wan_audio] adaptive_resize: {adaptive}, tgt_h: {config.tgt_h}, tgt_w: {config.tgt_w}, lat_h: {config.lat_h}, lat_w: {config.lat_w}")
667

668
        cond_frms = torch.nn.functional.interpolate(ref_img, size=(config.tgt_h, config.tgt_w), mode="bicubic")
669
670

        # clip encoder
gushiqiao's avatar
gushiqiao committed
671
        clip_encoder_out = self.image_encoder.visual([cond_frms]).squeeze(0).to(GET_DTYPE()) if self.config.get("use_image_encoder", True) else None
672
673

        # vae encode
674
        cond_frms = rearrange(cond_frms, "1 C H W -> 1 C 1 H W")
675
        vae_encoder_out = vae_model.encode(cond_frms.to(torch.float), config)
wangshankun's avatar
wangshankun committed
676
677
678
679
680
681

        if self.config.model_cls == "wan2.2_audio":
            vae_encoder_out = vae_encoder_out.unsqueeze(0).to(GET_DTYPE())
        else:
            if isinstance(vae_encoder_out, list):
                vae_encoder_out = torch.stack(vae_encoder_out, dim=0).to(GET_DTYPE())
wangshankun's avatar
wangshankun committed
682

683
        return vae_encoder_out, clip_encoder_out
wangshankun's avatar
wangshankun committed
684
685

    def set_target_shape(self):
686
        """Set target shape for generation"""
wangshankun's avatar
wangshankun committed
687
688
        ret = {}
        num_channels_latents = 16
wangshankun's avatar
wangshankun committed
689
690
        if self.config.model_cls == "wan2.2_audio":
            num_channels_latents = self.config.num_channels_latents
691

wangshankun's avatar
wangshankun committed
692
693
694
695
696
697
698
699
700
701
702
        if self.config.task == "i2v":
            self.config.target_shape = (
                num_channels_latents,
                (self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
                self.config.lat_h,
                self.config.lat_w,
            )
            ret["lat_h"] = self.config.lat_h
            ret["lat_w"] = self.config.lat_w
        else:
            error_msg = "t2v task is not supported in WanAudioRunner"
703
            assert False, error_msg
wangshankun's avatar
wangshankun committed
704
705
706

        ret["target_shape"] = self.config.target_shape
        return ret
wangshankun's avatar
wangshankun committed
707

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    def run_step(self):
        """Optimized pipeline with modular components"""

        self.initialize()

        assert self._audio_processor is not None
        assert self._audio_preprocess is not None

        self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)

        with memory_efficient_inference():
            if self.config["use_prompt_enhancer"]:
                self.config["prompt_enhanced"] = self.post_prompt_enhancer()

            self.inputs = self.prepare_inputs()
            # Re-initialize scheduler after image encoding sets correct dimensions
            self.init_scheduler()
            self.model.scheduler.prepare(self.inputs["image_encoder_output"])

        # Re-create video generator with updated model/scheduler
        self._video_generator = VideoGenerator(self.model, self.vae_encoder, self.vae_decoder, self.config, self.progress_callback)

        # Process audio
        audio_array = self._audio_processor.load_audio(self.config["audio_path"])
        video_duration = self.config.get("video_duration", 5)
        target_fps = self.config.get("target_fps", 16)
        max_num_frames = self.config.get("target_video_length", 81)

        audio_len = int(audio_array.shape[0] / self._audio_processor.audio_sr * target_fps)
        expected_frames = min(max(1, int(video_duration * target_fps)), audio_len)

        # Segment audio
        audio_segments = self._audio_processor.segment_audio(audio_array, expected_frames, max_num_frames)

        self._video_generator.total_segments = len(audio_segments)

        # Generate video segments
        prev_video = None

        torch.manual_seed(self.config.seed)
        # Process audio features
        audio_features = self._audio_preprocess(audio_segments[0].audio_array, sampling_rate=self._audio_processor.audio_sr, return_tensors="pt").input_values.squeeze(0).to(self.model.device)

        # Generate video segment
        with memory_efficient_inference():
            self._video_generator.generate_segment(
                self.inputs.copy(),  # Copy to avoid modifying original
                audio_features,
                prev_video=prev_video,
                prev_frame_length=5,
                segment_idx=0,
759
                total_steps=1,
760
761
762
763
            )
            # Final cleanup
            self.end_run()

wangshankun's avatar
wangshankun committed
764

wangshankun's avatar
wangshankun committed
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
@RUNNER_REGISTER("wan2.2_audio")
class Wan22AudioRunner(WanAudioRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_vae_decoder(self):
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
        }
        vae_decoder = Wan2_2_VAE(**vae_config)
        return vae_decoder

    def load_vae_encoder(self):
        # offload config
        vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
        if vae_offload:
            vae_device = torch.device("cpu")
        else:
            vae_device = torch.device("cuda")
        vae_config = {
            "vae_pth": find_torch_model_path(self.config, "vae_pth", "Wan2.2_VAE.pth"),
            "device": vae_device,
            "cpu_offload": vae_offload,
            "offload_cache": self.config.get("vae_offload_cache", False),
        }
        if self.config.task != "i2v":
            return None
        else:
            return Wan2_2_VAE(**vae_config)

    def load_vae(self):
        vae_encoder = self.load_vae_encoder()
        vae_decoder = self.load_vae_decoder()
        return vae_encoder, vae_decoder

809

wangshankun's avatar
wangshankun committed
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
@RUNNER_REGISTER("wan2.2_moe_audio")
class Wan22MoeAudioRunner(WanAudioRunner):
    def __init__(self, config):
        super().__init__(config)

    def load_transformer(self):
        # encoder -> high_noise_model -> low_noise_model -> vae -> video_output
        high_noise_model = Wan22MoeAudioModel(
            os.path.join(self.config.model_path, "high_noise_model"),
            self.config,
            self.init_device,
        )
        low_noise_model = Wan22MoeAudioModel(
            os.path.join(self.config.model_path, "low_noise_model"),
            self.config,
            self.init_device,
        )

        if self.config.get("lora_configs") and self.config.lora_configs:
            assert not self.config.get("dit_quantized", False) or self.config.mm_config.get("weight_auto_quant", False)

            for lora_config in self.config.lora_configs:
                lora_path = lora_config["path"]
                strength = lora_config.get("strength", 1.0)
                if lora_config.name == "high_noise_model":
                    lora_wrapper = WanLoraWrapper(high_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")

                if lora_config.name == "low_noise_model":
                    lora_wrapper = WanLoraWrapper(low_noise_model)
                    lora_name = lora_wrapper.load_lora(lora_path)
                    lora_wrapper.apply_lora(lora_name, strength)
                    logger.info(f"{lora_config.name} Loaded LoRA: {lora_name} with strength: {strength}")
        # XXX: trick
        self._audio_preprocess = AutoFeatureExtractor.from_pretrained(self.config["model_path"], subfolder="audio_encoder")

        return MultiModelStruct([high_noise_model, low_noise_model], self.config, self.config.boundary)