Unverified Commit 5b933382 authored by Andrew Ishutin's avatar Andrew Ishutin Committed by GitHub
Browse files

fix custom diffusion training with concept list (#6710)


Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 7c1c705f
......@@ -753,7 +753,7 @@ def main(args):
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment