Unverified Commit 1a3079a1 authored by 길재은's avatar 길재은 Committed by GitHub
Browse files

chore: support pytorch format in lora (#22790)


Signed-off-by: default avatarjaeeun.kil <rha3122@naver.com>
Signed-off-by: default avatar길재은 <rha3122@naver.com>
parent 941f5685
...@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel): ...@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
""" """
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join( new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors") lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir, new_embeddings_bin_file_path = os.path.join(lora_dir,
...@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel): ...@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules(f) check_unexpected_modules(f)
for module in f.keys(): # noqa for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module) tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path): elif os.path.isfile(lora_bin_file_path) or os.path.isfile(
# When a bin file is provided, we rely on config to find unexpected lora_pt_file_path):
# modules. # When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = [] unexpected_modules = []
target_modules = peft_helper.target_modules target_modules = peft_helper.target_modules
if not isinstance(target_modules, list): if not isinstance(target_modules, list):
...@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel): ...@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}." f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct") f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, lora_file_path = (lora_bin_file_path
if os.path.isfile(lora_bin_file_path) else
lora_pt_file_path)
tensors = torch.load(lora_file_path,
map_location=device, map_location=device,
weights_only=True) weights_only=True)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment