Unverified Commit 2c04e585 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Multi Vector Textual Inversion (#3144)



* Multi Vector

* Improve

* fix multi token

* improve test

* make style

* Update examples/test_examples.py

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* update

* Finish

* Apply suggestions from code review

---------
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 391cfcd7
......@@ -122,6 +122,18 @@ accelerate launch textual_inversion.py \
--lr_warmup_steps=0 \
--output_dir="textual_inversion_cat"
```
<Tip>
💡 If you want to increase the trainable capacity, you can associate your placeholder token, *e.g.* `<cat-toy>` to
multiple embedding vectors. This can help the model to better capture the style of more (complex) images.
To enable training multiple embedding vectors, simply pass:
```bash
--num_vectors=5
```
</Tip>
</pt>
<jax>
If you have access to TPUs, try out the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py) to train even faster (this'll also work for GPUs). With the same configuration settings, the Flax training script should be at least 70% faster than the PyTorch training script! ⚡️
......
## Multi Token Textual Inversion
## [Deprecated] Multi Token Textual Inversion
**IMPORTART: This research project is deprecated. Multi Token Textual Inversion is now supported natively in [the officail textual inversion example](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion#running-locally-with-pytorch).**
The author of this project is [Isamu Isozaki](https://github.com/isamu-isozaki) - please make sure to tag the author for issue and PRs as well as @patrickvonplaten.
We add multi token support to textual inversion. I added
......
......@@ -105,6 +105,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
......
......@@ -36,7 +36,6 @@ And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) e
accelerate config
```
### Cat toy example
First, let's login so that we can upload the checkpoint to the Hub during training:
......@@ -83,6 +82,18 @@ accelerate launch textual_inversion.py \
A full training run takes ~1 hour on one V100 GPU.
**Note**: As described in [the official paper](https://arxiv.org/abs/2208.01618)
only one embedding vector is used for the placeholder token, *e.g.* `"<cat-toy>"`.
However, one can also add multiple embedding vectors for the placeholder token
to inclease the number of fine-tuneable parameters. This can help the model to learn
more complex details. To use multiple embedding vectors, you can should define `--num_vectors`
to a number larger than one, *e.g.*:
```
--num_vectors 5
```
The saved textual inversion vectors will then be larger in size compared to the default case.
### Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt.
......
......@@ -82,6 +82,34 @@ check_min_version("0.16.0.dev0")
logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- textual_inversion
inference: true
---
"""
model_card = f"""
# Textual inversion text2image fine-tuning - {repo_id}
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
......@@ -94,6 +122,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
tokenizer=tokenizer,
unet=unet,
vae=vae,
safety_checker=None,
revision=args.revision,
torch_dtype=weight_dtype,
)
......@@ -124,11 +153,16 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
del pipeline
torch.cuda.empty_cache()
return images
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
learned_embeds = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
......@@ -144,9 +178,15 @@ def parse_args():
parser.add_argument(
"--only_save_embeds",
action="store_true",
default=False,
default=True,
help="Save only the embeddings for the new concept.",
)
parser.add_argument(
"--num_vectors",
type=int,
default=1,
help="How many textual inversion vectors shall be used to learn the concept.",
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
......@@ -581,8 +621,19 @@ def main():
)
# Add the placeholder token in tokenizer
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
if num_added_tokens == 0:
placeholder_tokens = [args.placeholder_token]
if args.num_vectors < 1:
raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
# add dummy tokens for multi-vector
additional_tokens = []
for i in range(1, args.num_vectors):
additional_tokens.append(f"{args.placeholder_token}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
......@@ -595,14 +646,16 @@ def main():
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
with torch.no_grad():
for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# Freeze vae and unet
vae.requires_grad_(False)
......@@ -810,7 +863,9 @@ def main():
optimizer.zero_grad()
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
......@@ -818,11 +873,12 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
images = []
progress_bar.update(1)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
......@@ -831,7 +887,9 @@ def main():
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
images = log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
......@@ -858,9 +916,15 @@ def main():
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment