Unverified Commit 98ba18ba authored by HelloWorldBeginner's avatar HelloWorldBeginner Committed by GitHub
Browse files

Add Ascend NPU support for SDXL. (#7916)


Co-authored-by: default avatarmhh001 <mahonghao1@huawei.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 5bb38586
......@@ -50,7 +50,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
......@@ -58,7 +58,8 @@ from diffusers.utils.torch_utils import is_compiled_module
check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = {
"lambdalabs/naruto-blip-captions": ("image", "text"),
......@@ -460,6 +461,9 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_npu_flash_attention", action="store_true", help="Whether or not to use npu flash attention."
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
......@@ -716,7 +720,12 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
if args.enable_npu_flash_attention:
if is_torch_npu_available():
logger.info("npu flash attention enabled.")
unet.enable_npu_flash_attention()
else:
raise ValueError("npu flash attention requires torch_npu extensions and is supported only on npu devices.")
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment