Unverified Commit 5956b68a authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Improve the performance and suitable for NPU computing (#9642)



* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU computing

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

---------
Co-authored-by: default avatar蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8d81564b
...@@ -59,6 +59,8 @@ check_min_version("0.31.0.dev0") ...@@ -59,6 +59,8 @@ check_min_version("0.31.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
if is_torch_npu_available(): if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = { DATASET_NAME_MAPPING = {
...@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae): ...@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad(): with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample() model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor model_input = model_input * vae.config.scaling_factor
# There might have slightly performance improvement
# by changing model_input.cpu() to accelerator.gather(model_input)
return {"model_input": model_input.cpu()} return {"model_input": model_input.cpu()}
...@@ -935,7 +940,10 @@ def main(args): ...@@ -935,7 +940,10 @@ def main(args):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae del text_encoders, tokenizers, vae
gc.collect() gc.collect()
torch.cuda.empty_cache() if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
def collate_fn(examples): def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
...@@ -1091,8 +1099,7 @@ def main(args): ...@@ -1091,8 +1099,7 @@ def main(args):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids]) add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids return add_time_ids
add_time_ids = torch.cat( add_time_ids = torch.cat(
...@@ -1261,7 +1268,10 @@ def main(args): ...@@ -1261,7 +1268,10 @@ def main(args):
) )
del pipeline del pipeline
torch.cuda.empty_cache() if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
if args.use_ema: if args.use_ema:
# Switch back to the original UNet parameters. # Switch back to the original UNet parameters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment