Unverified Commit c62b3a2e authored by Duong A. Nguyen's avatar Duong A. Nguyen Committed by GitHub
Browse files

[Flax] Fix sample batch size DreamBooth (#1129)

fix sample batch size
parent bde4880c
...@@ -361,7 +361,8 @@ def main(): ...@@ -361,7 +361,8 @@ def main():
logger.info(f"Number of class images to sample: {num_new_images}.") logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images) sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) total_sample_batch_size = args.sample_batch_size * jax.local_device_count()
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=total_sample_batch_size)
for example in tqdm( for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0 sample_dataloader, desc="Generating class images", disable=not jax.process_index() == 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment