Unverified Commit 8874027e authored by Katsuya's avatar Katsuya Committed by GitHub
Browse files

Make xformers optional even if it is available (#1753)

* Make xformers optional even if it is available

* Raise exception if xformers is used but not available

* Rename use_xformers to enable_xformers_memory_efficient_attention

* Add a note about xformers in README

* Reformat code style
parent b693aff7
......@@ -317,4 +317,7 @@ python train_dreambooth_flax.py \
--max_train_steps=800
```
You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
\ No newline at end of file
### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
......@@ -248,6 +248,9 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -516,14 +519,11 @@ def main(args):
revision=args.revision,
)
if is_xformers_available():
try:
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
vae.requires_grad_(False)
if not args.train_text_encoder:
......
......@@ -160,3 +160,6 @@ python train_text_to_image_flax.py \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
```
### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
......@@ -234,6 +234,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
......@@ -383,14 +386,11 @@ def main():
revision=args.revision,
)
if is_xformers_available():
try:
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Freeze vae and text_encoder
vae.requires_grad_(False)
......
......@@ -124,3 +124,6 @@ python textual_inversion_flax.py \
--output_dir="textual_inversion_cat"
```
It should be at least 70% faster than the PyTorch script with the same configuration.
### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
......@@ -222,6 +222,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
......@@ -457,14 +460,11 @@ def main():
revision=args.revision,
)
if is_xformers_available():
try:
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
except Exception as e:
logger.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment