Commit 089fa091 authored by GoatWu's avatar GoatWu
Browse files

converter support lora

parent 11fcc3fb
...@@ -96,7 +96,7 @@ class WanLoraWrapper: ...@@ -96,7 +96,7 @@ class WanLoraWrapper:
name_diff = lora_diffs[name] name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype) lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
param += lora_diff param += lora_diff * alpha
applied_count += 1 applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments") logger.info(f"Applied {applied_count} LoRA weight adjustments")
......
...@@ -398,6 +398,53 @@ def quantize_model( ...@@ -398,6 +398,53 @@ def quantize_model(
return weights return weights
def load_loras(lora_path, weight_dict, alpha):
logger.info(f"Loading LoRA from: {lora_path}")
with safe_open(lora_path, framework="pt") as f:
lora_weights = {k: f.get_tensor(k) for k in f.keys()}
lora_pairs = {}
lora_diffs = {}
prefix = "diffusion_model."
def try_lora_pair(key, suffix_a, suffix_b, target_suffix):
if key.endswith(suffix_a):
base_name = key[len(prefix) :].replace(suffix_a, target_suffix)
pair_key = key.replace(suffix_a, suffix_b)
if pair_key in lora_weights:
lora_pairs[base_name] = (key, pair_key)
def try_lora_diff(key, suffix, target_suffix):
if key.endswith(suffix):
base_name = key[len(prefix) :].replace(suffix, target_suffix)
lora_diffs[base_name] = key
for key in lora_weights.keys():
if not key.startswith(prefix):
continue
try_lora_pair(key, "lora_A.weight", "lora_B.weight", "weight")
try_lora_pair(key, "lora_down.weight", "lora_up.weight", "weight")
try_lora_diff(key, "diff", "weight")
try_lora_diff(key, "diff_b", "bias")
applied_count = 0
for name, param in weight_dict.items():
if name in lora_pairs:
name_lora_A, name_lora_B = lora_pairs[name]
lora_A = lora_weights[name_lora_A].to(param.device, param.dtype)
lora_B = lora_weights[name_lora_B].to(param.device, param.dtype)
param += torch.matmul(lora_B, lora_A) * alpha
applied_count += 1
elif name in lora_diffs:
name_diff = lora_diffs[name]
lora_diff = lora_weights[name_diff].to(param.device, param.dtype)
param += lora_diff * alpha
applied_count += 1
logger.info(f"Applied {applied_count} LoRA weight adjustments")
def convert_weights(args): def convert_weights(args):
if os.path.isdir(args.source): if os.path.isdir(args.source):
src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True) src_files = glob.glob(os.path.join(args.source, "*.safetensors"), recursive=True)
...@@ -423,6 +470,16 @@ def convert_weights(args): ...@@ -423,6 +470,16 @@ def convert_weights(args):
raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}") raise ValueError(f"Duplicate keys found: {duplicate_keys} in file {file_path}")
merged_weights.update(weights) merged_weights.update(weights)
if args.lora_path is not None:
# Handle alpha list - if single alpha, replicate for all LoRAs
if len(args.lora_alpha) == 1 and len(args.lora_path) > 1:
args.lora_alpha = args.lora_alpha * len(args.lora_path)
elif len(args.lora_alpha) != len(args.lora_path):
raise ValueError(f"Number of lora_alpha ({len(args.lora_alpha)}) must match number of lora_path ({len(args.lora_path)}) or be 1")
for path, alpha in zip(args.lora_path, args.lora_alpha):
load_loras(path, merged_weights, alpha)
if args.direction is not None: if args.direction is not None:
rules = get_key_mapping_rules(args.direction, args.model_type) rules = get_key_mapping_rules(args.direction, args.model_type)
converted_weights = {} converted_weights = {}
...@@ -584,6 +641,14 @@ def main(): ...@@ -584,6 +641,14 @@ def main():
choices=["torch.int8", "torch.float8_e4m3fn"], choices=["torch.int8", "torch.float8_e4m3fn"],
help="Data type for quantization", help="Data type for quantization",
) )
parser.add_argument("--lora_path", type=str, nargs="*", help="Path(s) to LoRA file(s). Can specify multiple paths separated by spaces.")
parser.add_argument(
"--lora_alpha",
type=float,
nargs="*",
default=[1.0],
help="Alpha for LoRA weight scaling",
)
args = parser.parse_args() args = parser.parse_args()
if args.quantized: if args.quantized:
......
...@@ -50,6 +50,21 @@ python converter.py \ ...@@ -50,6 +50,21 @@ python converter.py \
--model_type wan_dit --model_type wan_dit
``` ```
### Wan DiT + LoRA
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0
```
### Hunyuan DIT ### Hunyuan DIT
```bash ```bash
......
...@@ -50,6 +50,21 @@ python converter.py \ ...@@ -50,6 +50,21 @@ python converter.py \
--model_type wan_dit --model_type wan_dit
``` ```
### Wan DiT + LoRA
```bash
python converter.py \
--quantized \
--source /Path/To/Wan-AI/Wan2.1-T2V-14B/ \
--output /Path/To/output \
--output_ext .safetensors \
--output_name wan_int8 \
--dtype torch.int8 \
--model_type wan_dit \
--lora_path /Path/To/LoRA1/ /Path/To/LoRA2/ \
--lora_alpha 1.0 1.0
```
### Hunyuan DIT ### Hunyuan DIT
```bash ```bash
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment