"vscode:/vscode.git/clone" did not exist on "51a21df7288a7e2f78c10778493f9ba554694e81"
Unverified Commit 5956b68a authored by Leo Jiang's avatar Leo Jiang Committed by GitHub
Browse files

Improve the performance and suitable for NPU computing (#9642)



* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU computing

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

* Improve the performance and suitable for NPU

---------
Co-authored-by: default avatar蒋硕 <jiangshuo9@h-partners.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8d81564b
......@@ -59,6 +59,8 @@ check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
import torch_npu
torch.npu.config.allow_internal_format = False
DATASET_NAME_MAPPING = {
......@@ -540,6 +542,9 @@ def compute_vae_encodings(batch, vae):
with torch.no_grad():
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
# There might have slightly performance improvement
# by changing model_input.cpu() to accelerator.gather(model_input)
return {"model_input": model_input.cpu()}
......@@ -935,7 +940,10 @@ def main(args):
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
del text_encoders, tokenizers, vae
gc.collect()
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
def collate_fn(examples):
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
......@@ -1091,8 +1099,7 @@ def main(args):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = (args.resolution, args.resolution)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
......@@ -1261,7 +1268,10 @@ def main(args):
)
del pipeline
torch.cuda.empty_cache()
if is_torch_npu_available():
torch_npu.npu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
if args.use_ema:
# Switch back to the original UNet parameters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment