inference_engine.py 7.65 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# type: ignore

# SPDX-License-Identifier: Apache-2.0
"""
Inference module for diffusion models.

This module provides classes and functions for running inference with diffusion models.
"""

import time
from typing import Any, Dict

import torch

from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.pipelines import (ComposedPipelineBase, ForwardBatch,
                                    build_pipeline)
# TODO(will): remove, check if this is hunyuan specific
from fastvideo.v1.utils import align_to

logger = init_logger(__name__)


class InferenceEngine:
    """
    Engine for running inference with diffusion models.
    """

    def __init__(
        self,
        pipeline: ComposedPipelineBase,
        fastvideo_args: FastVideoArgs,
    ):
        """
        Initialize the inference engine.
        
        Args:
            pipeline: The pipeline to use for inference.
            fastvideo_args: The inference arguments.
            default_negative_prompt: The default negative prompt to use.
        """
        self.pipeline = pipeline
        self.fastvideo_args = fastvideo_args

    @classmethod
    def create_engine(
        cls,
        fastvideo_args: FastVideoArgs,
    ) -> "InferenceEngine":
        """
        Create an inference engine with the specified arguments.
        
        Args:
            fastvideo_args: The inference arguments.
            model_loader_cls: The model loader class to use. If None, it will be
                determined from the model type.
            pipeline_type: The type of pipeline to create. If None, it will be
                determined from the model type.
                
        Returns:
            The created inference engine.
            
        Raises:
            ValueError: If the model type is not recognized or if the pipeline type
                is not recognized.
        """

        logger.info("Building pipeline...")

        # TODO(will): I don't really like this api.
        # it should be something closer to pipeline_cls.from_pretrained(...)
        # this way for training we can just do pipeline_cls.from_pretrained(
        # checkpoint_path) and have it handle everything.
        # TODO(Peiyuan): Then maybe we should only pass in model path and device, not the entire inference args?
        pipeline = build_pipeline(fastvideo_args)
        logger.info("Pipeline Ready")

        # Create the inference engine
        return cls(pipeline, fastvideo_args)

    def run(
        self,
        prompt: str,
        fastvideo_args: FastVideoArgs,
    ) -> Dict[str, Any]:
        """
        Run inference with the pipeline.
        
        Args:
            prompt: The prompt to use for generation.
            negative_prompt: The negative prompt to use. If None, the default will be used.
            seed: The random seed to use. If None, a random seed will be used.
            **kwargs: Additional arguments to pass to the pipeline.
            
        Returns:
            A dictionary containing the generated videos and metadata.
        """
        out_dict: Dict[str, Any] = dict()

        num_videos_per_prompt = fastvideo_args.num_videos
        seed = fastvideo_args.seed
        height = fastvideo_args.height
        width = fastvideo_args.width
        video_length = fastvideo_args.num_frames
        negative_prompt = fastvideo_args.neg_prompt
        infer_steps = fastvideo_args.num_inference_steps
        guidance_scale = fastvideo_args.guidance_scale
        flow_shift = fastvideo_args.flow_shift
        embedded_guidance_scale = fastvideo_args.embedded_cfg_scale
        image_path = fastvideo_args.image_path

        # ========================================================================
        # Arguments: target_width, target_height, target_video_length
        # ========================================================================
        if width <= 0 or height <= 0 or video_length <= 0:
            raise ValueError(
                f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
            )
        if (video_length - 1) % 4 != 0:
            raise ValueError(
                f"`video_length-1` must be a multiple of 4, got {video_length}")

        target_height = align_to(height, 16)
        target_width = align_to(width, 16)
        target_video_length = video_length

        out_dict["size"] = (target_height, target_width, target_video_length)

        # ========================================================================
        # Arguments: prompt, new_prompt, negative_prompt
        # ========================================================================
        if not isinstance(prompt, str):
            raise TypeError(
                f"`prompt` must be a string, but got {type(prompt)}")
        prompt = prompt.strip()

        # negative prompt
        if negative_prompt is not None:
            negative_prompt = negative_prompt.strip()

        # TODO(PY): move to hunyuan stage
        latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
        n_tokens = latents_size[0] * latents_size[1] * latents_size[2]

        # ========================================================================
        # Print infer args
        # ========================================================================
        debug_str = f"""
                        height: {target_height}
                         width: {target_width}
                  video_length: {target_video_length}
                        prompt: {prompt}
                    neg_prompt: {negative_prompt}
                          seed: {seed}
                   infer_steps: {infer_steps}
         num_videos_per_prompt: {num_videos_per_prompt}
                guidance_scale: {guidance_scale}
                      n_tokens: {n_tokens}
                    flow_shift: {flow_shift}
       embedded_guidance_scale: {embedded_guidance_scale}"""
        logger.info(debug_str)
        # return
        # sp_group = get_sp_group()
        # local_rank = sp_group.rank
        device = torch.device(fastvideo_args.device_str)
        batch = ForwardBatch(
            image_path=image_path,
            prompt=prompt,
            negative_prompt=negative_prompt,
            num_videos_per_prompt=num_videos_per_prompt,
            height=fastvideo_args.height,
            width=fastvideo_args.width,
            num_frames=fastvideo_args.num_frames,
            num_inference_steps=fastvideo_args.num_inference_steps,
            guidance_scale=fastvideo_args.guidance_scale,
            # generator=generator,
            eta=0.0,
            n_tokens=n_tokens,
            data_type="video" if fastvideo_args.num_frames > 1 else "image",
            device=device,
            extra={},  # Any additional parameters
        )

        print('===============================================')
        print(batch)
        print('===============================================')
        print('===============================================')
        print(fastvideo_args)

        # ========================================================================
        # Pipeline inference
        # ========================================================================
        start_time = time.time()
        samples = self.pipeline.forward(
            batch=batch,
            fastvideo_args=fastvideo_args,
        ).output
        # TODO(will): fix and move to hunyuan stage
        # out_dict["seeds"] = batch.seeds
        out_dict["samples"] = samples
        out_dict["prompts"] = prompt

        gen_time = time.time() - start_time
        logger.info("Success, time: %s", gen_time)

        return out_dict