{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "a36078d9-c788-4323-b9af-88225e6c6c94", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, KandinskyV22PriorPipeline\n", "from PIL import Image\n", "\n", "from ip_adapter import IPAdapter" ] }, { "cell_type": "code", "execution_count": null, "id": "f2a71bc9-de68-4de4-b6c3-16c92fac3e45", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "current_dir = os.getcwd()\n", "\n", "print(current_dir)\n", "# TODO\n", "base_model_path = f\"{current_dir}/pretrained_models/sd1.5/Realistic_Vision_v4.0_noVAE\"\n", "vae_model_path = f\"{current_dir}/pretrained_models/sd1.5/sd-vae-ft-mse\"\n", "image_encoder_path = f\"{current_dir}/pretrained_models/sdxl_models/image_encoder/\"\n", "prior_model_path = f\"{current_dir}/pretrained_models/kandinsky-2-2-prior\"\n", "ip_ckpt = f\"{current_dir}/pretrained_models/models/ip-adapter_sd15_vit-G.safetensors\"\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": null, "id": "2d3092ca-f27e-4491-aacb-f0991f3a30ce", "metadata": {}, "outputs": [], "source": [ "def image_grid(imgs, rows, cols):\n", " assert len(imgs) == rows*cols\n", "\n", " w, h = imgs[0].size\n", " grid = Image.new('RGB', size=(cols*w, rows*h))\n", " grid_w, grid_h = grid.size\n", " \n", " for i, img in enumerate(imgs):\n", " grid.paste(img, box=(i%cols*w, i//cols*h))\n", " return grid\n", "\n", "noise_scheduler = DDIMScheduler(\n", " num_train_timesteps=1000,\n", " beta_start=0.00085,\n", " beta_end=0.012,\n", " beta_schedule=\"scaled_linear\",\n", " clip_sample=False,\n", " set_alpha_to_one=False,\n", " steps_offset=1,\n", ")\n", "vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)" ] }, { "cell_type": "code", "execution_count": null, "id": "4b558ca3-e671-4d10-9137-5bf34f710124", "metadata": {}, "outputs": [], "source": [ "# load SD pipeline\n", "pipe = StableDiffusionPipeline.from_pretrained(\n", " base_model_path,\n", " torch_dtype=torch.float16,\n", " scheduler=noise_scheduler,\n", " vae=vae,\n", " feature_extractor=None,\n", " safety_checker=None\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "216952a5-f70d-4aec-b705-fb235e540e3d", "metadata": {}, "outputs": [], "source": [ "# load Prior pipeline\n", "pipe_prior = KandinskyV22PriorPipeline.from_pretrained(prior_model_path, torch_dtype=torch.float16).to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "b02182b7-d3cb-4684-a6dd-8515a7f3f861", "metadata": {}, "outputs": [], "source": [ "# load ip-adapter\n", "ip_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, device)" ] }, { "cell_type": "code", "execution_count": null, "id": "5d0d42e9-6259-48ac-817a-ddf5164cb6ef", "metadata": {}, "outputs": [], "source": [ "# generate clip image embeds\n", "prompt = [\n", " \"a photograph of an astronaut riding a horse\",\n", " \"a macro wildlife photo of a green frog in a rainforest pond, highly detailed, eye-level shot\",\n", " \"kid's coloring book, a happy young girl holding a flower, cartoon, thick lines, black and white, white background\",\n", " \"a professional photograph of a woman with red and very short hair\",\n", "]\n", "clip_image_embeds = pipe_prior(prompt, generator=torch.manual_seed(42)).image_embeds" ] }, { "cell_type": "code", "execution_count": null, "id": "2097cffc-bf93-44ca-9e6b-9d099604d4e1", "metadata": {}, "outputs": [], "source": [ "# generate image\n", "images = ip_model.generate(clip_image_embeds=clip_image_embeds, num_samples=1, width=512, height=512, num_inference_steps=50, seed=42)\n", "image_grid(images, 1, 4)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }