Unverified Commit 875afe38 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add blacklist in model checkpoint (#1325)

parent ee8217e5
......@@ -144,8 +144,18 @@ def prepare_hf_model_weights(
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors:
# Exclude files that are not needed for inference.
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment