"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "263fd3c4c72752cd83c474d5b5264dad118fda38"
Unverified Commit 9e287502 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix peft ckpts not being pushed to hub (#24578)

* fix push to hub for peft ckpts

* oops
parent 232c898f
...@@ -119,6 +119,7 @@ from .trainer_utils import ( ...@@ -119,6 +119,7 @@ from .trainer_utils import (
) )
from .training_args import OptimizerNames, ParallelMode, TrainingArguments from .training_args import OptimizerNames, ParallelMode, TrainingArguments
from .utils import ( from .utils import (
ADAPTER_CONFIG_NAME,
ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME,
CONFIG_NAME, CONFIG_NAME,
...@@ -3533,6 +3534,8 @@ class Trainer: ...@@ -3533,6 +3534,8 @@ class Trainer:
output_dir = self.args.output_dir output_dir = self.args.output_dir
# To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
if is_peft_available():
modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
for modeling_file in modeling_files: for modeling_file in modeling_files:
if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
......
...@@ -178,6 +178,7 @@ from .import_utils import ( ...@@ -178,6 +178,7 @@ from .import_utils import (
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin" ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_NAME = "tf_model.h5"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment