Unverified Commit 01ab39b6 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Load state in else (#25318)

* Load else

* New approach

* Propagate
parent 36d5b8b0
...@@ -454,14 +454,18 @@ def main(): ...@@ -454,14 +454,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -664,14 +664,18 @@ def main(): ...@@ -664,14 +664,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -575,14 +575,18 @@ def main(): ...@@ -575,14 +575,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -613,14 +613,18 @@ def main(): ...@@ -613,14 +613,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -560,14 +560,18 @@ def main(): ...@@ -560,14 +560,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -797,14 +797,18 @@ def main(): ...@@ -797,14 +797,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -828,14 +828,18 @@ def main(): ...@@ -828,14 +828,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -558,14 +558,18 @@ def main(): ...@@ -558,14 +558,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -629,14 +629,18 @@ def main(): ...@@ -629,14 +629,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -507,14 +507,18 @@ def main(): ...@@ -507,14 +507,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -667,14 +667,18 @@ def main(): ...@@ -667,14 +667,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
...@@ -610,14 +610,18 @@ def main(): ...@@ -610,14 +610,18 @@ def main():
# Potentially load in the weights and states from a previous save # Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint: if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") checkpoint_path = args.resume_from_checkpoint
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint) path = os.path.basename(args.resume_from_checkpoint)
else: else:
# Get the most recent checkpoint # Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime) dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
checkpoint_path = path
path = os.path.basename(checkpoint_path)
accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
accelerator.load_state(path)
# Extract `epoch_{i}` or `step_{i}` # Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0] training_difference = os.path.splitext(path)[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment