Unverified Commit 84a9f5d6 authored by Chenxi Li's avatar Chenxi Li Committed by GitHub
Browse files

Feature/make PEFT adapter module format compatibile (#11080)

parent 8ce830a8
......@@ -98,6 +98,7 @@ def get_normalized_target_modules(
) -> set[str]:
"""
Mapping a list of target module name to names of the normalized LoRA weights.
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
"""
params_mapping = {
"q_proj": "qkv_proj",
......@@ -109,7 +110,8 @@ def get_normalized_target_modules(
result = set()
for name in target_modules:
normalized_name = params_mapping.get(name, name)
base_name = name.split(".")[-1]
normalized_name = params_mapping.get(base_name, base_name)
result.add(normalized_name)
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment