"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "753724d867f0938dd8d020d51410e0393bfacf11"
Unverified Commit a23ad87d authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax] Add Textual Inversion (#880)



* add textual inversion flax

* make style

* make style

* replicate vae and unet params

* make style

* minor

* save after end of training

* style

* Temporary fix
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Add Flax instruction
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent d3d22ce5
...@@ -68,6 +68,24 @@ accelerate launch textual_inversion.py \ ...@@ -68,6 +68,24 @@ accelerate launch textual_inversion.py \
A full training run takes ~1 hour on one V100 GPU. A full training run takes ~1 hour on one V100 GPU.
If you want to speed it up even more, Flax implementation is available:
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export DATA_DIR="path-to-dir-containing-images"
python textual_inversion_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="<cat-toy>" --initializer_token="toy" \
--resolution=512 \
--train_batch_size=1 \
--max_train_steps=3000 \
--learning_rate=5.0e-04 --scale_lr \
--output_dir="textual_inversion_cat"
```
It should be at least 70% faster than the PyTorch script with the same configuration.
### Inference ### Inference
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment