gradio_web_demo.py 7.43 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import argparse
import os
import tempfile

import gradio as gr
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video

from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
from fastvideo.models.mochi_hf.pipeline_mochi import MochiPipeline


def init_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompts", nargs="+", default=[])
    parser.add_argument("--num_frames", type=int, default=25)
    parser.add_argument("--height", type=int, default=480)
    parser.add_argument("--width", type=int, default=848)
    parser.add_argument("--num_inference_steps", type=int, default=8)
    parser.add_argument("--guidance_scale", type=float, default=4.5)
    parser.add_argument("--model_path", type=str, default="data/mochi")
    parser.add_argument("--seed", type=int, default=12345)
    parser.add_argument("--transformer_path", type=str, default=None)
    parser.add_argument("--scheduler_type", type=str, default="pcm_linear_quadratic")
    parser.add_argument("--lora_checkpoint_dir", type=str, default=None)
    parser.add_argument("--shift", type=float, default=8.0)
    parser.add_argument("--num_euler_timesteps", type=int, default=50)
    parser.add_argument("--linear_threshold", type=float, default=0.1)
    parser.add_argument("--linear_range", type=float, default=0.75)
    parser.add_argument("--cpu_offload", action="store_true")
    return parser.parse_args()


def load_model(args):
    if args.scheduler_type == "euler":
        scheduler = FlowMatchEulerDiscreteScheduler()
    else:
        linear_quadratic = True if "linear_quadratic" in args.scheduler_type else False
        scheduler = PCMFMScheduler(
            1000,
            args.shift,
            args.num_euler_timesteps,
            linear_quadratic,
            args.linear_threshold,
            args.linear_range,
        )

    if args.transformer_path:
        transformer = MochiTransformer3DModel.from_pretrained(args.transformer_path)
    else:
        transformer = MochiTransformer3DModel.from_pretrained(args.model_path, subfolder="transformer/")

    pipe = MochiPipeline.from_pretrained(args.model_path, transformer=transformer, scheduler=scheduler)
    pipe.enable_vae_tiling()
    # pipe.to(device)
    # if args.cpu_offload:
    pipe.enable_sequential_cpu_offload()
    return pipe


def generate_video(
    prompt,
    negative_prompt,
    use_negative_prompt,
    seed,
    guidance_scale,
    num_frames,
    height,
    width,
    num_inference_steps,
    randomize_seed=False,
):
    if randomize_seed:
        seed = torch.randint(0, 1000000, (1, )).item()

    generator = torch.Generator(device="cuda").manual_seed(seed)

    if not use_negative_prompt:
        negative_prompt = None

    with torch.autocast("cuda", dtype=torch.bfloat16):
        output = pipe(
            prompt=[prompt],
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
        ).frames[0]

    output_path = os.path.join(tempfile.mkdtemp(), "output.mp4")
    export_to_video(output, output_path, fps=30)
    return output_path, seed


examples = [
    "A hand enters the frame, pulling a sheet of plastic wrap over three balls of dough placed on a wooden surface. The plastic wrap is stretched to cover the dough more securely. The hand adjusts the wrap, ensuring that it is tight and smooth over the dough. The scene focuses on the hand’s movements as it secures the edges of the plastic wrap. No new objects appear, and the camera remains stationary, focusing on the action of covering the dough.",
    "A vintage train snakes through the mountains, its plume of white steam rising dramatically against the jagged peaks. The cars glint in the late afternoon sun, their deep crimson and gold accents lending a touch of elegance. The tracks carve a precarious path along the cliffside, revealing glimpses of a roaring river far below. Inside, passengers peer out the large windows, their faces lit with awe as the landscape unfolds.",
    "A crowded rooftop bar buzzes with energy, the city skyline twinkling like a field of stars in the background. Strings of fairy lights hang above, casting a warm, golden glow over the scene. Groups of people gather around high tables, their laughter blending with the soft rhythm of live jazz. The aroma of freshly mixed cocktails and charred appetizers wafts through the air, mingling with the cool night breeze.",
]

args = init_args()
pipe = load_model(args)
print("load model successfully")
with gr.Blocks() as demo:
    gr.Markdown("# Fastvideo Mochi Video Generation Demo")

    with gr.Group():
        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            run_button = gr.Button("Run", scale=0)
        result = gr.Video(label="Result", show_label=False)

    with gr.Accordion("Advanced options", open=False):
        with gr.Group():
            with gr.Row():
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=1024,
                    step=32,
                    value=args.height,
                )
                width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=args.width)

            with gr.Row():
                num_frames = gr.Slider(
                    label="Number of Frames",
                    minimum=21,
                    maximum=163,
                    value=args.num_frames,
                )
                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=12,
                    value=args.guidance_scale,
                )
                num_inference_steps = gr.Slider(
                    label="Inference Steps",
                    minimum=4,
                    maximum=100,
                    value=args.num_inference_steps,
                )

            with gr.Row():
                use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                visible=False,
            )

            seed = gr.Slider(label="Seed", minimum=0, maximum=1000000, step=1, value=args.seed)
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            seed_output = gr.Number(label="Used Seed")

    gr.Examples(examples=examples, inputs=prompt)

    use_negative_prompt.change(
        fn=lambda x: gr.update(visible=x),
        inputs=use_negative_prompt,
        outputs=negative_prompt,
    )

    run_button.click(
        fn=generate_video,
        inputs=[
            prompt,
            negative_prompt,
            use_negative_prompt,
            seed,
            guidance_scale,
            num_frames,
            height,
            width,
            num_inference_steps,
            randomize_seed,
        ],
        outputs=[result, seed_output],
    )

if __name__ == "__main__":
    demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=7860)