Unverified Commit 98ba18ba authored by HelloWorldBeginner's avatar HelloWorldBeginner Committed by GitHub
Browse files

Add Ascend NPU support for SDXL. (#7916)


Co-authored-by: default avatarmhh001 <mahonghao1@huawei.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5bb38586
...@@ -50,7 +50,7 @@ from diffusers.optimization import get_scheduler ...@@ -50,7 +50,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
...@@ -58,7 +58,8 @@ from diffusers.utils.torch_utils import is_compiled_module ...@@ -58,7 +58,8 @@ from diffusers.utils.torch_utils import is_compiled_module
check_min_version("0.28.0.dev0") check_min_version("0.28.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = { DATASET_NAME_MAPPING = {
"lambdalabs/naruto-blip-captions": ("image", "text"), "lambdalabs/naruto-blip-captions": ("image", "text"),
...@@ -460,6 +461,9 @@ def parse_args(input_args=None): ...@@ -460,6 +461,9 @@ def parse_args(input_args=None):
), ),
) )
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
...@@ -716,7 +720,12 @@ def main(args): ...@@ -716,7 +720,12 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
unet.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
if args.enable_xformers_memory_efficient_attention: if args.enable_xformers_memory_efficient_attention:
if is_xformers_available(): if is_xformers_available():
import xformers import xformers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment