Unverified Commit 6fedb29f authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[examples] add dataloader_num_workers argument (#2070)

add --dataloader_num_workers argument
parent d75ad93c
...@@ -240,6 +240,14 @@ def parse_args(input_args=None): ...@@ -240,6 +240,14 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
...@@ -652,7 +660,7 @@ def main(args): ...@@ -652,7 +660,7 @@ def main(args):
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
shuffle=True, shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1, num_workers=args.dataloader_num_workers,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
......
...@@ -285,6 +285,14 @@ def parse_args(input_args=None): ...@@ -285,6 +285,14 @@ def parse_args(input_args=None):
help="Number of hard resets of the lr in cosine_with_restarts scheduler.", help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
) )
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument( parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
) )
...@@ -746,7 +754,7 @@ def main(args): ...@@ -746,7 +754,7 @@ def main(args):
batch_size=args.train_batch_size, batch_size=args.train_batch_size,
shuffle=True, shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1, num_workers=args.dataloader_num_workers,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
......
...@@ -209,6 +209,14 @@ def parse_args(): ...@@ -209,6 +209,14 @@ def parse_args():
" remote repository specified with --pretrained_model_name_or_path." " remote repository specified with --pretrained_model_name_or_path."
), ),
) )
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
...@@ -515,7 +523,11 @@ def main(): ...@@ -515,7 +523,11 @@ def main():
# DataLoaders creation: # DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
......
...@@ -245,6 +245,14 @@ def parse_args(): ...@@ -245,6 +245,14 @@ def parse_args():
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
), ),
) )
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
...@@ -583,7 +591,11 @@ def main(): ...@@ -583,7 +591,11 @@ def main():
# DataLoaders creation: # DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
) )
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
......
...@@ -194,6 +194,14 @@ def parse_args(): ...@@ -194,6 +194,14 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
...@@ -566,7 +574,9 @@ def main(): ...@@ -566,7 +574,9 @@ def main():
center_crop=args.center_crop, center_crop=args.center_crop,
set="train", set="train",
) )
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment