{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "411c59b3-f177-4a10-8925-d931ce572eaa", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL\n", "from PIL import Image\n", "\n", "from ip_adapter import IPAdapterPlus" ] }, { "cell_type": "code", "execution_count": 2, "id": "6b6dc69c-192d-4d74-8b1e-f0d9ccfbdb49", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/modelzoo/IP-Adapter\n" ] } ], "source": [ "import os\n", "\n", "current_dir = os.getcwd()\n", "print(current_dir)\n", "\n", "base_model_path = f\"{current_dir}/pretrained_models/sd1.5/Realistic_Vision_v4.0_noVAE\"\n", "vae_model_path = f\"{current_dir}/pretrained_models/sd1.5/sd-vae-ft-mse\"\n", "image_encoder_path = f\"{current_dir}/pretrained_models/models/image_encoder\"\n", "ip_ckpt = f\"{current_dir}/pretrained_models/models/ip-adapter-plus_sd15.safetensors\"\n", "device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "63ec542f-8474-4f38-9457-073425578073", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "An error occurred while trying to fetch /home/modelzoo/IP-Adapter/pretrained_models/sd1.5/sd-vae-ft-mse: Error no file named diffusion_pytorch_model.safetensors found in directory /home/modelzoo/IP-Adapter/pretrained_models/sd1.5/sd-vae-ft-mse.\n", "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", " return self.fget.__get__(instance, owner)()\n" ] } ], "source": [ "def image_grid(imgs, rows, cols):\n", " assert len(imgs) == rows*cols\n", "\n", " w, h = imgs[0].size\n", " grid = Image.new('RGB', size=(cols*w, rows*h))\n", " grid_w, grid_h = grid.size\n", " \n", " for i, img in enumerate(imgs):\n", " grid.paste(img, box=(i%cols*w, i//cols*h))\n", " return grid\n", "\n", "noise_scheduler = DDIMScheduler(\n", " num_train_timesteps=1000,\n", " beta_start=0.00085,\n", " beta_end=0.012,\n", " beta_schedule=\"scaled_linear\",\n", " clip_sample=False,\n", " set_alpha_to_one=False,\n", " steps_offset=1,\n", ")\n", "vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3849f9d0-5f68-4a49-9190-69dd50720cae", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31b18de091d6467294e5fff83ead0351", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/5 [00:00" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# read image prompt\n", "image = Image.open(\"assets/images/statue.png\")\n", "image.resize((256, 256))" ] }, { "cell_type": "code", "execution_count": 6, "id": "a23de3d2-169e-470b-8012-960e3d07b04b", "metadata": {}, "outputs": [], "source": [ "# load ip-adapter\n", "ip_model = IPAdapterPlus(pipe, image_encoder_path, ip_ckpt, device, num_tokens=16)" ] }, { "cell_type": "code", "execution_count": 7, "id": "d83df45f-717d-4bb3-a5fd-0ea30930a431", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "380900e5571040d58d60ab134dcf5490", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# only image prompt\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] }, { "cell_type": "code", "execution_count": 8, "id": "b77f52de-a9e4-44e1-aeec-8165414f1273", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "471035d3730143bd98ac44bf050e1684", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# multimodal prompts\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42,\n", " prompt=\"best quality, high quality, wearing a hat on the beach\", scale=0.6)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] }, { "cell_type": "code", "execution_count": 9, "id": "5d3d874a-49b2-4c7e-ad58-b0ecc085c1fd", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1e26a277d47a42e59b3d2b54e0af378a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/50 [00:00" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# multimodal prompts\n", "images = ip_model.generate(pil_image=image, num_samples=4, num_inference_steps=50, seed=42,\n", " prompt=\"best quality, high quality, wearing sunglasses in a garden\", scale=0.6)\n", "grid = image_grid(images, 1, 4)\n", "grid" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }