convert_fp16.py 2.66 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import onnx
from onnxruntime.transformers import optimizer
import onnxruntime as ort
import numpy as np

def convert_to_fp16_with_transformers(input_path, output_path):
    """使用transformers优化工具转换FP16"""
    
    print(f"1. 加载模型: {input_path}")
    
    # 分析模型获取参数
    model = onnx.load(input_path)
    
    # 尝试获取hidden_size和num_heads
    hidden_size = 512  # 默认值
    num_heads = 8      # 默认值
    
    # 从模型中推断参数
    for node in model.graph.node:
        if node.op_type == 'LayerNormalization':
            # LayerNorm的输入维度通常是hidden_size
            if len(node.input) > 1:
                for init in model.graph.initializer:
                    if init.name == node.input[1]:  # scale参数
                        if len(init.dims) > 0:
                            hidden_size = init.dims[0]
                            break
        elif 'attention' in node.name.lower() or node.op_type == 'Attention':
            num_heads = 8  # 常见的注意力头数
            break
    
    print(f"2. 使用参数 - hidden_size: {hidden_size}, num_heads: {num_heads}")
    
    # 3. 优化模型
    print("3. 开始优化模型...")
    optimized_model = optimizer.optimize_model(
        input_path,
        model_type='bert',
        num_heads=num_heads,
        hidden_size=hidden_size,
        opt_level=1,  # 基本优化
        use_gpu=False  # 转换时使用CPU
    )
    
    # 4. 转换为FP16
    print("4. 转换为FP16...")
    optimized_model.convert_float_to_float16(
        keep_io_types=True,  # 保持输入输出为FP32
        op_types_to_cast=['MatMul', 'Add', 'Gemm', 'Conv', 'Relu', 'Mul'],
        node_block_list=['LayerNormalization', 'Softmax']  # 跳过LayerNorm
    )
    
    # 5. 保存模型
    print(f"5. 保存FP16模型: {output_path}")
    optimized_model.save_model_to_file(output_path)
    
    # 6. 验证模型
    print("6. 验证模型...")
    try:
        session = ort.InferenceSession(output_path, providers=['CPUExecutionProvider'])
        print("✓ 模型验证成功!")
        
        # 打印模型信息
        inputs = session.get_inputs()
        outputs = session.get_outputs()
        print(f"   输入: {inputs[0].name} - {inputs[0].type}")
        print(f"   输出: {outputs[0].name} - {outputs[0].type}")
        
    except Exception as e:
        print(f"✗ 模型验证失败: {e}")
    
    return output_path

# 使用
sunzhq2's avatar
sunzhq2 committed
75
input_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder.onnx"
sunzhq2's avatar
sunzhq2 committed
76
77
output_path = input_path.replace('.onnx', '_fp16.onnx')
convert_to_fp16_with_transformers(input_path, output_path)