Unverified Commit 8874027e authored by Katsuya's avatar Katsuya Committed by GitHub
Browse files

Make xformers optional even if it is available (#1753)

* Make xformers optional even if it is available

* Raise exception if xformers is used but not available

* Rename use_xformers to enable_xformers_memory_efficient_attention

* Add a note about xformers in README

* Reformat code style
parent b693aff7
...@@ -317,4 +317,7 @@ python train_dreambooth_flax.py \ ...@@ -317,4 +317,7 @@ python train_dreambooth_flax.py \
--max_train_steps=800 --max_train_steps=800
``` ```
You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint). ### Training with xformers:
\ No newline at end of file You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
...@@ -248,6 +248,9 @@ def parse_args(input_args=None): ...@@ -248,6 +248,9 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -516,14 +519,11 @@ def main(args): ...@@ -516,14 +519,11 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
if is_xformers_available(): if args.enable_xformers_memory_efficient_attention:
try: if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
except Exception as e: else:
logger.warning( raise ValueError("xformers is not available. Make sure it is installed correctly")
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
vae.requires_grad_(False) vae.requires_grad_(False)
if not args.train_text_encoder: if not args.train_text_encoder:
......
...@@ -160,3 +160,6 @@ python train_text_to_image_flax.py \ ...@@ -160,3 +160,6 @@ python train_text_to_image_flax.py \
--max_grad_norm=1 \ --max_grad_norm=1 \
--output_dir="sd-pokemon-model" --output_dir="sd-pokemon-model"
``` ```
### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
...@@ -234,6 +234,9 @@ def parse_args(): ...@@ -234,6 +234,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
), ),
) )
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -383,14 +386,11 @@ def main(): ...@@ -383,14 +386,11 @@ def main():
revision=args.revision, revision=args.revision,
) )
if is_xformers_available(): if args.enable_xformers_memory_efficient_attention:
try: if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
except Exception as e: else:
logger.warning( raise ValueError("xformers is not available. Make sure it is installed correctly")
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
# Freeze vae and text_encoder # Freeze vae and text_encoder
vae.requires_grad_(False) vae.requires_grad_(False)
......
...@@ -124,3 +124,6 @@ python textual_inversion_flax.py \ ...@@ -124,3 +124,6 @@ python textual_inversion_flax.py \
--output_dir="textual_inversion_cat" --output_dir="textual_inversion_cat"
``` ```
It should be at least 70% faster than the PyTorch script with the same configuration. It should be at least 70% faster than the PyTorch script with the same configuration.
### Training with xformers:
You can enable memory efficient attention by [installing xFormers](https://github.com/facebookresearch/xformers#installing-xformers) and padding the `--enable_xformers_memory_efficient_attention` argument to the script. This is not available with the Flax/JAX implementation.
...@@ -222,6 +222,9 @@ def parse_args(): ...@@ -222,6 +222,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
), ),
) )
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -457,14 +460,11 @@ def main(): ...@@ -457,14 +460,11 @@ def main():
revision=args.revision, revision=args.revision,
) )
if is_xformers_available(): if args.enable_xformers_memory_efficient_attention:
try: if is_xformers_available():
unet.enable_xformers_memory_efficient_attention() unet.enable_xformers_memory_efficient_attention()
except Exception as e: else:
logger.warning( raise ValueError("xformers is not available. Make sure it is installed correctly")
"Could not enable memory efficient attention. Make sure xformers is installed"
f" correctly and a GPU is available: {e}"
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer # Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer)) text_encoder.resize_token_embeddings(len(tokenizer))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment