gemm-softmax-gemm.py 1.04 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import onnx
import onnxruntime as ort
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions

def fuse_gemm_softmax_gemm(model_path, output_path):
    # 加载模型
    model = onnx.load(model_path)
    
    # 创建优化选项
    opt_options = FusionOptions('bert')
    opt_options.enable_gemm_fast_gelu = True
    opt_options.enable_layer_norm = False
    opt_options.enable_attention = True  # 启用注意力融合
    
    # 优化模型
    optimized_model = optimizer.optimize_model(
        model_path,
        model_type='bert',  # 也可以是 'gpt2', 'bert', 'bert_tf' 等
        num_heads=12,       # 根据你的模型调整
        hidden_size=768,    # 根据你的模型调整
        optimization_options=opt_options
    )
    
    # 保存优化后的模型
    optimized_model.save_model_to_file(output_path)
    return output_path

model_path = "/models/bert_best.onnx"
output_path = "/models/bert_best_fused.onnx"
res = fuse_gemm_softmax_gemm(model_path, output_path)
print(res)