Unverified Commit 31da6925 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

support adapter lora (#539)

parent 227e48a9
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
Model Merge and Multi-Precision Conversion Script Model Merge and Multi-Precision Conversion Script
This script supports three conversion modes: This script supports three conversion modes:
1. 'both' (default): Convert both merged model and audio adapter 1. 'both' (default): Convert both R2V model and audio adapter
2. 'merged': Only convert merged model (R2V + distill via LoRA) 2. 'r2v': Only convert R2V model (R2V + distill via LoRA)
3. 'audio': Only convert audio adapter 3. 'audio': Only convert audio adapter
Pipeline: Pipeline:
- Merged model: R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8 - R2V model: R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8
- Audio adapter: audio_adapter.pt → BF16 → FP8 - Audio adapter: (optional: + LoRA) → audio_adapter.pt → BF16 → FP8
Usage Examples: Usage Examples:
# Convert both (default) # Convert both (default)
...@@ -18,9 +18,9 @@ Usage Examples: ...@@ -18,9 +18,9 @@ Usage Examples:
--audio_adapter /path/to/audio_adapter.pt \ --audio_adapter /path/to/audio_adapter.pt \
--output_dir /data/output --output_dir /data/output
# Only convert merged model # Only convert R2V model
python tools/convert/seko_talk_converter.py \ python tools/convert/seko_talk_converter.py \
--mode merged \ --mode r2v \
--r2v_model /path/to/model.pt \ --r2v_model /path/to/model.pt \
--distill_model /path/to/model_ema.pt \ --distill_model /path/to/model_ema.pt \
--output_dir /data/output --output_dir /data/output
...@@ -31,10 +31,18 @@ Usage Examples: ...@@ -31,10 +31,18 @@ Usage Examples:
--audio_adapter /path/to/audio_adapter.pt \ --audio_adapter /path/to/audio_adapter.pt \
--output_dir /data/output --output_dir /data/output
# Convert audio adapter with LoRA merge
python tools/convert/seko_talk_converter.py \
--mode audio \
--audio_adapter /path/to/audio_adapter.pt \
--audio_lora /path/to/audio_lora.pt \
--output_dir /data/output
Output files (depending on mode): Output files (depending on mode):
- merged.safetensors (FP32, R2V + distill merged) - merged.safetensors (FP32, R2V + distill merged)
- merged_bf16.safetensors (BF16) - merged_bf16.safetensors (BF16)
- merged_fp8.safetensors (FP8) - merged_fp8.safetensors (FP8)
- audio_adapter_merged.safetensors (FP32, audio + lora merged, optional)
- audio_adapter_model.safetensors (BF16) - audio_adapter_model.safetensors (BF16)
- audio_adapter_model_fp8.safetensors (FP8) - audio_adapter_model_fp8.safetensors (FP8)
""" """
...@@ -177,8 +185,8 @@ def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str ...@@ -177,8 +185,8 @@ def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str
str(output_dir), str(output_dir),
"-o_n", "-o_n",
"merged_fp8", "merged_fp8",
"--linear_dtype", "--linear_type",
"torch.float8_e4m3fn", "fp8",
"--quantized", "--quantized",
"--device", "--device",
device, device,
...@@ -191,12 +199,62 @@ def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str ...@@ -191,12 +199,62 @@ def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str
logger.info(f" ✓ Created: {fp8_path}") logger.info(f" ✓ Created: {fp8_path}")
def step_audio_merge_lora(audio_adapter_path: Path, audio_lora_path: Path, output_dir: Path, lora_alpha: float, temp_dir: Path) -> Path:
"""
Merge audio adapter + LoRA using converter.py.
Both in FP32, output audio_adapter_merged.safetensors (FP32).
"""
logger.info("=" * 80)
logger.info("AUDIO STEP 1: Merge Audio Adapter + LoRA (FP32)")
logger.info("=" * 80)
temp_dir.mkdir(parents=True, exist_ok=True)
logger.info("\n[1.1] Converting audio adapter to safetensors (FP32)...")
audio_dict = load_checkpoint(audio_adapter_path)
audio_safetensors = temp_dir / "audio_adapter.safetensors"
save_file(audio_dict, str(audio_safetensors))
logger.info(f" Saved: {audio_safetensors}")
logger.info("\n[1.2] Converting audio LoRA to safetensors (FP32)...")
lora_dict = load_checkpoint(audio_lora_path)
lora_safetensors = temp_dir / "audio_lora.safetensors"
save_file(lora_dict, str(lora_safetensors))
logger.info(f" Saved: {lora_safetensors}")
logger.info("\n[1.3] Merging via LoRA (converter.py)...")
cmd = [
"python",
"tools/convert/converter.py",
"-s",
str(audio_safetensors),
"-o",
str(output_dir),
"-o_n",
"audio_adapter_merged",
"--lora_path",
str(lora_safetensors),
"--lora_alpha",
str(lora_alpha),
"--single_file",
]
run_command(cmd, "Audio LoRA merge")
merged_path = output_dir / "audio_adapter_merged.safetensors"
if not merged_path.exists():
raise FileNotFoundError(f"Merged audio file not found: {merged_path}")
logger.info(f" ✓ Created: {merged_path} (FP32)")
return merged_path
def step4_convert_audio_adapter_to_bf16(audio_adapter_path: Path, output_dir: Path): def step4_convert_audio_adapter_to_bf16(audio_adapter_path: Path, output_dir: Path):
""" """
Step 4: Convert audio adapter to BF16. Step 4: Convert audio adapter to BF16.
""" """
logger.info("=" * 80) logger.info("=" * 80)
logger.info("STEP 4: Convert audio adapter → BF16") logger.info("AUDIO STEP 2: Convert audio adapter → BF16")
logger.info("=" * 80) logger.info("=" * 80)
audio_dict = load_checkpoint(audio_adapter_path) audio_dict = load_checkpoint(audio_adapter_path)
...@@ -212,7 +270,7 @@ def step5_convert_audio_adapter_to_fp8(output_dir: Path): ...@@ -212,7 +270,7 @@ def step5_convert_audio_adapter_to_fp8(output_dir: Path):
Step 5: Convert audio adapter BF16 to FP8 using quant_adapter.py. Step 5: Convert audio adapter BF16 to FP8 using quant_adapter.py.
""" """
logger.info("=" * 80) logger.info("=" * 80)
logger.info("STEP 5: Convert audio adapter → FP8") logger.info("AUDIO STEP 3: Convert audio adapter → FP8")
logger.info("=" * 80) logger.info("=" * 80)
input_path = output_dir / "audio_adapter_model.safetensors" input_path = output_dir / "audio_adapter_model.safetensors"
...@@ -229,12 +287,14 @@ def main(): ...@@ -229,12 +287,14 @@ def main():
parser = argparse.ArgumentParser(description="Merge R2V+distill via LoRA and convert to multiple formats") parser = argparse.ArgumentParser(description="Merge R2V+distill via LoRA and convert to multiple formats")
# Mode selection # Mode selection
parser.add_argument("--mode", type=str, choices=["both", "merged", "audio"], default="both", help="Conversion mode: 'both' (default), 'merged' (only model), or 'audio' (only audio adapter)") parser.add_argument("--mode", type=str, choices=["both", "r2v", "audio"], default="both", help="Conversion mode: 'both' (default), 'r2v' (only R2V model), or 'audio' (only audio adapter)")
# Inputs (conditionally required based on mode) # Inputs (conditionally required based on mode)
parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both' and 'merged' modes]") parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both' and 'r2v' modes]")
parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'merged' modes]") parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'r2v' modes]")
parser.add_argument("--audio_adapter", type=str, help="Path to audio adapter (.pt) [required for 'both' and 'audio' modes]") parser.add_argument("--audio_adapter", type=str, help="Path to audio adapter (.pt) [required for 'both' and 'audio' modes]")
parser.add_argument("--audio_lora", type=str, help="Path to audio LoRA (.pt/.safetensors) [optional, for merging with audio adapter]")
parser.add_argument("--audio_lora_alpha", type=float, default=8.0, help="Alpha for audio LoRA merge (default: 8.0)")
# Outputs # Outputs
parser.add_argument("--output_dir", type=str, required=True, help="Output directory") parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
...@@ -251,9 +311,9 @@ def main(): ...@@ -251,9 +311,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Validate required arguments based on mode # Validate required arguments based on mode
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
if not args.r2v_model or not args.distill_model: if not args.r2v_model or not args.distill_model:
parser.error("--r2v_model and --distill_model are required for 'both' and 'merged' modes") parser.error("--r2v_model and --distill_model are required for 'both' and 'r2v' modes")
if args.mode in ["both", "audio"]: if args.mode in ["both", "audio"]:
if not args.audio_adapter: if not args.audio_adapter:
...@@ -266,6 +326,7 @@ def main(): ...@@ -266,6 +326,7 @@ def main():
r2v_path = Path(args.r2v_model) if args.r2v_model else None r2v_path = Path(args.r2v_model) if args.r2v_model else None
distill_path = Path(args.distill_model) if args.distill_model else None distill_path = Path(args.distill_model) if args.distill_model else None
audio_path = Path(args.audio_adapter) if args.audio_adapter else None audio_path = Path(args.audio_adapter) if args.audio_adapter else None
audio_lora_path = Path(args.audio_lora) if args.audio_lora else None
# Validate file existence # Validate file existence
if r2v_path and not r2v_path.exists(): if r2v_path and not r2v_path.exists():
...@@ -274,6 +335,8 @@ def main(): ...@@ -274,6 +335,8 @@ def main():
raise FileNotFoundError(f"Distill model not found: {distill_path}") raise FileNotFoundError(f"Distill model not found: {distill_path}")
if audio_path and not audio_path.exists(): if audio_path and not audio_path.exists():
raise FileNotFoundError(f"Audio adapter not found: {audio_path}") raise FileNotFoundError(f"Audio adapter not found: {audio_path}")
if audio_lora_path and not audio_lora_path.exists():
raise FileNotFoundError(f"Audio LoRA not found: {audio_lora_path}")
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
...@@ -287,9 +350,13 @@ def main(): ...@@ -287,9 +350,13 @@ def main():
logger.info(f"Distill model: {distill_path}") logger.info(f"Distill model: {distill_path}")
if audio_path: if audio_path:
logger.info(f"Audio adapter: {audio_path}") logger.info(f"Audio adapter: {audio_path}")
if audio_lora_path:
logger.info(f"Audio LoRA: {audio_lora_path}")
logger.info(f"Output dir: {output_dir}") logger.info(f"Output dir: {output_dir}")
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
logger.info(f"LoRA alpha: {args.lora_alpha}") logger.info(f"LoRA alpha: {args.lora_alpha}")
if audio_lora_path:
logger.info(f"Audio LoRA alpha: {args.audio_lora_alpha}")
logger.info(f"Device: {args.device}") logger.info(f"Device: {args.device}")
logger.info("=" * 80) logger.info("=" * 80)
...@@ -297,9 +364,9 @@ def main(): ...@@ -297,9 +364,9 @@ def main():
try: try:
merged_path = None merged_path = None
# Process merged model (modes: 'both', 'merged') # Process R2V model (modes: 'both', 'r2v')
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
logger.info("\n>>> Processing MERGED MODEL") logger.info("\n>>> Processing R2V MODEL")
# Step 1: Merge R2V + Distill via LoRA # Step 1: Merge R2V + Distill via LoRA
merged_path = step1_merge_via_lora(r2v_path, distill_path, output_dir, args.lora_alpha, temp_dir) merged_path = step1_merge_via_lora(r2v_path, distill_path, output_dir, args.lora_alpha, temp_dir)
...@@ -315,10 +382,16 @@ def main(): ...@@ -315,10 +382,16 @@ def main():
if args.mode in ["both", "audio"]: if args.mode in ["both", "audio"]:
logger.info("\n>>> Processing AUDIO ADAPTER") logger.info("\n>>> Processing AUDIO ADAPTER")
# Step 4: Convert audio adapter to BF16 audio_source_path = audio_path
step4_convert_audio_adapter_to_bf16(audio_path, output_dir)
# Optional: Merge audio adapter + LoRA
if audio_lora_path:
audio_source_path = step_audio_merge_lora(audio_path, audio_lora_path, output_dir, args.audio_lora_alpha, temp_dir)
# Convert audio adapter to BF16
step4_convert_audio_adapter_to_bf16(audio_source_path, output_dir)
# Step 5: Convert audio adapter to FP8 # Convert audio adapter to FP8
if not args.skip_audio_fp8: if not args.skip_audio_fp8:
step5_convert_audio_adapter_to_fp8(output_dir) step5_convert_audio_adapter_to_fp8(output_dir)
...@@ -338,24 +411,26 @@ def main(): ...@@ -338,24 +411,26 @@ def main():
logger.info("Generated files:") logger.info("Generated files:")
# Show files based on mode # Show files based on mode
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
logger.info(" ✓ merged.safetensors (FP32, R2V+distill merged)") logger.info(" ✓ merged.safetensors (FP32, R2V+distill merged)")
logger.info(" ✓ merged_bf16.safetensors (BF16)") logger.info(" ✓ merged_bf16.safetensors (BF16)")
if not args.skip_merged_fp8: if not args.skip_merged_fp8:
logger.info(" ✓ merged_fp8.safetensors (FP8)") logger.info(" ✓ merged_fp8.safetensors (FP8)")
if args.mode in ["both", "audio"]: if args.mode in ["both", "audio"]:
if audio_lora_path:
logger.info(" ✓ audio_adapter_merged.safetensors (FP32, audio+lora merged)")
logger.info(" ✓ audio_adapter_model.safetensors (BF16)") logger.info(" ✓ audio_adapter_model.safetensors (BF16)")
if not args.skip_audio_fp8: if not args.skip_audio_fp8:
logger.info(" ✓ audio_adapter_model_fp8.safetensors (FP8)") logger.info(" ✓ audio_adapter_model_fp8.safetensors (FP8)")
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
logger.info(f"\nTemp files: {temp_dir}") logger.info(f"\nTemp files: {temp_dir}")
# Show conversion flow # Show conversion flow
logger.info("\nConversion flow:") logger.info("\nConversion flow:")
if args.mode in ["both", "merged"]: if args.mode in ["both", "r2v"]:
logger.info(" Merged model:") logger.info(" R2V model:")
logger.info(" 1. R2V (FP32) + Distill (FP32) --LoRA--> merged.safetensors (FP32)") logger.info(" 1. R2V (FP32) + Distill (FP32) --LoRA--> merged.safetensors (FP32)")
logger.info(" 2. merged.safetensors (FP32) --> merged_bf16.safetensors") logger.info(" 2. merged.safetensors (FP32) --> merged_bf16.safetensors")
if not args.skip_merged_fp8: if not args.skip_merged_fp8:
...@@ -363,9 +438,14 @@ def main(): ...@@ -363,9 +438,14 @@ def main():
if args.mode in ["both", "audio"]: if args.mode in ["both", "audio"]:
logger.info(" Audio adapter:") logger.info(" Audio adapter:")
logger.info(" 1. audio_adapter.pt --> audio_adapter_model.safetensors (BF16)") step_num = 1
if audio_lora_path:
logger.info(f" {step_num}. audio_adapter.pt + audio_lora --LoRA--> audio_adapter_merged.safetensors (FP32)")
step_num += 1
logger.info(f" {step_num}. audio_adapter --> audio_adapter_model.safetensors (BF16)")
step_num += 1
if not args.skip_audio_fp8: if not args.skip_audio_fp8:
logger.info(" 2. audio_adapter_model.safetensors --> audio_adapter_model_fp8.safetensors") logger.info(f" {step_num}. audio_adapter_model.safetensors --> audio_adapter_model_fp8.safetensors")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment