end2end.py 7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import argparse
import os
from typing import cast

from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        default="ByteDance-Seed/BAGEL-7B-MoT",
        help="Path to merged model directory.",
    )
    parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
    parser.add_argument(
        "--txt-prompts",
        type=str,
        default=None,
        help="Path to a .txt file with one prompt per line (preferred).",
    )
    parser.add_argument("--prompt_type", default="text", choices=["text"])

    parser.add_argument(
        "--modality",
        default="text2img",
        choices=["text2img", "img2img", "img2text", "text2text"],
        help="Modality mode to control stage execution.",
    )

    parser.add_argument(
        "--image-path",
        type=str,
        default=None,
        help="Path to input image for img2img.",
    )

    # OmniLLM init args
    parser.add_argument("--enable-stats", action="store_true", default=False)
    parser.add_argument("--init-sleep-seconds", type=int, default=20)
    parser.add_argument("--batch-timeout", type=int, default=5)
    parser.add_argument("--init-timeout", type=int, default=300)
    parser.add_argument("--shm-threshold-bytes", type=int, default=65536)
    parser.add_argument("--worker-backend", type=str, default="process", choices=["process", "ray"])
    parser.add_argument("--ray-address", type=str, default=None)
    parser.add_argument("--stage-configs-path", type=str, default=None)
    parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    model_name = args.model
    prompts: list[OmniPromptType] = []
    try:
        # Preferred: load from txt file (one prompt per line)
        if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
            with open(args.txt_prompts, encoding="utf-8") as f:
                lines = [ln.strip() for ln in f.readlines()]
            prompts = [ln for ln in lines if ln != ""]
            print(f"[Info] Loaded {len(prompts)} prompts from {args.txt_prompts}")
        else:
            prompts = args.prompts
    except Exception as e:
        print(f"[Error] Failed to load prompts: {e}")
        raise

    if not prompts:
        # Default prompt for text2img test if none provided
        prompts = ["<|im_start|>A cute cat<|im_end|>"]
        print(f"[Info] No prompts provided, using default: {prompts}")
    omni_outputs = []

    from PIL import Image

    if args.modality == "img2img":
        from PIL import Image

        from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion

        print("[Info] Running in img2img mode (Stage 1 only)")
        client = OmniDiffusion(model=model_name)

        if args.image_path:
            if os.path.exists(args.image_path):
                loaded_image = Image.open(args.image_path).convert("RGB")
                prompts = [
                    {
                        "prompt": cast(str, p),
                        "multi_modal_data": {"image": loaded_image},
                    }
                    for p in prompts
                ]
            else:
                print(f"[Warning] Image path {args.image_path} does not exist.")

        result = client.generate(
            prompts,
            OmniDiffusionSamplingParams(
                seed=52,
                need_kv_receive=False,
                num_inference_steps=args.steps,
            ),
        )

        # Ensure result is a list for iteration
        if not isinstance(result, list):
            omni_outputs = [result]
        else:
            omni_outputs = result

    else:
        from vllm_omni.entrypoints.omni import Omni

        omni_kwargs = {}
        if args.stage_configs_path:
            omni_kwargs["stage_configs_path"] = args.stage_configs_path

        omni_kwargs.update(
            {
                "log_stats": args.enable_stats,
                "init_sleep_seconds": args.init_sleep_seconds,
                "batch_timeout": args.batch_timeout,
                "init_timeout": args.init_timeout,
                "shm_threshold_bytes": args.shm_threshold_bytes,
                "worker_backend": args.worker_backend,
                "ray_address": args.ray_address,
            }
        )

        omni = Omni(model=model_name, **omni_kwargs)

        formatted_prompts = []
        for p in args.prompts:
            if args.modality == "img2text":
                if args.image_path:
                    loaded_image = Image.open(args.image_path).convert("RGB")
                    final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
                    prompt_dict = {
                        "prompt": final_prompt_text,
                        "multi_modal_data": {"image": loaded_image},
                        "modalities": ["text"],
                    }
                    formatted_prompts.append(prompt_dict)
            elif args.modality == "text2text":
                final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
                prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
                formatted_prompts.append(prompt_dict)
            else:
                # text2img
                final_prompt_text = f"<|im_start|>{p}<|im_end|>"
                prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
                formatted_prompts.append(prompt_dict)

        params_list = omni.default_sampling_params_list
        if args.modality == "text2img":
            params_list[0].max_tokens = 1  # type: ignore # The first stage is a SamplingParam (vllm)
            if len(params_list) > 1:
                params_list[1].num_inference_steps = args.steps  # type: ignore # The second stage is an OmniDiffusionSamplingParam

        omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))

    for i, req_output in enumerate(omni_outputs):
        images = getattr(req_output, "images", None)
        if not images and hasattr(req_output, "output"):
            if isinstance(req_output.output, list):
                images = req_output.output
            else:
                images = [req_output.output]

        if images:
            for j, img in enumerate(images):
                img.save(f"output_{i}_{j}.png")

        if hasattr(req_output, "request_output") and req_output.request_output:
            for stage_out in req_output.request_output:
                if hasattr(stage_out, "images") and stage_out.images:
                    for k, img in enumerate(stage_out.images):
                        save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png"
                        img.save(save_path)
                        print(f"[Info] Saved stage output image to {save_path}")

    print(omni_outputs)


if __name__ == "__main__":
    main()