Unverified Commit e9442440 authored by Rahul A R's avatar Rahul A R Committed by GitHub
Browse files

streamlining 'checkpointing_steps' parsing (#18755)

parent fbf382c8
...@@ -406,12 +406,9 @@ def main(): ...@@ -406,12 +406,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -508,12 +508,9 @@ def main(): ...@@ -508,12 +508,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -552,12 +552,9 @@ def main(): ...@@ -552,12 +552,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -505,12 +505,9 @@ def main(): ...@@ -505,12 +505,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -764,12 +764,9 @@ def main(): ...@@ -764,12 +764,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration # We need to initialize the trackers we use, and also store our configuration
if args.with_tracking: if args.with_tracking:
......
...@@ -779,12 +779,9 @@ def main(): ...@@ -779,12 +779,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -475,12 +475,9 @@ def main(): ...@@ -475,12 +475,9 @@ def main():
) )
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
......
...@@ -451,12 +451,9 @@ def main(): ...@@ -451,12 +451,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -566,12 +566,9 @@ def main(): ...@@ -566,12 +566,9 @@ def main():
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process. # The trackers initializes automatically on the main process.
......
...@@ -549,12 +549,9 @@ def main(): ...@@ -549,12 +549,9 @@ def main():
# Afterwards we recalculate our number of training epochs # Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Figure out how many steps we should save the Accelerator states # Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"): checkpointing_steps = args.checkpointing_steps
checkpointing_steps = args.checkpointing_steps if checkpointing_steps is not None and checkpointing_steps.isdigit():
if args.checkpointing_steps.isdigit(): checkpointing_steps = int(checkpointing_steps)
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration. # We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log` # We initialize the trackers only on main process because `accelerator.log`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment