stream_infer.py 4.84 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os

import numpy as np
import torch
import torch.distributed as dist
import torchaudio as ta
from loguru import logger

from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs
from lightx2v.shot_runner.utils import SlidingWindowReader, save_audio, save_to_video
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import seed_all


class ShotStreamPipeline(ShotPipeline):  # type:ignore
    def __init__(self, config):
        super().__init__(config)

    @torch.no_grad()
    def generate(self):
        s2v = self.clip_generators["s2v_clip"]  # s2v一致性强,动态相应差
        f2v = self.clip_generators["f2v_clip"]  # f2v一致性差,动态响应强
        # 根据 pipe 最长 overlap_len 初始化 tail buffer
        self.max_tail_len = max(s2v.prev_frame_length, f2v.prev_frame_length)
        self.global_tail_video = None

        gen_video_list = []
        cut_audio_list = []

        audio_array, ori_sr = ta.load(self.shot_cfg.audio_path)
        audio_array = audio_array.mean(0)
        if ori_sr != 16000:
            audio_array = ta.functional.resample(audio_array, ori_sr, 16000)
        audio_reader = SlidingWindowReader(audio_array, frame_len=33)

        # Demo 交替生成 clip
        i = 0
        overlap = 0
        while True:
            audio_clip = audio_reader.next_frame(overlap=overlap)
            if audio_clip is None:
                break

            if i % 2 == 0:
                pipe = s2v
                inputs = self.clip_inputs["s2v_clip"]
            else:
                pipe = f2v
                inputs = self.clip_inputs["f2v_clip"]
                inputs.prompt = "A man speaks to the camera with a slightly furrowed brow and focused gaze. He raises both hands upward in powerful, emphatic gestures. "  # 添加动作提示

            inputs.seed = inputs.seed + i  # 不同 clip 使用不同随机种子
            inputs.audio_clip = audio_clip
            i = i + 1

            if self.global_tail_video is not None:  # 根据当前 pipe 需要多少 overlap_len 来裁剪 tail
                inputs.overlap_frame = self.global_tail_video[:, :, -pipe.prev_frame_length :]

            gen_clip_video, audio_clip, _ = pipe.run_clip_pipeline(inputs)

            aligned_len = gen_clip_video.shape[2] - overlap
            gen_video_list.append(gen_clip_video[:, :, :aligned_len])
            cut_audio_list.append(audio_clip[: aligned_len * audio_reader.audio_per_frame])

            overlap = pipe.prev_frame_length
            self.global_tail_video = gen_clip_video[:, :, -self.max_tail_len :]

        gen_lvideo = torch.cat(gen_video_list, dim=2).float()
        gen_lvideo = torch.clamp(gen_lvideo, -1, 1)
        merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
        out_path = os.path.join("./", "video_merge.mp4")
        audio_file = os.path.join("./", "audio_merge.wav")

        save_to_video(gen_lvideo, out_path, 16)
        save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path)
        os.remove(out_path)
        os.remove(audio_file)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
    parser.add_argument("--config_json", type=str, required=True)
    parser.add_argument("--prompt", type=str, default="", help="The input prompt for text-to-video generation")
    parser.add_argument("--negative_prompt", type=str, default="")
    parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
    parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task")
    parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
    parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
    parser.add_argument("--target_shape", nargs="+", default=[], help="Set return video or image shape")
    args = parser.parse_args()

    seed_all(args.seed)

    clip_configs = load_clip_configs(args.config_json)

    shot_cfg = ShotConfig(
        seed=args.seed,
        image_path=args.image_path,
        audio_path=args.audio_path,
        prompt=args.prompt,
        negative_prompt=args.negative_prompt,
        save_result_path=args.save_result_path,
        clip_configs=clip_configs,
        target_shape=args.target_shape,
    )

    with ProfilingContext4DebugL1("Total Cost"):
        shot_stream_pipe = ShotStreamPipeline(shot_cfg)
        shot_stream_pipe.generate()

    # Clean up distributed process group
    if dist.is_initialized():
        dist.destroy_process_group()
        logger.info("Distributed process group cleaned up")


if __name__ == "__main__":
    main()