"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "946bb53c566ef75a6c9417ca399b7e072d243c04"
Unverified Commit 9a2600ed authored by Oleh's avatar Oleh Committed by GitHub
Browse files

Map speedup (#6745)



* Speed up dataset mapping

* Fix missing columns

* Remove cache files cleanup

* Update examples/text_to_image/train_text_to_image_sdxl.py

* make style

* Fix code style

* style

* Empty-Commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 5f150c4c
...@@ -35,7 +35,7 @@ import transformers ...@@ -35,7 +35,7 @@ import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder from huggingface_hub import create_repo, upload_folder
from packaging import version from packaging import version
from torchvision import transforms from torchvision import transforms
...@@ -896,13 +896,19 @@ def main(args): ...@@ -896,13 +896,19 @@ def main(args):
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args) new_fingerprint = Hasher.hash(args)
new_fingerprint_for_vae = Hasher.hash(vae_path) new_fingerprint_for_vae = Hasher.hash(vae_path)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) train_dataset_with_embeddings = train_dataset.map(
train_dataset = train_dataset.map( compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn, compute_vae_encodings_fn,
batched=True, batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
new_fingerprint=new_fingerprint_for_vae, new_fingerprint=new_fingerprint_for_vae,
) )
precomputed_dataset = concatenate_datasets(
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
)
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
del text_encoders, tokenizers, vae del text_encoders, tokenizers, vae
gc.collect() gc.collect()
...@@ -925,7 +931,7 @@ def main(args): ...@@ -925,7 +931,7 @@ def main(args):
# DataLoaders creation: # DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, precomputed_dataset,
shuffle=True, shuffle=True,
collate_fn=collate_fn, collate_fn=collate_fn,
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
...@@ -976,7 +982,7 @@ def main(args): ...@@ -976,7 +982,7 @@ def main(args):
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num examples = {len(precomputed_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment