Commit bfa3fb86 authored by dongchy920's avatar dongchy920
Browse files

dalle2_pytorch

parents
Pipeline #1495 canceled with stages
try:
import gradio as gr
except ImportError:
print("Please install gradio: `pip install gradio`")
exit(1)
from pathlib import Path
from typing import Dict, List
from PIL import Image as PILImage
from dalle2_laion import ModelLoadConfig, DalleModelManager, utils
from dalle2_laion.scripts import BasicInference, ImageVariation, BasicInpainting
config_path = Path(__file__).parent / 'configs/gradio.example.json'
model_config = ModelLoadConfig.from_json_path(config_path)
model_manager = DalleModelManager(model_config)
output_path = Path(__file__).parent / 'output/gradio'
output_path.mkdir(parents=True, exist_ok=True)
cond_scale_sliders = [gr.Slider(minimum=0.5, maximum=5, step=0.05, label="Prior Cond Scale", value=1),]
for i in range(model_manager.model_config.decoder.final_unet_number):
cond_scale_sliders.append(gr.Slider(minimum=0.5, maximum=5, step=0.05, label=f"Decoder {i+1} Cond Scale", value=1))
def dream(text: str, samples_per_prompt: int, prior_cond_scale: float, *decoder_cond_scales: List[float]):
prompts = text.split('\n')[:8]
script = BasicInference(model_manager, verbose=True)
output = script.run(prompts, prior_sample_count=samples_per_prompt, decoder_batch_size=40, prior_cond_scale=prior_cond_scale, decoder_cond_scale=decoder_cond_scales)
all_outputs = []
for text, embedding_outputs in output.items():
for index, embedding_output in embedding_outputs.items():
all_outputs.extend(embedding_output)
return all_outputs
dream_interface = gr.Interface(
dream,
inputs=[
gr.Textbox(placeholder="A corgi wearing a top hat...", lines=8),
gr.Slider(minimum=1, maximum=4, step=1, label="Samples per prompt", value=1),
*cond_scale_sliders
],
outputs=[
gr.Gallery()
],
title="Dalle2 Dream",
description="Generate images from text. You can give a maximum of 8 prompts at a time. Any more will be ignored. Generation takes around 5 minutes so be patient.",
)
def variation(image: PILImage.Image, text: str, num_generations: int, *decoder_cond_scales: List[float]):
print("Variation using text:", text)
img = utils.center_crop_to_square(image)
script = ImageVariation(model_manager, verbose=True)
output = script.run([img], [text], sample_count=num_generations, cond_scale=decoder_cond_scales)
all_outputs = []
for index, embedding_output in output.items():
all_outputs.extend(embedding_output)
return all_outputs
variation_interface = gr.Interface(
variation,
inputs=[
gr.Image(value="https://www.thefarmersdog.com/digest/wp-content/uploads/2021/12/corgi-top-1400x871.jpg", source="upload", interactive=True, type="pil"),
gr.Text(),
gr.Slider(minimum=1, maximum=6, label="Number to generate", value=2, step=1),
*cond_scale_sliders[1:]
],
outputs=[
gr.Gallery()
],
title="Dalle2 Variation",
description="Generates images similar to the input image.\nGeneration takes around 5 minutes so be patient.",
)
def inpaint(image: Dict[str, PILImage.Image], text: str, num_generations: int, prior_cond_scale: float, *decoder_cond_scales: List[float]):
print("Inpainting using text:", text)
img, mask = image['image'], image['mask']
# Remove alpha from img
img = img.convert('RGB')
img = utils.center_crop_to_square(img)
mask = utils.center_crop_to_square(mask)
script = BasicInpainting(model_manager, verbose=True)
mask = ~utils.get_mask_from_image(mask)
output = script.run(images=[img], masks=[mask], text=[text], sample_count=num_generations, prior_cond_scale=prior_cond_scale, decoder_cond_scale=decoder_cond_scales)
all_outputs = []
for index, embedding_output in output.items():
all_outputs.extend(embedding_output)
return all_outputs
inpaint_interface = gr.Interface(
inpaint,
inputs=[
gr.Image(value="https://www.thefarmersdog.com/digest/wp-content/uploads/2021/12/corgi-top-1400x871.jpg", source="upload", tool="sketch", interactive=True, type="pil"),
gr.Text(),
gr.Slider(minimum=1, maximum=6, label="Number to generate", value=2, step=1),
*cond_scale_sliders
],
outputs=[
gr.Gallery()
],
title="Dalle2 Inpainting",
description="Fills in the details of areas you mask out.\nGeneration takes around 5 minutes so be patient.",
)
demo = gr.TabbedInterface(interface_list=[dream_interface, variation_interface, inpaint_interface], tab_names=["Dream", "Variation", "Inpaint"])
demo.launch(share=True, enable_queue=True)
\ No newline at end of file
icon.png

68.4 KB

from setuptools import setup, find_packages
setup(
name = "dalle2-laion",
version = "0.0.1",
packages = find_packages(exclude=[]),
include_package_data = True,
install_requires = [
"packaging>=21.0",
"pydantic>=1.9.0",
"torch>=1.10",
"Pillow>=9.0.0",
"numpy>=1.20.0",
"click>=8.0.0"
"dalle2-pytorch"
]
)
print("---------------- Train_2.py --------------------")
import torch
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch import OpenAIClipAdapter, Unet, Decoder, DecoderTrainer
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
import os
import torch.utils.data as data
import pickle
from datasets import load_dataset, concatenate_datasets
""" got the base fot this here: https://github.com/lucidrains/DALLE2-pytorch/issues/279"""
import pdb
import time
# Parameters
image_size = 256 # Image dimension
batch_size = 2 # Batch size for training, adjust based on GPU memory
learning_rate = 1e-4 # Learning rate for the optimizer
num_epochs = 3 # Number of epochs for training
log_image_interval = 1000 # Interval for logging images
checkp_interval = 2500
log_idx = 100
save_dir = "./T2_log_images" # Directory to save log images
weight_dir = "./T2_weight_log"
os.makedirs(save_dir, exist_ok=True) # Create save directory if it doesn't exist
os.makedirs(weight_dir, exist_ok=True)
def xosc2ImageDataset():
with open('../Dataset_dictionary.pkl', 'rb') as f:
loaded_dict = pickle.load(f)
dset = loaded_dict
return dset
class ImgTextDataset(data.Dataset):
def __init__(self, data):
self.img_paths = data[f"train"]['image_path']
self.captions = data[f"train"]['text']
# Apply required image transforms. For my model I needed images with 256 x 256 dimensions.
self.image_transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor()
])
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image_path = self.img_paths[idx]
caption = self.captions[idx]
image = Image.open(image_path)
image_pt = self.image_transform(image)
if image_pt.shape[0]==1:
image_pt = image_pt.repeat(3,1,1)
return image_pt, caption
# Setup device
device = torch.device("cuda")
# Define your image-text dataset
data = load_dataset('json',data_files='data.json')
dataset = ImgTextDataset(data)
num_workers = 0
dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=True)
# Initialize OpenAI CLIP model adapter
clip = OpenAIClipAdapter()
# Create models for training
unet1 = Unet(
dim=128,
image_embed_dim=512,
text_embed_dim=512,
cond_dim=128,
channels=3,
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
clip = clip,
timesteps = 100
).cuda()
decoder_trainer = DecoderTrainer(
decoder,
lr=3e-4,
wd=1e-2,
ema_beta=0.99,
ema_update_after_step=1000,
ema_update_every=10,
).cuda()
# Use built-in tokenizer. You can use others like GPT2, YTTM etc.
t = SimpleTokenizer()
# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.
if os.listdir(weight_dir):
# Lade den letzten Checkpoint mit der höchsten Epochennummer
last_epoch = max([int(name.split('_')[-2]) for name in os.listdir(weight_dir)])
last_batch_idx = max([int(name.split('_')[-1].split('.')[0]) for name in os.listdir(weight_dir) if name.startswith(f'model_decoder_{last_epoch}')])
checkpoint_path = os.path.join(weight_dir, f'model_decoder_{last_epoch}_{last_batch_idx}.pt')
print("Loading from: ",checkpoint_path)
checkpoint = torch.load(checkpoint_path)
decoder_trainer.load_state_dict(checkpoint)
start_epoch = last_epoch + 1 # Beginne die Epoche nach der geladenen Epoche
print("Checkpoint loaded")
else:
start_epoch = 0 # Beginne von der ersten Epoche
print("starting from zero")
for epoch in range(num_epochs):
for batch_idx, (images, texts) in enumerate(dataloader):
images_copy = images.clone().detach()
text=t.tokenize(texts, context_length = 1024)
text_copy = text.clone().detach()
ep = epoch + start_epoch
t0=time.time()
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,],record_shapes=True, profile_memory=False, with_stack=False) as prof:
loss = decoder_trainer(
images_copy.cuda(),
text=text_copy.cuda(),
unet_number=1,
max_batch_size=4
)
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
# 导出更加详细的prof的josn文件
# prof.export_chrome_trace('./model_prof_nv.json')
# pdb.set_trace()
# print(time.time()-t0)
decoder_trainer.update(1)
if batch_idx % 10 == 0:
print(f"epoch {epoch}, step {batch_idx}, loss {loss}")
if batch_idx % log_image_interval == 0 and batch_idx != 0:
image_embed = clip.embed_image(images.cuda())
sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
save_image(sample, f'./T2_log_images/{ep}_{batch_idx}.png')
if batch_idx % checkp_interval == 0: # Periodically save the model.
decoder_trainer.save(f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt')
#torch.save(decoder_trainer.state_dict(), f'./T2_weight_log/model_decoder_{ep}_{batch_idx}.pt')
\ No newline at end of file
print("---------------- Train_2_Priortrainer.py --------------------")
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter, DiffusionPriorTrainer
from dalle2_pytorch.tokenizer import SimpleTokenizer
import torchvision.transforms as T
from PIL import Image
import pickle
import os
import torch.utils.data as data
from torch.nn.utils.rnn import pad_sequence
from datasets import load_dataset, concatenate_datasets
from accelerate import Accelerator
import pdb
device = torch.device("cuda")
weight_dir = "./Priortrainer_weight_log"
os.makedirs(weight_dir, exist_ok=True)
num_epochs = 3
#batch_idx = 196491
checkp_interval = 35000
batch_size = 2
clip = OpenAIClipAdapter()
def xosc2ImageDataset():
with open('../Dataset_dictionary.pkl', 'rb') as f:
loaded_dict = pickle.load(f)
dset = loaded_dict
return dset
class TextDataset:
def __init__(self, texts, batch_size=4, max_length=4500):
self.texts = texts
self.batch_size = batch_size
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __iter__(self):
self.current_index = 0 # Setze den Index zu Beginn der Iteration zurück
return self
def __next__(self):
batch_texts = []
for _ in range(self.batch_size):
if self.current_index >= len(self.texts):
raise StopIteration
text = self.texts[self.current_index]
#print(text[:90])
text = text[:self.max_length]
tensor = torch.tensor([ord(char) for char in text])
batch_texts.append(tensor)
self.current_index += 1
padded_tensors = pad_sequence(batch_texts, batch_first=True)
return padded_tensors
class ImageDataset:
def __init__(self, image_paths, batch_size=4, image_size=(256, 256)):
self.image_paths = image_paths
self.batch_size = batch_size
self.image_size = image_size
def __len__(self):
return len(self.image_paths)
def __iter__(self):
self.current_index = 0 # Setze den Index zu Beginn der Iteration zurück
return self
def __next__(self):
batch_images = []
for _ in range(self.batch_size):
if self.current_index >= len(self.image_paths):
raise StopIteration
path = self.image_paths[self.current_index]
normalized_path = path.replace('\\', '/')
#print(normalized_path)
image = self.load_image(normalized_path)
batch_images.append(image)
self.current_index += 1
return torch.stack(batch_images, dim=0)
def load_image(self, path):
transform = T.Compose([
T.Resize(self.image_size),
T.ToTensor(),
])
image = Image.open(path).convert("RGB")
image = transform(image)
return image
data = load_dataset('json',data_files='data.json')
image_list = data[f"train"]['image_path'] # list of captions
text_list = data[f"train"]['text']
text_dataset = TextDataset(text_list, batch_size=batch_size)
image_dataset = ImageDataset(image_list, batch_size=batch_size)
"""prior networks (with transformer)""" #setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(# diffusion prior network, which contains the CLIP and network (with transformer) above
net = prior_network,
clip = clip,
timesteps = 1000,
sample_timesteps = 64,
cond_drop_prob = 0.2
).cuda()
accelerator = Accelerator()
prior_trainer = DiffusionPriorTrainer(
diffusion_prior,
accelerator=accelerator,
lr = 3e-4,
)
if os.listdir(weight_dir):
# Load the last checkpoint with the highest epoch number
last_epoch = max([int(name.split('_')[-2]) for name in os.listdir(weight_dir)])
last_batch_idx = max([int(name.split('_')[-1].split('.')[0]) for name in os.listdir(weight_dir) if name.startswith(f'model_prior_{last_epoch}')])
checkpoint_path = os.path.join(weight_dir, f'model_prior_{last_epoch}_{last_batch_idx}.pt')
prior_trainer.load(checkpoint_path, overwrite_lr = True, strict=True)
start_epoch = last_epoch + 1 # Start next epoch plus last one
print("Checkpoint loaded")
else:
start_epoch = 0
print("starting from zero")
# checkpoint_path = './model/prior.pth'
# prior_trainer.load(checkpoint_path, overwrite_lr = True, strict=True)
t = SimpleTokenizer()
num_batches = len(text_dataset) // batch_size # Calculate the total number of batches
print("Numberofbatches",num_batches,"Length Dataset:", len(text_dataset))
for epoch in range(num_epochs):
ep = epoch + start_epoch
for idx in range(num_batches):# range(len(text_dataset)):
text_loader = iter(text_dataset)
image_loader = iter(image_dataset)
#for _ in range(num_batches):
batch_texts = next(text_loader)
batch_images = next(image_loader)
loss = prior_trainer(
batch_texts.to(device),
batch_images.to(device)
)
prior_trainer.update() # Update the parameters of the model with the Optimizer
if idx % 10 == 0:
print(f"epoch {ep}, step {idx}, loss {loss}")
if idx % (int(num_batches/10)) == 0: # Periodically save the model.
prior_trainer.save(f'./Priortrainer_weight_log/model_prior_{ep}_{idx}.pt')
# do above for many steps ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment