Commit bfa3fb86 authored by dongchy920's avatar dongchy920
Browse files

dalle2_pytorch

parents
Pipeline #1495 canceled with stages
"""
This script takes images, a mask, and text, and in-paints where the mask.
The script is basically the same as BasicInference, but it also takes images and masks.
"""
from dalle2_laion.scripts import InferenceScript
from typing import List, Dict
from PIL import Image as PILImage
import torch
class BasicInpainting(InferenceScript):
def run(
self,
images: List[PILImage.Image],
masks: List[torch.Tensor], # Boolean tensor of same size as image
text: List[str],
prior_cond_scale: float = None, # Use default cond scale from config by default
decoder_cond_scale: float = None,
sample_count: int = 1,
prior_batch_size: int = 100,
decoder_batch_size: int = 10
) -> Dict[int, List[PILImage.Image]]:
if isinstance(text, str):
text = [text]
self.print("Generating prior embeddings...")
image_embedding_map = self._sample_prior(text, cond_scale=prior_cond_scale, sample_count=sample_count, batch_size=prior_batch_size, num_samples_per_batch=2)
self.print("Finished generating prior embeddings.")
# image_embedding_map is a map between the text index and the generated image embeddings
image_embeddings: List[torch.Tensor] = []
decoder_text = [] # The decoder also needs the text, but since we have repeated the text embeddings, we also need to repeat the text
for i, original_text in enumerate(text):
decoder_text.extend([original_text] * len(image_embedding_map[i]))
image_embeddings.extend(image_embedding_map[i])
# In order to get the original text from the image embeddings, we need to reverse the map
image_embedding_index_reverse_map = {i: [] for i in range(len(text))}
current_count = 0
for i in range(len(text)):
for _ in range(len(image_embedding_map[i])):
image_embedding_index_reverse_map[i].append(current_count)
current_count += 1
# Now we can use the image embeddings to generate the images
self.print(f"Grouped {len(text)} texts into {len(image_embeddings)} embeddings.")
self.print("Sampling from decoder...")
# images = self._pil_to_torch(images, resize_for_clip=False)
image_map = self._sample_decoder(
text=decoder_text,
image_embed=image_embeddings,
cond_scale=decoder_cond_scale,
inpaint_images=images, inpaint_image_masks=masks,
sample_count=1, batch_size=decoder_batch_size
)
self.print("Finished sampling from decoder.")
# Now we will reconstruct a map from text to a map of img_embedding indices to list of images
output_map: Dict[int, List[PILImage.Image]] = {}
for i in range(len(images)):
output_map[i] = []
embedding_indices = image_embedding_index_reverse_map[i]
for embedding_index in embedding_indices:
output_map[i].extend(image_map[embedding_index])
return output_map
\ No newline at end of file
"""
This script generate image embeddings directly with clip instead of using the prior.
Put image in, get image out... but different.
"""
from dalle2_laion.scripts import InferenceScript
from typing import List, Dict, Optional
from PIL import Image as PILImage
class ImageVariation(InferenceScript):
def run(
self,
images: List[PILImage.Image],
text: Optional[List[str]],
cond_scale: float = None, # Use defaults from config by default
sample_count: int = 1,
batch_size: int = 10
) -> Dict[int, List[PILImage.Image]]:
self.print("Running decoder...")
image_map = self._sample_decoder(images=images, text=text, cond_scale=cond_scale, sample_count=sample_count, batch_size=batch_size)
self.print("Finished running decoder.")
return image_map
"""
This module contains an abstract class for inference scripts.
"""
from typing import List, Tuple, Union, Dict, TypeVar
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_laion import DalleModelManager, ModelLoadConfig
from torchvision.transforms import ToPILImage, ToTensor
from PIL import Image as PILImage
import torch
from contextlib import contextmanager
ClassType = TypeVar('ClassType')
class InferenceScript:
def __init__(self, model_manager: DalleModelManager, verbose: bool = False):
self.model_manager = model_manager
self.verbose = verbose
self.device = model_manager.devices[0] if model_manager is not None else 'cpu'
@classmethod
def create(cls: ClassType, config: Union[str, ModelLoadConfig], *args, verbose: bool = False, check_updates: bool = True, **kwargs) -> ClassType:
"""
Creates an instance of the inference script directly from a config.
Useful if only one inference script will be run at a time.
"""
if isinstance(config, str):
config = ModelLoadConfig.from_json_path(config)
model_manager = DalleModelManager(config, check_updates=check_updates)
return cls(model_manager, *args, **kwargs, verbose=verbose)
def print(self, *args, **kwargs):
if self.verbose:
print(*args, **kwargs)
@contextmanager
def _clip_in_decoder(self):
assert self.model_manager.decoder_info is not None, "Cannot use the decoder without a decoder model."
decoder = self.model_manager.decoder_info.model
clip = self.model_manager.clip
decoder.clip = clip
yield decoder
decoder.clip = None
@contextmanager
def _clip_in_prior(self):
assert self.model_manager.prior_info is not None, "Cannot use the prior without a prior model."
prior = self.model_manager.prior_info.model
clip = self.model_manager.clip
prior.clip = clip
yield prior
prior.clip = None
@contextmanager
def _decoder_in_gpu(self):
# Moves the decoder to gpu and prior to cpu and removes both from gpu after the context is exited.
assert self.model_manager.decoder_info is not None, "Cannot use the decoder without a decoder model."
if self.model_manager.prior_info is not None:
prior = self.model_manager.prior_info.model
prior.to('cpu')
with self._clip_in_decoder() as decoder:
decoder.to(self.device)
yield decoder
decoder.to('cpu')
@contextmanager
def _prior_in_gpu(self):
# Moves the prior to gpu and decoder to cpu and removes both from gpu after the context is exited.
assert self.model_manager.prior_info is not None, "Cannot use the prior without a prior model."
if self.model_manager.decoder_info is not None:
decoder = self.model_manager.decoder_info.model
decoder.to('cpu')
with self._clip_in_prior() as prior:
prior.to(self.device)
yield prior
prior.to('cpu')
@contextmanager
def _clip_in_gpu(self):
# Moves the clip model to gpu and doesn't touch the others. If clip was originally on cpu, then it is moved back to cpu after
clip = self.model_manager.clip
assert clip is not None, "Cannot use the clip without a clip model."
original_device = next(iter(clip.parameters())).device
not_on_device = original_device != self.device
if not_on_device:
clip.to(self.device)
yield clip
if not_on_device:
clip.to(original_device)
def _pil_to_torch(self, image: Union[PILImage.Image, List[PILImage.Image]], resize_for_clip: bool = True):
"""
Convert a PIL image into a torch tensor.
Tensor is of dimension 3 if one image is passed in, and of dimension 4 if a list of images is passed in.
"""
# If the image has an alpha channel, then we need to remove it.
def process_image(image: PILImage.Image) -> PILImage.Image:
if resize_for_clip:
clip_size = self.model_manager.clip.image_size
image = image.resize((clip_size, clip_size), resample=PILImage.LANCZOS)
if image.mode == 'RGBA':
return image.convert('RGB')
else:
return image
if isinstance(image, PILImage.Image):
return ToTensor()(process_image(image))
else:
return torch.stack([ToTensor()(process_image(image[i])) for i in range(len(image))])
def _torch_to_pil(self, image: torch.tensor):
"""
If the tensor is a batch of images, then we return a list of PIL images.
"""
if len(image.shape) == 4:
return [ToPILImage(image[i]) for i in range(image.shape[0])]
else:
return ToPILImage()(image)
def _repeat_tensor_and_batch(self, tensor: Union[torch.Tensor, List[torch.Tensor]], repeat_num: int, batch_size: int) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Repeats each element of the first dimension of a tensor repeat_num times then batches the result into a list of tensors.
Also returns a map from the repeated tensor to the index of the original tensor.
"""
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
batched_repeat = tensor.repeat_interleave(repeat_num, dim=0).split(batch_size, dim=0)
batched_map = torch.arange(0, tensor.shape[0]).repeat_interleave(repeat_num, dim=0).split(batch_size, dim=0)
return list(batched_repeat), [t.tolist() for t in batched_map]
def _embed_images(self, images: List[PILImage.Image]) -> torch.Tensor:
"""
Generates the clip embeddings for a list of images
"""
assert self.model_manager.clip is not None, "Cannot generate embeddings for this model."
images_tensor = self._pil_to_torch(images, resize_for_clip=True).to(self.device)
with self._clip_in_gpu() as clip:
image_embed = clip.embed_image(images_tensor)
return image_embed.image_embed
def _encode_text(self, text: List[str]) -> torch.Tensor:
"""
Generates the clip embeddings for a list of text
"""
assert self.model_manager.clip is not None, "Cannot generate embeddings for this model."
text_tokens = self._tokenize_text(text)
with self._clip_in_gpu() as clip:
text_embed = clip.embed_text(text_tokens.to(self.device))
return text_embed.text_encodings
def _tokenize_text(self, text: List[str]) -> torch.Tensor:
"""
Tokenizes a list of text
"""
return tokenizer.tokenize(text)
def _sample_decoder(
self,
images: List[PILImage.Image] = None, image_embed: List[torch.Tensor] = None,
text: List[str] = None, text_encoding: List[torch.Tensor] = None,
inpaint_images: List[PILImage.Image] = None, inpaint_image_masks: List[torch.Tensor] = None,
cond_scale: float = None, sample_count: int = 1, batch_size: int = 10,
):
"""
Samples images from the decoder
Capable of doing basic generation with a list of image embeddings (possibly also conditioned with a list of strings or text embeddings)
Also capable of two more advanced generation techniques:
1. Variation generation: If images are passed in the image embeddings will be generated based on those.
2. In-painting generation: If images and masks are passed in, the images will be in-painted using the masks and the image embeddings.
"""
if cond_scale is None:
# Then we use the default scale
load_config = self.model_manager.model_config.decoder
unet_configs = load_config.unet_sources
cond_scale = [1.0] * load_config.final_unet_number
for unet_config in unet_configs:
if unet_config.default_cond_scale is not None:
for unet_number, new_cond_scale in zip(unet_config.unet_numbers, unet_config.default_cond_scale):
cond_scale[unet_number - 1] = new_cond_scale
self.print(f"Sampling decoder with cond_scale: {cond_scale}")
decoder_info = self.model_manager.decoder_info
assert decoder_info is not None, "No decoder loaded."
data_requirements = decoder_info.data_requirements
min_image_size = min(min(image.size) for image in images) if images is not None else None
is_valid, errors = data_requirements.is_valid(
has_image_emb=image_embed is not None, has_image=images is not None,
has_text_encoding=text_encoding is not None, has_text=text is not None,
image_size=min_image_size
)
assert is_valid, f"The data requirements for the decoder are not satisfied: {errors}"
# Prepare the data
image_embeddings = [] # The null case where nothing is done. This should never be used in actuality, but for stylistic consistency I'm keeping it.
if data_requirements.image_embedding:
if image_embed is None:
# Then we need to use clip to generate the image embedding
image_embed = self._embed_images(images)
# Then we need to group these tensors into batches of size batch_size such that the total number of samples is sample_count
image_embeddings, image_embeddings_map = self._repeat_tensor_and_batch(image_embed, repeat_num=sample_count, batch_size=batch_size)
self.print(f"Decoder batched inputs into {len(image_embeddings)} batches. Total number of samples: {sum(len(t) for t in image_embeddings)}.")
if data_requirements.text_encoding:
if text_encoding is None:
text_encoding = self._encode_text(text)
text_encodings, text_encodings_map = self._repeat_tensor_and_batch(text_encoding, repeat_num=sample_count, batch_size=batch_size)
assert len(image_embeddings) > 0, "No data provided for decoder inference."
output_image_map: Dict[int, List[PILImage.Image]] = {}
with self._decoder_in_gpu() as decoder:
for i in range(len(image_embeddings)):
args = {}
embeddings_map = []
if data_requirements.image_embedding:
args["image_embed"] = image_embeddings[i].to(self.device)
embeddings_map = image_embeddings_map[i]
if data_requirements.text_encoding:
args["text_encodings"] = text_encodings[i].to(self.device)
embeddings_map = text_encodings_map[i]
if inpaint_images is not None:
assert len(inpaint_images) == len(inpaint_image_masks), "Number of inpaint images and masks must match."
inpaint_image_tensors = self._pil_to_torch(inpaint_images, resize_for_clip=False)
args["inpaint_image"] = inpaint_image_tensors.to(self.device)
args["inpaint_mask"] = torch.stack(inpaint_image_masks).to(self.device)
self.print(f"image tensor shape: {args['inpaint_image'].shape}. mask shape: {args['inpaint_mask'].shape}")
output_images = decoder.sample(**args, cond_scale=cond_scale)
for output_image, input_embedding_number in zip(output_images, embeddings_map):
if input_embedding_number not in output_image_map:
output_image_map[input_embedding_number] = []
output_image_map[input_embedding_number].append(self._torch_to_pil(output_image))
return output_image_map
def _sample_prior(self, text_or_tokens: Union[torch.Tensor, List[str]], cond_scale: float = None, sample_count: int = 1, batch_size: int = 100, num_samples_per_batch: int = 2):
"""
Samples image embeddings from the prior
:param text_or_tokens: A list of strings to use as input to the prior or a tensor of tokens generated from strings.
:param cond_scale: The scale of the conditioning.
:param sample_count: The number of samples to generate for each input.
:param batch_size: The max number of samples to run in parallel.
:param num_samples_per_batch: The number of samples to rerank for each output sample.
"""
if cond_scale is None:
# Then we use the default scale
cond_scale = self.model_manager.model_config.prior.default_cond_scale
if cond_scale is None:
# Fallback
cond_scale = 1.0
self.print(f"Sampling prior with cond_scale: {cond_scale}")
assert self.model_manager.prior_info is not None
data_requirements = self.model_manager.prior_info.data_requirements
is_valid, errors = data_requirements.is_valid(
has_text_encoding=False, has_text=text_or_tokens is not None,
has_image_emb=False, has_image=False,
image_size=None
)
assert is_valid, f"The data requirements for the prior are not satisfied. {errors}"
if isinstance(text_or_tokens, list):
text_tokens = self._tokenize_text(text_or_tokens)
else:
text_tokens = text_or_tokens
text_batches, text_batches_map = self._repeat_tensor_and_batch(text_tokens, repeat_num=sample_count, batch_size=batch_size)
self.print(f"Prior batched inputs into {len(text_batches)} batches. Total number of samples: {sum(len(t) for t in text_batches)}.")
embedding_map: Dict[int, List[torch.Tensor]] = {}
# Weirdly the prior requires clip be part of itself to work so we insert it
with self._prior_in_gpu() as prior:
for text_batch, batch_map in zip(text_batches, text_batches_map):
text_batch = text_batch.to(self.device)
embeddings = prior.sample(text_batch, cond_scale=cond_scale, num_samples_per_batch=num_samples_per_batch)
for embedding, embedding_number in zip(embeddings, batch_map):
if embedding_number not in embedding_map:
embedding_map[embedding_number] = []
embedding_map[embedding_number].append(embedding)
return embedding_map
class CliInferenceScript(InferenceScript):
def __init__(self, model_manager: DalleModelManager):
super().__init__(model_manager)
raise NotImplementedError("CliInferenceScript is not implemented cause I have no idea how I'm going to do it yet.")
\ No newline at end of file
from dalle2_laion.scripts.InferenceScript import InferenceScript, CliInferenceScript
from dalle2_laion.scripts.BasicInference import BasicInference
from dalle2_laion.scripts.ImageVariation import ImageVariation
from dalle2_laion.scripts.BasicInpainting import BasicInpainting
\ No newline at end of file
from pathlib import Path
from typing import List
from PIL import Image as PILImage
import torch
import re
import numpy as np
def is_image_file(file: Path) -> bool:
return file.suffix == '.png' or file.suffix == '.jpg' or file.suffix == '.jpeg'
def is_text_file(file: Path) -> bool:
return file.suffix == '.txt'
def is_json_file(file: Path) -> bool:
return file.suffix == '.json'
def get_images_in_dir(dir: Path) -> List[Path]:
assert dir.is_dir()
return [file for file in dir.iterdir() if is_image_file(file)]
def get_images_from_paths(paths: List[Path]) -> List[PILImage.Image]:
return [PILImage.open(path) for path in paths]
def get_prompt_from_filestem(filestem: str) -> str:
"""
Converts the filename to a prompt with the first letter capitalized and spaces between words.
We assume the stem is either in snake case or camel case.
"""
# First, we replace all "_" with " "
prompt = filestem.replace("_", " ")
# Then we insert a space before every capital letter that does not already have a space
prompt = re.sub(r'([A-Z])', r' \1', prompt)
# Then we capitalize the first letter
prompt = prompt[0].upper() + prompt[1:]
return prompt
def get_mask_from_image(image: PILImage.Image) -> torch.Tensor:
"""
Returns a boolean tensor of the same size as the image.
Where the red channel of the image is greater than 128, the mask is True.
"""
mask = torch.zeros(list(reversed(image.size)), dtype=torch.bool)
# mask[np.array(image.getchannel('R')) > 128] = True
mask[np.array(image) < 128] = True
return mask
def center_crop_to_square(image: PILImage.Image) -> PILImage.Image:
"""
Crops the pill image into a square with the center staying in the same location
"""
width, height = image.size
if width > height:
left = (width - height) // 2
right = left + height
return image.crop((left, 0, right, height))
else:
top = (height - width) // 2
bottom = top + width
return image.crop((0, top, width, bottom))
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
RUN source /opt/dtk/env.sh
\ No newline at end of file
from dalle2_laion import DalleModelManager, ModelLoadConfig, utils
from dalle2_laion.scripts import BasicInference, ImageVariation, BasicInpainting
from typing import List
import os
import click
from pathlib import Path
import json
import torch
import pdb
@click.group()
@click.option('--verbose', '-v', is_flag=True, default=False, help='Print verbose output.')
@click.option('--suppress-updates', '-s', is_flag=True, default=False, help='Suppress updating models if checksums do not match.')
@click.pass_context
def inference(ctx, verbose, suppress_updates):
ctx.obj['verbose'] = verbose
ctx.obj['suppress_updates'] = suppress_updates
@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.pass_context
def test(ctx, model_config):
model_config = ModelLoadConfig.from_json_path(model_config)
if model_config.decoder is not None:
for unet_source in model_config.decoder.unet_sources:
print('Checksum:', unet_source.load_model_from.checksum_file_path)
if model_config.prior is not None:
print('Checksum:', model_config.prior.load_model_from.checksum_file_path)
model_manager = DalleModelManager(model_config, check_updates=not ctx.obj['suppress_updates'])
@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/basic/', help='Path to output directory')
@click.option('--decoder-batch-size', default=10, help='Batch size for decoder')
@click.pass_context
def dream(ctx, model_config: str, output_path: str, decoder_batch_size: int):
verbose = ctx.obj['verbose']
prompts = []
print("Enter your prompts one by one. Enter an empty prompt to finish.")
while True:
prompt = click.prompt(f'Prompt {len(prompts)+1}', default='', type=str, show_default=False)
if prompt == '':
break
prompt_file = Path(prompt)
if utils.is_text_file(prompt_file):
# Then we can read the prompts line by line
with open(prompt_file, 'r') as f:
for line in f:
prompts.append(line.strip())
elif utils.is_json_file(prompt_file):
# Then we assume this is an array of prompts
with open(prompt_file, 'r') as f:
prompts.extend(json.load(f))
else:
prompts.append(prompt)
num_prior_samples = click.prompt('How many samples would you like to generate for each prompt?', default=1, type=int)
dreamer: BasicInference = BasicInference.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],record_shapes=True, profile_memory=False, with_stack=False) as prof:
# output_map = dreamer.run(prompts, prior_sample_count=num_prior_samples, decoder_batch_size=decoder_batch_size)
# # import pdb
# # pdb.set_trace()
# prof.step()
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# prof.export_chrome_trace('./dalle2_prof_dcu.json')
output_map = dreamer.run(prompts, prior_sample_count=num_prior_samples, decoder_batch_size=decoder_batch_size)
os.makedirs(output_path, exist_ok=True)
for text in output_map:
for embedding_index in output_map[text]:
for image in output_map[text][embedding_index]:
image.save(os.path.join(output_path, f"{text}_{embedding_index}.png"))
@inference.command()
@click.option('--model-config', default='./configs/variation.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/variations/', help='Path to output directory')
@click.option('--decoder-batch-size', default=10, help='Batch size for decoder')
@click.pass_context
def variation(ctx, model_config: str, output_path: str, decoder_batch_size: int):
verbose = ctx.obj['verbose']
variation: ImageVariation = ImageVariation.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
decoder_data_requirements = variation.model_manager.decoder_info.data_requirements
image_filepaths: List[Path] = []
text_prompts: List[str] = [] if decoder_data_requirements.text_encoding else None
print("Enter paths to your images. If you specify a directory all images within will be added. Enter an empty line to finish.")
if decoder_data_requirements.text_encoding:
print("This decoder was also conditioned on text. You will need to enter a prompt for each image you use.")
while True:
image_filepath: Path = click.prompt(f'File {len(image_filepaths)+1}', default=Path(), type=Path, show_default=False)
if image_filepath == Path():
break
if image_filepath.is_dir():
new_image_paths = utils.get_images_in_dir(image_filepath)
elif utils.is_image_file(image_filepath):
new_image_paths = [image_filepath]
else:
print(f"{image_filepath} is not a valid image file.")
continue
if decoder_data_requirements.text_encoding:
for image_path in new_image_paths:
text_prompt = click.prompt(f'Prompt for {image_path.name}', default=utils.get_prompt_from_filestem(image_path.stem), type=str, show_default=True)
text_prompts.append(text_prompt)
image_filepaths.extend(new_image_paths)
print(f"Found {len(image_filepaths)} images.")
images = utils.get_images_from_paths(image_filepaths)
num_samples = click.prompt('How many samples would you like to generate for each image?', default=1, type=int)
output_map = variation.run(images, text=text_prompts, sample_count=num_samples, batch_size=decoder_batch_size)
os.makedirs(output_path, exist_ok=True)
for file_index, generation_list in output_map.items():
file = image_filepaths[file_index].stem
for i, image in enumerate(generation_list):
image.save(os.path.join(output_path, f"{file}_{i}.png"))
@inference.command()
@click.option('--model-config', default='./configs/upsampler.example.json', help='Path to model config file')
@click.option('--output-path', default='./output/inpaint/', help='Path to output directory')
@click.pass_context
def inpaint(ctx, model_config: str, output_path: str):
verbose = ctx.obj['verbose']
inpainting: BasicInpainting = BasicInpainting.create(model_config, verbose=verbose, check_updates=not ctx.obj['suppress_updates'])
image_filepaths: List[Path] = []
mask_filepaths: List[Path] = []
text_prompts: List[str] = []
print("You will be entering the paths to your images and masks one at a time. Enter an empty image path to continue")
while True:
image_filepath: Path = click.prompt(f'File {len(image_filepaths)+1}', default=Path(), type=Path, show_default=False)
if image_filepath == Path():
break
if not utils.is_image_file(image_filepath):
print(f"{image_filepath} is not a valid image file.")
continue
mask_filepath: Path = click.prompt(f'Mask for {image_filepath.name}', default=Path(), type=Path, show_default=False)
if not utils.is_image_file(mask_filepath):
print(f"{mask_filepath} is not a valid image file.")
continue
text_prompt = click.prompt(f'Prompt for {image_filepath.name}', default=utils.get_prompt_from_filestem(image_filepath.stem), type=str, show_default=True)
image_filepaths.append(image_filepath)
mask_filepaths.append(mask_filepath)
text_prompts.append(text_prompt)
print(f"Found {len(image_filepaths)} images.")
images = utils.get_images_from_paths(image_filepaths)
mask_images = utils.get_images_from_paths(mask_filepaths)
min_image_size = float('inf')
for i, image, mask_image, filepath in zip(range(len(images)), images, mask_images, image_filepaths):
assert image.size == mask_image.size, f"Image {filepath.name} has different dimensions than mask {mask_filepaths[i].name}"
if min(image.size) < min_image_size:
min_image_size = min(image.size)
if image.size[1] != image.size[0]:
print(f"{filepath.name} is not a square image. It will be center cropped into a square.")
images[i] = utils.center_crop_to_square(image)
mask_images[i] = utils.center_crop_to_square(mask_image)
print(f"Minimum image size is {min_image_size}. All images will be resized to this size for inference.")
images = [image.resize((min_image_size, min_image_size)) for image in images]
mask_images = [mask_image.resize((min_image_size, min_image_size)) for mask_image in mask_images]
masks = [utils.get_mask_from_image(mask_image) for mask_image in mask_images]
num_samples = click.prompt('How many samples would you like to generate for each image?', default=1, type=int)
output_map = inpainting.run(images, masks, text=text_prompts, sample_count=num_samples)
os.makedirs(output_path, exist_ok=True)
for file_index, generation_list in output_map.items():
file = image_filepaths[file_index].stem
for i, image in enumerate(generation_list):
image.save(os.path.join(output_path, f"{file}_{i}.png"))
if __name__ == "__main__":
inference(obj={})
\ No newline at end of file
a t-shirt of an avocado
a rainbow hat
a beautiful sunset at a beach with a shell on the shore
a farmhouse surrounded by flowers
a graphite sketch of a gothic cathedral
still life in the style of Picasso
a portrait of a nightmare creature
a very cute cat
a painting of a capybara sitting on a mountain during fall in surrealist style
a red cube on top of a blue cube
an illustration of a baby daikon radish in a tutu walking a dog
a human face
a child eating a birthday cake near some balloons
the representation of infinity
the end of the world
an armchair in the shape of an avocado
an astronaut riding a horse in a photorealistic style
a propaganda poster for transhumanism
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment