Unverified Commit 07860f99 authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

NPU Adaption for Sanna (#10409)



* NPU Adaption for Sanna


---------
Co-authored-by: default avatarJ石页 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 87252d80
...@@ -63,6 +63,7 @@ from diffusers.utils import ( ...@@ -63,6 +63,7 @@ from diffusers.utils import (
is_wandb_available, is_wandb_available,
) )
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -74,6 +75,9 @@ check_min_version("0.33.0.dev0") ...@@ -74,6 +75,9 @@ check_min_version("0.33.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False
def save_model_card( def save_model_card(
repo_id: str, repo_id: str,
...@@ -601,6 +605,7 @@ def parse_args(input_args=None): ...@@ -601,6 +605,7 @@ def parse_args(input_args=None):
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation") parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -924,8 +929,7 @@ def main(args): ...@@ -924,8 +929,7 @@ def main(args):
image.save(image_filename) image.save(image_filename)
del pipeline del pipeline
if torch.cuda.is_available(): free_memory()
torch.cuda.empty_cache()
# Handle the repository creation # Handle the repository creation
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -988,6 +992,13 @@ def main(args): ...@@ -988,6 +992,13 @@ def main(args):
# because Gemma2 is particularly suited for bfloat16. # because Gemma2 is particularly suited for bfloat16.
text_encoder.to(dtype=torch.bfloat16) text_encoder.to(dtype=torch.bfloat16)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
transformer.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu device ")
# Initialize a text encoding pipeline and keep it to CPU for now. # Initialize a text encoding pipeline and keep it to CPU for now.
text_encoding_pipeline = SanaPipeline.from_pretrained( text_encoding_pipeline = SanaPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
......
...@@ -3154,6 +3154,11 @@ class AttnProcessorNPU: ...@@ -3154,6 +3154,11 @@ class AttnProcessorNPU:
# scaled_dot_product_attention expects attention_mask shape to be # scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length) # (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
if attention_mask.dtype == torch.bool:
attention_mask = torch.logical_not(attention_mask.bool())
else:
attention_mask = attention_mask.bool()
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment