"benchmark/vscode:/vscode.git/clone" did not exist on "8cdc76f6d4cd61ced1d84a44c243b8a89e0a1f74"
Unverified Commit 8df28bb3 authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Push sharded checkpoint to hub when `push_to_hub=True` in `TrainingArguments` (#31808)

Save sharded checkpoint in Trainer
parent da79b180
...@@ -22,6 +22,7 @@ import functools ...@@ -22,6 +22,7 @@ import functools
import glob import glob
import importlib.metadata import importlib.metadata
import inspect import inspect
import json
import math import math
import os import os
import random import random
...@@ -4215,6 +4216,15 @@ class Trainer: ...@@ -4215,6 +4216,15 @@ class Trainer:
output_dir = self.args.output_dir output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
# Add sharded checkpoints if we have an index
for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
index_path = os.path.join(checkpoint_folder, index_file)
if os.path.isfile(index_path):
modeling_files.append(index_file)
with open(index_path) as f:
index = json.loads(f.read())
shard_files = list(set(index["weight_map"].values()))
modeling_files.extend(shard_files)
if is_peft_available(): if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files: for modeling_file in modeling_files:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment