"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2f8acfea1ca11fe3479fb379ccbded516d0cff57"
Unverified Commit 5324bf9c authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

update `create_model_card` to properly save peft details when using Trainer with PEFT (#27754)



* update `create_model_card` to properly save peft details when using Trainer with PEFT

* nit

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent 52746922
...@@ -48,7 +48,7 @@ import huggingface_hub.utils as hf_hub_utils ...@@ -48,7 +48,7 @@ import huggingface_hub.utils as hf_hub_utils
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from huggingface_hub import create_repo, upload_folder from huggingface_hub import ModelCard, create_repo, upload_folder
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
...@@ -3494,6 +3494,12 @@ class Trainer: ...@@ -3494,6 +3494,12 @@ class Trainer:
if not self.is_world_process_zero(): if not self.is_world_process_zero():
return return
model_card_filepath = os.path.join(self.args.output_dir, "README.md")
is_peft_library = False
if os.path.exists(model_card_filepath):
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"
training_summary = TrainingSummary.from_trainer( training_summary = TrainingSummary.from_trainer(
self, self,
language=language, language=language,
...@@ -3507,9 +3513,12 @@ class Trainer: ...@@ -3507,9 +3513,12 @@ class Trainer:
dataset_args=dataset_args, dataset_args=dataset_args,
) )
model_card = training_summary.to_model_card() model_card = training_summary.to_model_card()
with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: with open(model_card_filepath, "w") as f:
f.write(model_card) f.write(model_card)
if is_peft_library:
unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
def _push_from_checkpoint(self, checkpoint_folder): def _push_from_checkpoint(self, checkpoint_folder):
# Only push from one node. # Only push from one node.
if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment