example_inference.py 9.3 KB
Newer Older
dongchy920's avatar
dongchy920 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from dalle2_laion import DalleModelManager, ModelLoadConfig, utils
from dalle2_laion.scripts import BasicInference, ImageVariation, BasicInpainting
from typing import List
import os
import click
from pathlib import Path
import json
import torch
import pdb

@click.group()
@click.option('--verbose', '-v', is_flag=True, default=False, help='Print verbose output.')
@click.option('--suppress-updates', '-s', is_flag=True, default=False, help='Suppress updating models if checksums do not match.')
@click.pass_context
def inference(ctx, verbose, suppress_updates):
    ctx.obj['verbose'] = verbose
    ctx.obj['suppress_updates'] = suppress_updates

@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.pass_context
def test(ctx, model_config):
    model_config = ModelLoadConfig.from_json_path(model_config)
    if model_config.decoder is not None:
        for unet_source in model_config.decoder.unet_sources:
            print('Checksum:', unet_source.load_model_from.checksum_file_path)
    if model_config.prior is not None:
        print('Checksum:', model_config.prior.load_model_from.checksum_file_path)
    model_manager = DalleModelManager(model_config, check_updates=not ctx.obj['suppress_updates'])

@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/basic/', help='Path to output directory')
@click.option('--decoder-batch-size', default=10, help='Batch size for decoder')
@click.pass_context
def dream(ctx, model_config: str, output_path: str, decoder_batch_size: int):
    verbose = ctx.obj['verbose']
    prompts = []
    print("Enter your prompts one by one. Enter an empty prompt to finish.")
    while True:
        prompt = click.prompt(f'Prompt {len(prompts)+1}', default='', type=str, show_default=False)
        if prompt == '':
            break
        prompt_file = Path(prompt)
        if utils.is_text_file(prompt_file):
            # Then we can read the prompts line by line
            with open(prompt_file, 'r') as f:
                for line in f:
                    prompts.append(line.strip())
        elif utils.is_json_file(prompt_file):
            # Then we assume this is an array of prompts
            with open(prompt_file, 'r') as f:
                prompts.extend(json.load(f))
        else:
            prompts.append(prompt)
    num_prior_samples = click.prompt('How many samples would you like to generate for each prompt?', default=1, type=int)

    dreamer: BasicInference = BasicInference.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
    
    # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],record_shapes=True, profile_memory=False, with_stack=False) as prof:
    #     output_map = dreamer.run(prompts, prior_sample_count=num_prior_samples, decoder_batch_size=decoder_batch_size)
    #     # import pdb
    #     # pdb.set_trace()
    #     prof.step()
    # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
    # prof.export_chrome_trace('./dalle2_prof_dcu.json')

    output_map = dreamer.run(prompts, prior_sample_count=num_prior_samples, decoder_batch_size=decoder_batch_size)
    os.makedirs(output_path, exist_ok=True)
    for text in output_map:
        for embedding_index in output_map[text]:
            for image in output_map[text][embedding_index]:
                image.save(os.path.join(output_path, f"{text}_{embedding_index}.png"))

@inference.command()
@click.option('--model-config', default='./configs/variation.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/variations/', help='Path to output directory')
@click.option('--decoder-batch-size', default=10, help='Batch size for decoder')
@click.pass_context
def variation(ctx, model_config: str, output_path: str, decoder_batch_size: int):
    verbose = ctx.obj['verbose']
    variation: ImageVariation = ImageVariation.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
    decoder_data_requirements = variation.model_manager.decoder_info.data_requirements
    image_filepaths: List[Path] = []
    text_prompts: List[str] = [] if decoder_data_requirements.text_encoding else None

    print("Enter paths to your images. If you specify a directory all images within will be added. Enter an empty line to finish.")
    if decoder_data_requirements.text_encoding:
        print("This decoder was also conditioned on text. You will need to enter a prompt for each image you use.")

    while True:
        image_filepath: Path = click.prompt(f'File {len(image_filepaths)+1}', default=Path(), type=Path, show_default=False)
        if image_filepath == Path():
            break
        if image_filepath.is_dir():
            new_image_paths = utils.get_images_in_dir(image_filepath)
        elif utils.is_image_file(image_filepath):
            new_image_paths = [image_filepath]
        else:
            print(f"{image_filepath} is not a valid image file.")
            continue

        if decoder_data_requirements.text_encoding:
            for image_path in new_image_paths:
                text_prompt = click.prompt(f'Prompt for {image_path.name}', default=utils.get_prompt_from_filestem(image_path.stem), type=str, show_default=True)
                text_prompts.append(text_prompt)
        image_filepaths.extend(new_image_paths)

    print(f"Found {len(image_filepaths)} images.")
    images = utils.get_images_from_paths(image_filepaths)
    num_samples = click.prompt('How many samples would you like to generate for each image?', default=1, type=int)

    output_map = variation.run(images, text=text_prompts, sample_count=num_samples, batch_size=decoder_batch_size)
    os.makedirs(output_path, exist_ok=True)
    for file_index, generation_list in output_map.items():
        file = image_filepaths[file_index].stem
        for i, image in enumerate(generation_list):
            image.save(os.path.join(output_path, f"{file}_{i}.png"))

@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/inpaint/', help='Path to output directory')
@click.pass_context
def inpaint(ctx, model_config: str, output_path: str):
    verbose = ctx.obj['verbose']
    inpainting: BasicInpainting = BasicInpainting.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
    image_filepaths: List[Path] = []
    mask_filepaths: List[Path] = []
    text_prompts: List[str] = []
    print("You will be entering the paths to your images and masks one at a time. Enter an empty image path to continue")
    while True:
        image_filepath: Path = click.prompt(f'File {len(image_filepaths)+1}', default=Path(), type=Path, show_default=False)
        if image_filepath == Path():
            break
        if not utils.is_image_file(image_filepath):
            print(f"{image_filepath} is not a valid image file.")
            continue
        mask_filepath: Path = click.prompt(f'Mask for {image_filepath.name}', default=Path(), type=Path, show_default=False)
        if not utils.is_image_file(mask_filepath):
            print(f"{mask_filepath} is not a valid image file.")
            continue
        text_prompt = click.prompt(f'Prompt for {image_filepath.name}', default=utils.get_prompt_from_filestem(image_filepath.stem), type=str, show_default=True)

        image_filepaths.append(image_filepath)
        mask_filepaths.append(mask_filepath)
        text_prompts.append(text_prompt)
            
    print(f"Found {len(image_filepaths)} images.")
    images = utils.get_images_from_paths(image_filepaths)
    mask_images = utils.get_images_from_paths(mask_filepaths)
    min_image_size = float('inf')
    for i, image, mask_image, filepath in zip(range(len(images)), images, mask_images, image_filepaths):
        assert image.size == mask_image.size, f"Image {filepath.name} has different dimensions than mask {mask_filepaths[i].name}"
        if min(image.size) < min_image_size:
            min_image_size = min(image.size)
        if image.size[1] != image.size[0]:
            print(f"{filepath.name} is not a square image. It will be center cropped into a square.")
            images[i] = utils.center_crop_to_square(image)
            mask_images[i] = utils.center_crop_to_square(mask_image)
    print(f"Minimum image size is {min_image_size}. All images will be resized to this size for inference.")
    images = [image.resize((min_image_size, min_image_size)) for image in images]
    mask_images = [mask_image.resize((min_image_size, min_image_size)) for mask_image in mask_images]

    masks = [utils.get_mask_from_image(mask_image) for mask_image in mask_images]
    num_samples = click.prompt('How many samples would you like to generate for each image?', default=1, type=int)
    output_map = inpainting.run(images, masks, text=text_prompts, sample_count=num_samples)
    os.makedirs(output_path, exist_ok=True)
    for file_index, generation_list in output_map.items():
        file = image_filepaths[file_index].stem
        for i, image in enumerate(generation_list):
            image.save(os.path.join(output_path, f"{file}_{i}.png"))


if __name__ == "__main__":
    inference(obj={})