infer_dreambooth.py 2.34 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os

import torch
from diffusers import StableDiffusionPipeline
from modelscope import snapshot_download


def parse_args():
    parser = argparse.ArgumentParser(description='Simple example of a dreambooth inference.')
    parser.add_argument(
        '--model_path',
        type=str,
        default=None,
        required=True,
        help='Path to trained model.',
    )
    parser.add_argument(
        '--prompt',
        type=str,
        default=None,
        required=True,
        help='The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
    )
    parser.add_argument(
        '--image_save_path',
        type=str,
        default=None,
        required=True,
        help='The path to save generated image',
    )
    parser.add_argument(
        '--torch_dtype',
        type=str,
        default=None,
        choices=['no', 'fp16', 'bf16'],
        help=('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
              ' 1.10.and an Nvidia Ampere GPU.  Default to the value of the'
              ' mixed_precision passed with the `accelerate.launch` command in training script.'),
    )
    parser.add_argument(
        '--num_inference_steps',
        type=int,
        default=50,
        help=('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
                expense of slower inference.'),
    )
    parser.add_argument(
        '--guidance_scale',
        type=float,
        default=7.5,
        help=('A higher guidance scale value encourages the model to generate images closely linked to the text \
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'),
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if args.torch_dtype == 'fp16':
        torch_dtype = torch.float16
    elif args.torch_dtype == 'bf16':
        torch_dtype = torch.bfloat16
    else:
        torch_dtype = torch.float32

    pipe = StableDiffusionPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to('cuda')

    image = pipe(
        args.prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale).images[0]

    image.save(args.image_save_path)