"src/vscode:/vscode.git/clone" did not exist on "91fd181245d7b287c735f2f479e4c498d5458462"
Unverified Commit 9a2600ed authored by Oleh's avatar Oleh Committed by GitHub
Browse files

Map speedup (#6745)



* Speed up dataset mapping

* Fix missing columns

* Remove cache files cleanup

* Update examples/text_to_image/train_text_to_image_sdxl.py

* make style

* Fix code style

* style

* Empty-Commit

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 5f150c4c
......@@ -35,7 +35,7 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
......@@ -896,13 +896,19 @@ def main(args):
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
new_fingerprint_for_vae = Hasher.hash(vae_path)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
train_dataset = train_dataset.map(
train_dataset_with_embeddings = train_dataset.map(
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn,
batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
new_fingerprint=new_fingerprint_for_vae,
)
precomputed_dataset = concatenate_datasets(
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
)
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
del text_encoders, tokenizers, vae
gc.collect()
......@@ -925,7 +931,7 @@ def main(args):
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
precomputed_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
......@@ -976,7 +982,7 @@ def main(args):
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num examples = {len(precomputed_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment