Unverified Commit 0fd7ee79 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

NPU attention refactor for FLUX (#12209)



* NPU attention refactor for FLUX transformer

* Apply style fixes

---------
Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 0d1c5b0c
......@@ -642,6 +642,7 @@ def parse_args(input_args=None):
],
help="The image interpolation method to use for resizing images.",
)
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -1182,6 +1183,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
......
......@@ -80,6 +80,7 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
......@@ -686,6 +687,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -1213,6 +1215,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
......
......@@ -706,6 +706,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -1354,6 +1355,13 @@ def main(args):
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.set_attention_backend("_native_npu")
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
......
......@@ -22,8 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
......@@ -354,25 +353,13 @@ class FluxSingleTransformerBlock(nn.Module):
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
from ..attention_processor import FluxAttnProcessor2_0_NPU
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
)
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor()
self.attn = FluxAttention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
processor=FluxAttnProcessor(),
eps=1e-6,
pre_only=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment