Unverified Commit d3881f35 authored by Hzzone's avatar Hzzone Committed by GitHub
Browse files

Gligen training (#7906)



* add training code of gligen

* fix code quality tests.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 48207d66
# GLIGEN: Open-Set Grounded Text-to-Image Generation
These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.
### Install the requirements
```bash
conda create -n diffusers python==3.10
conda activate diffusers
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
### Prepare the training data
If you want to make your own grounding data, you need to install the requirements.
I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag
images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,
and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.
Only RAM needs to be installed manually:
```bash
pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
```
Download the pre-trained model:
```bash
huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
huggingface-cli download --resume-download clip-vit-large-patch14
huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
```
Make the training data on 8 GPUs:
```bash
torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
--data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--save_root /root/gligen_data \
--ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth
```
You can download the COCO training data from
```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
```
It's in the format of
```json
[
...
{
'file_path': Path,
'annos': [
{
'caption': Instance
Caption,
'bbox': bbox
in
xyxy,
'text_embeddings_before_projection': CLIP
text
embedding
before
linear
projection
}
]
}
...
]
```
### Training commands
The training script is heavily based
on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py
```bash
accelerate launch train_gligen_text.py \
--data_path /root/data/zhizhonghuang/coco_train2017.pth \
--image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--train_batch_size 8 \
--max_train_steps 100000 \
--checkpointing_steps 1000 \
--checkpoints_total_limit 10 \
--learning_rate 5e-5 \
--dataloader_num_workers 16 \
--mixed_precision fp16 \
--report_to wandb \
--tracker_project_name gligen \
--output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO
```
I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the
layout possibly at 50k iterations.
Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)
The trained model can be downloaded from
```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
```
You can run `demo.ipynb` to visualize the generated images.
Example prompts:
```python
prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'
boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],
[0.525390625, 0.552734375, 0.93359375, 0.865234375],
[0.12890625, 0.015625, 0.412109375, 0.279296875],
[0.578125, 0.08203125, 0.857421875, 0.27734375]]
gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']
```
Example images:
![alt text](generated-images-100000-00.png)
### Citation
```
@article{li2023gligen,
title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
journal={CVPR},
year={2023}
}
```
\ No newline at end of file
import os
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):
scale = image_size / min(original_image_size)
crop_y = (original_image_size[1] * scale - image_size) // 2
crop_x = (original_image_size[0] * scale - image_size) // 2
x0 = max(x * scale - crop_x, 0)
y0 = max(y * scale - crop_y, 0)
x1 = min((x + w) * scale - crop_x, image_size)
y1 = min((y + h) * scale - crop_y, image_size)
if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:
return False, (None, None, None, None)
return True, (x0, y0, x1, y1)
class COCODataset(torch.utils.data.Dataset):
def __init__(
self,
data_path,
image_path,
image_size=512,
min_box_size=0.01,
max_boxes_per_data=8,
tokenizer=None,
):
super().__init__()
self.min_box_size = min_box_size
self.max_boxes_per_data = max_boxes_per_data
self.image_size = image_size
self.image_path = image_path
self.tokenizer = tokenizer
self.transforms = transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.data_list = torch.load(data_path, map_location="cpu")
def __getitem__(self, index):
if self.max_boxes_per_data > 99:
assert False, "Are you sure setting such large number of boxes per image?"
out = {}
data = self.data_list[index]
image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB")
original_image_size = image.size
out["pixel_values"] = self.transforms(image)
annos = data["annos"]
areas, valid_annos = [], []
for anno in annos:
# x, y, w, h = anno['bbox']
x0, y0, x1, y1 = anno["bbox"]
x, y, w, h = x0, y0, x1 - x0, y1 - y0
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(
x, y, w, h, self.image_size, original_image_size, self.min_box_size
)
if valid:
anno["bbox"] = [x0, y0, x1, y1]
areas.append((x1 - x0) * (y1 - y0))
valid_annos.append(anno)
# Sort according to area and choose the largest N objects
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
wanted_idxs = wanted_idxs[: self.max_boxes_per_data]
valid_annos = [valid_annos[i] for i in wanted_idxs]
out["boxes"] = torch.zeros(self.max_boxes_per_data, 4)
out["masks"] = torch.zeros(self.max_boxes_per_data)
out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768)
for i, anno in enumerate(valid_annos):
out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size
out["masks"][i] = 1
out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"]
prob_drop_boxes = 0.1
if random.random() < prob_drop_boxes:
out["masks"][:] = 0
caption = random.choice(data["captions"])
prob_drop_captions = 0.5
if random.random() < prob_drop_captions:
caption = ""
caption = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
out["caption"] = caption
return out
def __len__(self):
return len(self.data_list)
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
")\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
"\n",
"noise_scheduler = EulerDiscreteScheduler.from_config(noise_scheduler.config)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"unet = UNet2DConditionModel.from_pretrained(\n",
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion_gligen.pipeline_stable_diffusion_gligen.StableDiffusionGLIGENPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n"
]
}
],
"source": [
"pipe = StableDiffusionGLIGENPipeline(\n",
" vae,\n",
" text_encoder,\n",
" tokenizer,\n",
" unet,\n",
" noise_scheduler,\n",
" safety_checker=None,\n",
" feature_extractor=None,\n",
")\n",
"pipe = pipe.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
"# prompt = 'A realistic top-down view of a wooden table with two apples on it'\n",
"# gen_boxes = [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]\n",
"\n",
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
"boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n",
"boxes[:, 3] = boxes[:, 1] + boxes[:, 3]\n",
"boxes = boxes.tolist()\n",
"gligen_phrases = [x[0] for x in gen_boxes]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:683: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\n",
" num_channels_latents = self.unet.in_channels\n",
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:716: FutureWarning: Accessing config attribute `cross_attention_dim` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'cross_attention_dim' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.cross_attention_dim'.\n",
" max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype\n",
"100%|██████████| 50/50 [01:21<00:00, 1.64s/it]\n"
]
}
],
"source": [
"images = pipe(\n",
" prompt=prompt,\n",
" gligen_phrases=gligen_phrases,\n",
" gligen_boxes=boxes,\n",
" gligen_scheduled_sampling_beta=1.0,\n",
" output_type=\"pil\",\n",
" num_inference_steps=50,\n",
" negative_prompt=\"artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate\",\n",
" num_images_per_prompt=16,\n",
").images"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "densecaption",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import argparse
import os
import random
import torch
import torchvision
import torchvision.transforms as TS
from PIL import Image
from ram import inference_ram
from ram.models import ram
from tqdm import tqdm
from transformers import (
AutoModelForZeroShotObjectDetection,
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
CLIPTextModel,
CLIPTokenizer,
)
torch.autograd.set_grad_enabled(False)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Caption Generation script", add_help=False)
parser.add_argument("--data_root", type=str, required=True, help="path to COCO")
parser.add_argument("--save_root", type=str, required=True, help="path to save")
parser.add_argument("--ram_checkpoint", type=str, required=True, help="path to save")
args = parser.parse_args()
# ram_checkpoint = '/root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth'
# data_root = '/mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017'
# save_root = '/root/gligen_data'
box_threshold = 0.25
text_threshold = 0.2
import torch.distributed as dist
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)
ram_model = ram(pretrained=args.ram_checkpoint, image_size=384, vit="swin_l").cuda().eval()
ram_processor = TS.Compose(
[TS.Resize((384, 384)), TS.ToTensor(), TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)
grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-base"
).cuda()
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xxl", torch_dtype=torch.float16
).cuda()
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").cuda()
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_paths = [os.path.join(args.data_root, x) for x in os.listdir(args.data_root)]
random.shuffle(image_paths)
for image_path in tqdm.tqdm(image_paths):
pth_path = os.path.join(args.save_root, os.path.basename(image_path))
if os.path.exists(pth_path):
continue
sample = {"file_path": os.path.basename(image_path), "annos": []}
raw_image = Image.open(image_path).convert("RGB")
res = inference_ram(ram_processor(raw_image).unsqueeze(0).cuda(), ram_model)
text = res[0].replace(" |", ".")
inputs = grounding_dino_processor(images=raw_image, text=text, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
outputs = grounding_dino_model(**inputs)
results = grounding_dino_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[raw_image.size[::-1]],
)
boxes = results[0]["boxes"]
labels = results[0]["labels"]
scores = results[0]["scores"]
indices = torchvision.ops.nms(boxes, scores, 0.5)
boxes = boxes[indices]
category_names = [labels[i] for i in indices]
for i, bbox in enumerate(boxes):
bbox = bbox.tolist()
inputs = blip2_processor(images=raw_image.crop(bbox), return_tensors="pt")
inputs = {k: v.cuda().to(torch.float16) for k, v in inputs.items()}
outputs = blip2_model.generate(**inputs)
caption = blip2_processor.decode(outputs[0], skip_special_tokens=True)
inputs = clip_tokenizer(
caption,
padding="max_length",
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
inputs = {k: v.cuda() for k, v in inputs.items()}
text_embeddings_before_projection = clip_text_encoder(**inputs).pooler_output.squeeze(0)
sample["annos"].append(
{
"caption": caption,
"bbox": bbox,
"text_embeddings_before_projection": text_embeddings_before_projection,
}
)
torch.save(sample, pth_path)
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
diffusers
scipy
timm
fairscale
wandb
\ No newline at end of file
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment