Unverified Commit 1a3079a1 authored by 길재은's avatar 길재은 Committed by GitHub
Browse files

chore: support pytorch format in lora (#22790)


Signed-off-by: default avatarjaeeun.kil <rha3122@naver.com>
Signed-off-by: default avatar길재은 <rha3122@naver.com>
parent 941f5685
......@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
......@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
......@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path,
lora_file_path = (lora_bin_file_path
if os.path.isfile(lora_bin_file_path) else
lora_pt_file_path)
tensors = torch.load(lora_file_path,
map_location=device,
weights_only=True)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment