Unverified Commit a6314a8d authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add `--dataloader_num_workers` to the DDPM training example (#1027)

parent 939ec17e
...@@ -83,7 +83,16 @@ def parse_args(): ...@@ -83,7 +83,16 @@ def parse_args():
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
) )
parser.add_argument( parser.add_argument(
"--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
" process."
),
) )
parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
...@@ -249,7 +258,9 @@ def main(args): ...@@ -249,7 +258,9 @@ def main(args):
return {"input": images} return {"input": images}
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment