seko_talk_converter.py 16.4 KB
Newer Older
1
2
3
4
"""
Model Merge and Multi-Precision Conversion Script

This script supports three conversion modes:
PengGao's avatar
PengGao committed
5
6
1. 'both' (default): Convert both R2V model and audio adapter
2. 'r2v': Only convert R2V model (R2V + distill via LoRA)
7
8
9
3. 'audio': Only convert audio adapter

Pipeline:
PengGao's avatar
PengGao committed
10
11
- R2V model: R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8
- Audio adapter: (optional: + LoRA) → audio_adapter.pt → BF16 → FP8
12
13
14
15
16
17
18
19
20

Usage Examples:
    # Convert both (default)
    python tools/convert/seko_talk_converter.py \
        --r2v_model /path/to/model.pt \
        --distill_model /path/to/model_ema.pt \
        --audio_adapter /path/to/audio_adapter.pt \
        --output_dir /data/output

PengGao's avatar
PengGao committed
21
    # Only convert R2V model
22
    python tools/convert/seko_talk_converter.py \
PengGao's avatar
PengGao committed
23
        --mode r2v \
24
25
26
27
28
29
30
31
32
33
        --r2v_model /path/to/model.pt \
        --distill_model /path/to/model_ema.pt \
        --output_dir /data/output

    # Only convert audio adapter
    python tools/convert/seko_talk_converter.py \
        --mode audio \
        --audio_adapter /path/to/audio_adapter.pt \
        --output_dir /data/output

PengGao's avatar
PengGao committed
34
35
36
37
38
39
40
    # Convert audio adapter with LoRA merge
    python tools/convert/seko_talk_converter.py \
        --mode audio \
        --audio_adapter /path/to/audio_adapter.pt \
        --audio_lora /path/to/audio_lora.pt \
        --output_dir /data/output

41
42
43
44
Output files (depending on mode):
    - merged.safetensors                  (FP32, R2V + distill merged)
    - merged_bf16.safetensors             (BF16)
    - merged_fp8.safetensors              (FP8)
PengGao's avatar
PengGao committed
45
    - audio_adapter_merged.safetensors    (FP32, audio + lora merged, optional)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    - audio_adapter_model.safetensors     (BF16)
    - audio_adapter_model_fp8.safetensors (FP8)
"""

import argparse
import subprocess
import sys
from pathlib import Path

import torch
from loguru import logger
from safetensors.torch import load_file, save_file
from tqdm import tqdm


def run_command(cmd: list, description: str):
    """Run a subprocess command and handle errors."""
    logger.info(f"\n{description}")
    logger.info("Command: " + " \\\n  ".join(cmd))

    result = subprocess.run(cmd, capture_output=True, text=True)

    if result.returncode != 0:
        logger.error(f"{description} FAILED!")
        logger.error(f"STDOUT:\n{result.stdout}")
        logger.error(f"STDERR:\n{result.stderr}")
        raise RuntimeError(f"{description} failed")

    logger.info(f"✓ {description} completed!")
    return result


def load_checkpoint(ckpt_path: Path) -> dict:
    """Load checkpoint from .pt or .safetensors file."""
    logger.info(f"Loading: {ckpt_path.name}")

    if ckpt_path.suffix in [".pt", ".pth"]:
        checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
    elif ckpt_path.suffix == ".safetensors":
        checkpoint = load_file(str(ckpt_path))
    else:
        raise ValueError(f"Unsupported format: {ckpt_path.suffix}")

    logger.info(f"  Loaded {len(checkpoint)} keys")
    return checkpoint


def convert_to_bf16(state_dict: dict) -> dict:
    """Convert all tensors to bfloat16."""
    logger.info("Converting to BF16...")
    bf16_dict = {}
    for key, tensor in tqdm(state_dict.items(), desc="BF16 conversion"):
        bf16_dict[key] = tensor.to(torch.bfloat16)
    return bf16_dict


def step1_merge_via_lora(r2v_model_path: Path, distill_model_path: Path, output_dir: Path, lora_alpha: float, temp_dir: Path) -> Path:
    """
    Step 1: Merge R2V + distillation model via LoRA using converter.py.
    Both models in FP32, output merged.safetensors (FP32).
    """
    logger.info("=" * 80)
    logger.info("STEP 1: Merge R2V + Distillation via LoRA (FP32)")
    logger.info("=" * 80)

    temp_dir.mkdir(parents=True, exist_ok=True)

    # Convert R2V to safetensors (keep FP32)
    logger.info("\n[1.1] Converting R2V model to safetensors (FP32)...")
    r2v_dict = load_checkpoint(r2v_model_path)
    r2v_safetensors = temp_dir / "model.safetensors"
    save_file(r2v_dict, str(r2v_safetensors))
    logger.info(f"  Saved: {r2v_safetensors}")

    # Convert distill to safetensors (keep FP32 for LoRA merge)
    logger.info("\n[1.2] Converting distillation model to safetensors (FP32)...")
    distill_dict = load_checkpoint(distill_model_path)
    distill_safetensors = temp_dir / "model_ema.safetensors"
    save_file(distill_dict, str(distill_safetensors))
    logger.info(f"  Saved: {distill_safetensors}")

    # Merge via LoRA using converter.py (FP32 + FP32 → FP32)
    logger.info("\n[1.3] Merging via LoRA (converter.py)...")
    cmd = [
        "python",
        "tools/convert/converter.py",
        "-s",
        str(r2v_safetensors),
        "-o",
        str(output_dir),
        "-o_n",
        "merged",
        "--lora_path",
        str(distill_safetensors),
        "--lora_alpha",
        str(lora_alpha),
        "--single_file",
    ]

    run_command(cmd, "LoRA merge")

    merged_path = output_dir / "merged.safetensors"
    if not merged_path.exists():
        raise FileNotFoundError(f"Merged file not found: {merged_path}")

    logger.info(f"  ✓ Created: {merged_path} (FP32)")
    return merged_path


def step2_convert_merged_to_bf16(merged_path: Path, output_dir: Path):
    """
    Step 2: Convert merged.safetensors (FP32) to BF16.
    """
    logger.info("=" * 80)
    logger.info("STEP 2: Convert merged.safetensors (FP32) → BF16")
    logger.info("=" * 80)

    merged_dict = load_file(str(merged_path))
    merged_bf16 = convert_to_bf16(merged_dict)

    bf16_path = output_dir / "merged_bf16.safetensors"
    save_file(merged_bf16, str(bf16_path))
    logger.info(f"  ✓ Created: {bf16_path}")


def step3_convert_merged_to_fp8(merged_path: Path, output_dir: Path, device: str = "cuda"):
    """
    Step 3: Convert merged.safetensors (FP32) to FP8 using converter.py --quantized.
    """
    logger.info("=" * 80)
    logger.info("STEP 3: Convert merged.safetensors (FP32) → FP8")
    logger.info("=" * 80)

    cmd = [
        "python",
        "tools/convert/converter.py",
        "-s",
        str(merged_path),
        "-o",
        str(output_dir),
        "-o_n",
        "merged_fp8",
PengGao's avatar
PengGao committed
188
189
        "--linear_type",
        "fp8",
190
191
192
193
194
195
196
197
198
199
200
201
        "--quantized",
        "--device",
        device,
        "--single_file",
    ]

    run_command(cmd, "Merged FP8 conversion")

    fp8_path = output_dir / "merged_fp8.safetensors"
    logger.info(f"  ✓ Created: {fp8_path}")


PengGao's avatar
PengGao committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def step_audio_merge_lora(audio_adapter_path: Path, audio_lora_path: Path, output_dir: Path, lora_alpha: float, temp_dir: Path) -> Path:
    """
    Merge audio adapter + LoRA using converter.py.
    Both in FP32, output audio_adapter_merged.safetensors (FP32).
    """
    logger.info("=" * 80)
    logger.info("AUDIO STEP 1: Merge Audio Adapter + LoRA (FP32)")
    logger.info("=" * 80)

    temp_dir.mkdir(parents=True, exist_ok=True)

    logger.info("\n[1.1] Converting audio adapter to safetensors (FP32)...")
    audio_dict = load_checkpoint(audio_adapter_path)
    audio_safetensors = temp_dir / "audio_adapter.safetensors"
    save_file(audio_dict, str(audio_safetensors))
    logger.info(f"  Saved: {audio_safetensors}")

    logger.info("\n[1.2] Converting audio LoRA to safetensors (FP32)...")
    lora_dict = load_checkpoint(audio_lora_path)
    lora_safetensors = temp_dir / "audio_lora.safetensors"
    save_file(lora_dict, str(lora_safetensors))
    logger.info(f"  Saved: {lora_safetensors}")

    logger.info("\n[1.3] Merging via LoRA (converter.py)...")
    cmd = [
        "python",
        "tools/convert/converter.py",
        "-s",
        str(audio_safetensors),
        "-o",
        str(output_dir),
        "-o_n",
        "audio_adapter_merged",
        "--lora_path",
        str(lora_safetensors),
        "--lora_alpha",
        str(lora_alpha),
        "--single_file",
    ]

    run_command(cmd, "Audio LoRA merge")

    merged_path = output_dir / "audio_adapter_merged.safetensors"
    if not merged_path.exists():
        raise FileNotFoundError(f"Merged audio file not found: {merged_path}")

    logger.info(f"  ✓ Created: {merged_path} (FP32)")
    return merged_path


252
253
254
255
256
def step4_convert_audio_adapter_to_bf16(audio_adapter_path: Path, output_dir: Path):
    """
    Step 4: Convert audio adapter to BF16.
    """
    logger.info("=" * 80)
PengGao's avatar
PengGao committed
257
    logger.info("AUDIO STEP 2: Convert audio adapter → BF16")
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    logger.info("=" * 80)

    audio_dict = load_checkpoint(audio_adapter_path)
    audio_bf16 = convert_to_bf16(audio_dict)

    bf16_path = output_dir / "audio_adapter_model.safetensors"
    save_file(audio_bf16, str(bf16_path))
    logger.info(f"  ✓ Created: {bf16_path}")


def step5_convert_audio_adapter_to_fp8(output_dir: Path):
    """
    Step 5: Convert audio adapter BF16 to FP8 using quant_adapter.py.
    """
    logger.info("=" * 80)
PengGao's avatar
PengGao committed
273
    logger.info("AUDIO STEP 3: Convert audio adapter → FP8")
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    logger.info("=" * 80)

    input_path = output_dir / "audio_adapter_model.safetensors"
    output_path = output_dir / "audio_adapter_model_fp8.safetensors"

    cmd = ["python", "tools/convert/quant_adapter.py", "--model_path", str(input_path), "--output_path", str(output_path)]

    run_command(cmd, "Audio adapter FP8 conversion")

    logger.info(f"  ✓ Created: {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Merge R2V+distill via LoRA and convert to multiple formats")

    # Mode selection
PengGao's avatar
PengGao committed
290
    parser.add_argument("--mode", type=str, choices=["both", "r2v", "audio"], default="both", help="Conversion mode: 'both' (default), 'r2v' (only R2V model), or 'audio' (only audio adapter)")
291
292

    # Inputs (conditionally required based on mode)
PengGao's avatar
PengGao committed
293
294
    parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both' and 'r2v' modes]")
    parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'r2v' modes]")
295
    parser.add_argument("--audio_adapter", type=str, help="Path to audio adapter (.pt) [required for 'both' and 'audio' modes]")
PengGao's avatar
PengGao committed
296
297
    parser.add_argument("--audio_lora", type=str, help="Path to audio LoRA (.pt/.safetensors) [optional, for merging with audio adapter]")
    parser.add_argument("--audio_lora_alpha", type=float, default=8.0, help="Alpha for audio LoRA merge (default: 8.0)")
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313

    # Outputs
    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
    parser.add_argument("--temp_dir", type=str, default=None, help="Temp directory (default: output_dir/temp)")

    # Settings
    parser.add_argument("--lora_alpha", type=float, default=8.0, help="Alpha for LoRA merge (default: 8.0)")
    parser.add_argument("--device", type=str, default="cuda", help="Device for FP8 quantization (default: cuda)")

    # Options
    parser.add_argument("--skip_merged_fp8", action="store_true", help="Skip merged FP8 conversion")
    parser.add_argument("--skip_audio_fp8", action="store_true", help="Skip audio adapter FP8 conversion")

    args = parser.parse_args()

    # Validate required arguments based on mode
PengGao's avatar
PengGao committed
314
    if args.mode in ["both", "r2v"]:
315
        if not args.r2v_model or not args.distill_model:
PengGao's avatar
PengGao committed
316
            parser.error("--r2v_model and --distill_model are required for 'both' and 'r2v' modes")
317
318
319
320
321
322
323
324
325
326
327
328

    if args.mode in ["both", "audio"]:
        if not args.audio_adapter:
            parser.error("--audio_adapter is required for 'both' and 'audio' modes")

    # Setup paths
    output_dir = Path(args.output_dir)
    temp_dir = Path(args.temp_dir) if args.temp_dir else output_dir / "temp"

    r2v_path = Path(args.r2v_model) if args.r2v_model else None
    distill_path = Path(args.distill_model) if args.distill_model else None
    audio_path = Path(args.audio_adapter) if args.audio_adapter else None
PengGao's avatar
PengGao committed
329
    audio_lora_path = Path(args.audio_lora) if args.audio_lora else None
330
331
332
333
334
335
336
337

    # Validate file existence
    if r2v_path and not r2v_path.exists():
        raise FileNotFoundError(f"R2V model not found: {r2v_path}")
    if distill_path and not distill_path.exists():
        raise FileNotFoundError(f"Distill model not found: {distill_path}")
    if audio_path and not audio_path.exists():
        raise FileNotFoundError(f"Audio adapter not found: {audio_path}")
PengGao's avatar
PengGao committed
338
339
    if audio_lora_path and not audio_lora_path.exists():
        raise FileNotFoundError(f"Audio LoRA not found: {audio_lora_path}")
340
341
342
343
344
345
346
347
348
349
350
351
352

    output_dir.mkdir(parents=True, exist_ok=True)

    logger.info("=" * 80)
    logger.info("MODEL CONVERSION PIPELINE")
    logger.info("=" * 80)
    logger.info(f"Mode:           {args.mode}")
    if r2v_path:
        logger.info(f"R2V model:      {r2v_path}")
    if distill_path:
        logger.info(f"Distill model:  {distill_path}")
    if audio_path:
        logger.info(f"Audio adapter:  {audio_path}")
PengGao's avatar
PengGao committed
353
354
    if audio_lora_path:
        logger.info(f"Audio LoRA:     {audio_lora_path}")
355
    logger.info(f"Output dir:     {output_dir}")
PengGao's avatar
PengGao committed
356
    if args.mode in ["both", "r2v"]:
357
        logger.info(f"LoRA alpha:     {args.lora_alpha}")
PengGao's avatar
PengGao committed
358
359
    if audio_lora_path:
        logger.info(f"Audio LoRA alpha: {args.audio_lora_alpha}")
360
361
362
363
364
365
366
    logger.info(f"Device:         {args.device}")
    logger.info("=" * 80)

    # Execute pipeline based on mode
    try:
        merged_path = None

PengGao's avatar
PengGao committed
367
368
369
        # Process R2V model (modes: 'both', 'r2v')
        if args.mode in ["both", "r2v"]:
            logger.info("\n>>> Processing R2V MODEL")
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

            # Step 1: Merge R2V + Distill via LoRA
            merged_path = step1_merge_via_lora(r2v_path, distill_path, output_dir, args.lora_alpha, temp_dir)

            # Step 2: Convert merged to BF16
            step2_convert_merged_to_bf16(merged_path, output_dir)

            # Step 3: Convert merged to FP8
            if not args.skip_merged_fp8:
                step3_convert_merged_to_fp8(merged_path, output_dir, args.device)

        # Process audio adapter (modes: 'both', 'audio')
        if args.mode in ["both", "audio"]:
            logger.info("\n>>> Processing AUDIO ADAPTER")

PengGao's avatar
PengGao committed
385
386
387
388
389
390
391
392
            audio_source_path = audio_path

            # Optional: Merge audio adapter + LoRA
            if audio_lora_path:
                audio_source_path = step_audio_merge_lora(audio_path, audio_lora_path, output_dir, args.audio_lora_alpha, temp_dir)

            # Convert audio adapter to BF16
            step4_convert_audio_adapter_to_bf16(audio_source_path, output_dir)
393

PengGao's avatar
PengGao committed
394
            # Convert audio adapter to FP8
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
            if not args.skip_audio_fp8:
                step5_convert_audio_adapter_to_fp8(output_dir)

    except Exception as e:
        logger.error(f"\n{'=' * 80}")
        logger.error("PIPELINE FAILED")
        logger.error(f"{'=' * 80}")
        logger.error(f"Error: {e}")
        sys.exit(1)

    # Summary
    logger.info("\n" + "=" * 80)
    logger.info("✓ PIPELINE COMPLETED SUCCESSFULLY!")
    logger.info("=" * 80)
    logger.info(f"\nMode: {args.mode}")
    logger.info(f"Output directory: {output_dir}\n")
    logger.info("Generated files:")

    # Show files based on mode
PengGao's avatar
PengGao committed
414
    if args.mode in ["both", "r2v"]:
415
416
417
418
419
420
        logger.info("  ✓ merged.safetensors                  (FP32, R2V+distill merged)")
        logger.info("  ✓ merged_bf16.safetensors             (BF16)")
        if not args.skip_merged_fp8:
            logger.info("  ✓ merged_fp8.safetensors              (FP8)")

    if args.mode in ["both", "audio"]:
PengGao's avatar
PengGao committed
421
422
        if audio_lora_path:
            logger.info("  ✓ audio_adapter_merged.safetensors    (FP32, audio+lora merged)")
423
424
425
426
        logger.info("  ✓ audio_adapter_model.safetensors     (BF16)")
        if not args.skip_audio_fp8:
            logger.info("  ✓ audio_adapter_model_fp8.safetensors (FP8)")

PengGao's avatar
PengGao committed
427
    if args.mode in ["both", "r2v"]:
428
429
430
431
        logger.info(f"\nTemp files: {temp_dir}")

    # Show conversion flow
    logger.info("\nConversion flow:")
PengGao's avatar
PengGao committed
432
433
    if args.mode in ["both", "r2v"]:
        logger.info("  R2V model:")
434
435
436
437
438
439
440
        logger.info("    1. R2V (FP32) + Distill (FP32) --LoRA--> merged.safetensors (FP32)")
        logger.info("    2. merged.safetensors (FP32) --> merged_bf16.safetensors")
        if not args.skip_merged_fp8:
            logger.info("    3. merged.safetensors (FP32) --> merged_fp8.safetensors")

    if args.mode in ["both", "audio"]:
        logger.info("  Audio adapter:")
PengGao's avatar
PengGao committed
441
442
443
444
445
446
        step_num = 1
        if audio_lora_path:
            logger.info(f"    {step_num}. audio_adapter.pt + audio_lora --LoRA--> audio_adapter_merged.safetensors (FP32)")
            step_num += 1
        logger.info(f"    {step_num}. audio_adapter --> audio_adapter_model.safetensors (BF16)")
        step_num += 1
447
        if not args.skip_audio_fp8:
PengGao's avatar
PengGao committed
448
            logger.info(f"    {step_num}. audio_adapter_model.safetensors --> audio_adapter_model_fp8.safetensors")
449
450
451
452


if __name__ == "__main__":
    main()