import sys import numpy as np from onnx_modifier import ONNXModifier def change_inf_to_value(om: ONNXModifier): records = set() for where_node in om.get_nodes("Where"): for input_name in where_node.inputs[1:]: init = om.get_initializer(input_name) if init is None: continue assert input_name == init.name init_name = input_name if init_name in records: continue # info = np.finfo(np.float32) info = np.finfo(np.float16) data = om.get_initializer_value(init.name) if data.size > 1: continue if data == np.inf: om.set_initializer_value(init_name, np.array(info.max, dtype=np.float32)) elif data == -np.inf: om.set_initializer_value(init_name, np.array(info.min, dtype=np.float32)) else: continue # print("Changed value:", init_name) records.add(init_name) # def optimize_where_ndoes(om: ONNXModifier): # """Where节点等价替换 # (1) condition为initializer, X为0, Y为输入数据: # Where(cond, X, Y) ==> Mul(Y, ~cond) # (2) condition为initializer, X为负无穷, Y为输入数据 # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0)) # (3) condition为真实输入, X为负无穷, Y为输入数据 # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf)) # cases: # 1. Where(cond, -inf, input) # a. /transformer/encoder/fusion_layers.*/attn/Where # b. /transformer/encoder/fusion_layers.*/attn/Where_1 # c. /class_embed.0_*/Where: Where(cond, -inf, input) # 2. Where(cond, 0, input): # a. /transformer/encoder/layers.*/self_attn/Where # b. /transformer/decoder/layers.*/cross_attn/Where # """ # for where_node in om.get_nodes("Where"): # where_name = where_node.name # # print("Process where node:", where_name) # x_value = om.get_initializer_value(where_node.inputs[1]) # assert x_value.size == 1 # assert x_value == np.array(0.0, dtype=np.float32) or \ # x_value == np.array(-np.inf, dtype=np.float32) # cond_init = om.get_initializer(where_node.inputs[0]) # if cond_init is not None: # cond_value = om.get_initializer_value(where_node.inputs[0]) # if x_value == np.array(0.0, dtype=np.float32): # # Where(cond, X, Y) ==> Mul(Y, ~cond) # mul_name = where_name.replace("Where", "NewMul") # mul_b_init = om.create_initializer(mul_name + "_B", # (~cond_value).astype(np.float32)) # mul_node = om.create_node("Mul", # mul_name, # [where_node.inputs[2], mul_b_init.name], # [mul_name+"_output_0"], # index=where_node.index) # next_nodes = where_node.next_nodes # for next_node in next_nodes: # next_node.replace_input(where_node.outputs[0], mul_node.outputs[0]) # elif x_value == np.array(-np.inf, dtype=np.float32): # # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0)) # sub_name = where_name.replace("Where", "NewSub") # sub_b_init = om.create_initializer( # sub_name + "_B", # np.where(cond_value.astype(np.float32), # np.finfo(np.float16).max, 0.0).astype(np.float32) # ) # sub_node = om.create_node("Sub", # sub_name, # [where_node.inputs[2], sub_b_init.name], # [sub_name+"_output_0"], # index=where_node.index) # next_nodes = where_node.next_nodes # for next_node in next_nodes: # next_node.replace_input(where_node.outputs[0], sub_node.outputs[0]) # else: # # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf)) # assert x_value == np.array(-np.inf, dtype=np.float32) # cast_name = where_name.replace("Where", "NewCast") # mul_name = where_name.replace("Where", "NewMul") # sub_name = where_name.replace("Where", "NewSub") # cast_node = om.create_node("Cast", # cast_name, # [where_node.inputs[0]], # [cast_name+"_output_0"], # to=1, # index=where_node.index) # mul_b_init = om.create_initializer(mul_name + "_B", # np.array([np.finfo(np.float16).max], np.float32)) # mul_node = om.create_node("Mul", # mul_name, # [cast_node.outputs[0], mul_b_init.name], # [mul_name+"_output_0"], # index=cast_node.index+1) # sub_node = om.create_node("Sub", # sub_name, # [where_node.inputs[2], mul_node.outputs[0]], # [sub_name+"_output_0"], # index=mul_node.index+1) # next_nodes = where_node.next_nodes # for next_node in next_nodes: # next_node.replace_input(where_node.outputs[0], sub_node.outputs[0]) # om.update_map() def optimize_where_ndoes(om: ONNXModifier): """Where节点等价替换 (加入安全校验版本)""" for where_node in om.get_nodes("Where"): where_name = where_node.name # 1. 安全获取 X 的值,如果 X 不是常量(initializer),直接跳过不优化 x_init = om.get_initializer(where_node.inputs[1]) if x_init is None: continue x_value = om.get_initializer_value(where_node.inputs[1]) # 2. 避免 assert 崩溃:如果 size 不为 1,说明不是我们要找的 Attention Mask 节点,跳过 if x_value.size != 1: continue # 3. 判断是否符合优化条件(0.0 或 -inf),不符合直接跳过 is_zero = (x_value == np.array(0.0, dtype=np.float32)) is_neg_inf = (x_value == np.array(-np.inf, dtype=np.float32)) if not (is_zero or is_neg_inf): continue cond_init = om.get_initializer(where_node.inputs[0]) if cond_init is not None: cond_value = om.get_initializer_value(where_node.inputs[0]) if is_zero: # Where(cond, X, Y) ==> Mul(Y, ~cond) mul_name = where_name.replace("Where", "NewMul") mul_b_init = om.create_initializer(mul_name + "_B", (~cond_value).astype(np.float32)) mul_node = om.create_node("Mul", mul_name, [where_node.inputs[2], mul_b_init.name], [mul_name+"_output_0"], index=where_node.index) next_nodes = where_node.next_nodes for next_node in next_nodes: next_node.replace_input(where_node.outputs[0], mul_node.outputs[0]) elif is_neg_inf: # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0)) sub_name = where_name.replace("Where", "NewSub") sub_b_init = om.create_initializer( sub_name + "_B", np.where(cond_value.astype(np.float32), np.finfo(np.float16).max, 0.0).astype(np.float32) ) sub_node = om.create_node("Sub", sub_name, [where_node.inputs[2], sub_b_init.name], [sub_name+"_output_0"], index=where_node.index) next_nodes = where_node.next_nodes for next_node in next_nodes: next_node.replace_input(where_node.outputs[0], sub_node.outputs[0]) else: # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf)) # 当 condition 不是 initializer 时,只处理 -inf 的情况 if not is_neg_inf: continue cast_name = where_name.replace("Where", "NewCast") mul_name = where_name.replace("Where", "NewMul") sub_name = where_name.replace("Where", "NewSub") cast_node = om.create_node("Cast", cast_name, [where_node.inputs[0]], [cast_name+"_output_0"], to=1, index=where_node.index) mul_b_init = om.create_initializer(mul_name + "_B", np.array([np.finfo(np.float16).max], np.float32)) mul_node = om.create_node("Mul", mul_name, [cast_node.outputs[0], mul_b_init.name], [mul_name+"_output_0"], index=cast_node.index+1) sub_node = om.create_node("Sub", sub_name, [where_node.inputs[2], mul_node.outputs[0]], [sub_name+"_output_0"], index=mul_node.index+1) next_nodes = where_node.next_nodes for next_node in next_nodes: next_node.replace_input(where_node.outputs[0], sub_node.outputs[0]) om.update_map() def optimize_transpose_nodes(om: ONNXModifier): transpose_list = [ "/transformer/encoder/Transpose", "/transformer/encoder/Transpose_1", "/transformer/encoder/Transpose_2", "/transformer/encoder/Transpose_3", "/transformer/encoder/Transpose_4", "/transformer/encoder/Transpose_5", "/transformer/encoder/Transpose_6", "/transformer/encoder/Transpose_7", "/transformer/encoder/Transpose_8", "/transformer/encoder/Transpose_9", "/transformer/encoder/Transpose_10", "/transformer/encoder/Transpose_11", "/transformer/decoder/layers.0/Transpose", "/transformer/decoder/layers.0/Transpose_1", "/transformer/decoder/layers.0/Transpose_2", "/transformer/decoder/layers.1/Transpose", "/transformer/decoder/layers.1/Transpose_1", "/transformer/decoder/layers.1/Transpose_2", "/transformer/decoder/layers.2/Transpose", "/transformer/decoder/layers.2/Transpose_1", "/transformer/decoder/layers.2/Transpose_2", "/transformer/decoder/layers.3/Transpose", "/transformer/decoder/layers.3/Transpose_1", "/transformer/decoder/layers.3/Transpose_2", "/transformer/decoder/layers.4/Transpose", "/transformer/decoder/layers.4/Transpose_1", "/transformer/decoder/layers.4/Transpose_2", "/transformer/decoder/layers.5/Transpose", "/transformer/decoder/layers.5/Transpose_1", "/transformer/decoder/layers.5/Transpose_2", "/transformer/Transpose_8", "/transformer/decoder/Transpose", "/transformer/decoder/Transpose_1", "/transformer/decoder/Transpose_2", "/transformer/decoder/Transpose_3", "/transformer/decoder/Transpose_4", "/transformer/decoder/Transpose_5", "/transformer/decoder/Transpose_6", "/transformer/decoder/Transpose_7", "/transformer/decoder/Transpose_8", "/transformer/decoder/Transpose_9", "/transformer/decoder/Transpose_10", "/transformer/decoder/Transpose_11" ] for name in transpose_list: node = om.get_node(name) # 安全校验:如果找不到这个节点,说明当前模型不需要优化这个点,跳过 if node is None: continue if 'perm' in node.attrs and (node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3]): next_nodes = om.get_next_nodes(node) for node_ in next_nodes: node_.replace_input(node.outputs[0], node.inputs[0]) # modify /transformer/encoder/text_layers.*/self_attn/Reshape_4 shape_init1 = om.create_initializer( "/transformer/encoder/text_layers.x/self_attn/des_shape", np.array([1, 4, 256], np.int64) ) for i in range(6): reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4") if reshape_node is not None: reshape_node.set_input(1, shape_init1.name) # modify /transformer/enc_out_class_embed/Transpose trans_node = om.get_node("/transformer/enc_out_class_embed/Transpose") if trans_node is not None: trans_node.set_attribute("perm", [0, 2, 1]) # modify /transformer/decoder/Reshape_* # 安全校验:避免写死的随机变量名 _v_5525 引发崩溃 init_5525 = om.get_initializer("_v_5525") if init_5525 is not None: om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64)) # modify /transformer/decoder/layers.*/self_attn/Reshape_4 # modify /transformer/decoder/layers.*/ca_text/Reshape_6 shape_init3 = om.create_initializer( "/transformer/decoder/layers.x/self_attn_ca_text/des_shape", np.array([1, 900, 256], np.int64) ) for i in range(6): reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4") if reshape_node1 is not None: reshape_node1.set_input(1, shape_init3.name) reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6") if reshape_node2 is not None: reshape_node2.set_input(1, shape_init3.name) # modify /transformer/decoder/layers.0/Add # modify /transformer/decoder/layers.0/Add_1 init_name = "/transformer/Tile_1_output_0" tile_init = om.get_initializer(init_name) if tile_init is not None: add_value = om.get_initializer_value(init_name) om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2))) om.update_map() # 将形状推断包起来,防止自定义算子(MSDeformAttn)导致推理失败崩溃 try: om.infer_shape(strict_mode=False) except Exception as e: print(f"[Warning] infer_shape 跳过 (可能由于自定义算子引起). 详细信息: {e}") def optmize_sin_cos_block(om: ONNXModifier): node_pairs = [ ("/transformer/decoder/Gather_1", "/transformer/decoder/ref_point_head/layers.0/MatMul"), ("/transformer/decoder/Gather_6", "/transformer/decoder/ref_point_head/layers.0_1/MatMul"), ("/transformer/decoder/Gather_11", "/transformer/decoder/ref_point_head/layers.0_2/MatMul"), ("/transformer/decoder/Gather_16", "/transformer/decoder/ref_point_head/layers.0_3/MatMul"), ("/transformer/decoder/Gather_21", "/transformer/decoder/ref_point_head/layers.0_4/MatMul"), ("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"), ] # 提前创建一些公用的 initializer unsqueeze_axes_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/unsqueeze_axes1", np.array([3, 4], np.int64)) slice_axes_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_axes", np.array([4], np.int64)) slice_steps_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps", np.array([1], np.int64)) slice_starts_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_starts1", np.array([0], np.int64)) slice_ends_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends1", np.array([1], np.int64)) slice_starts_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps2", np.array([1], np.int64)) slice_ends_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends2", np.array([2], np.int64)) reshape_init = om.create_initializer("/transformer/decoder/sin_cos_block/reshape_dst_shape", np.array([1, 900, -1], np.int64)) for i, (gather_name, matmul_name) in enumerate(node_pairs): gather_node = om.get_node(gather_name) matmul_node = om.get_node(matmul_name) # 【安全校验】:如果找不到这一对节点,说明不需要/无法优化这个 block,直接跳过 if gather_node is None or matmul_node is None: continue try: next_node = om.get_next_nodes(gather_node)[0] if next_node.op_type != "Mul": continue mul_init_value = om.get_initializer_value(next_node.inputs[1]) if mul_init_value.size != 1: continue next_node = om.get_next_nodes(next_node)[0] if next_node.op_type != "Unsqueeze": continue next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name]) next_node = om.get_next_nodes(next_node)[0] if next_node.op_type != "Div": continue div_init_value = om.get_initializer_value(next_node.inputs[1]) new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2) new_init = om.create_initializer(next_node.name + "_B", new_value) next_node.set_input(1, new_init.name) next_nodes = om.get_next_nodes(next_node) if len(next_nodes) != 2 or not all(x.op_type == 'Slice' for x in next_nodes): continue sin_node, cos_node = None, None for j, slice_node in enumerate(next_nodes): slice_node.set_inputs([slice_node.inputs[0], slice_starts_init1.name if j == 0 else slice_starts_init2.name, slice_ends_init1.name if j == 0 else slice_ends_init2.name, slice_axes_init.name, slice_steps_init.name]) n_node = om.get_next_nodes(slice_node)[0] if n_node.op_type == "Sin": sin_node = n_node elif n_node.op_type == "Cos": cos_node = n_node else: raise RuntimeError("match fail!") n_node = om.get_next_nodes(n_node)[0] n_node = om.get_next_nodes(n_node)[0] next_node = n_node # Concat node if next_node.op_type != "Concat": continue next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]]) next_node.set_attribute("axis", 4) next_node = om.get_next_nodes(next_node)[0] if next_node.op_type != "Reshape": continue next_node.set_input(1, reshape_init.name) matmul_node.set_input(0, next_node.outputs[0]) if i == 0: mm_b_value = om.get_initializer_value(matmul_node.inputs[1]) mm_b_value = np.concatenate([mm_b_value[128:256, ...], mm_b_value[0:128, ...], mm_b_value[256:, ...]], axis=0) om.set_initializer_value(matmul_node.inputs[1], mm_b_value) except Exception as e: # 如果匹配过程中发生任何形状或节点断层的意外,静默跳过这个 block continue om.update_map() try: om.infer_shape(strict_mode=False) except: pass def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12): softmax_node = om.get_node(softmax_name) tmp_node = om.get_prev_nodes(softmax_node)[0] assert tmp_node.op_type in ["MatMul", "Add"] mask = None if tmp_node.op_type == "Add": mask_node = tmp_node tmp_node = om.get_from_node(mask_node.inputs[0]) if tmp_node.op_type == "Div": tmp_node = om.get_from_node(tmp_node.inputs[0]) assert tmp_node.op_type == "MatMul" mask = mask_node.inputs[1] assert new_mask is not None tmp_node1 = om.get_from_node(tmp_node.inputs[0]) if tmp_node1.op_type == "Mul": tmp_node1 = om.get_prev_nodes(tmp_node1)[0] tmp_node2 = om.get_from_node(tmp_node.inputs[1]) assert tmp_node1.op_type == tmp_node2.op_type == "Transpose" tmp_node1 = om.get_prev_nodes(tmp_node1)[0] tmp_node2 = om.get_prev_nodes(tmp_node2)[0] assert tmp_node1.op_type == tmp_node2.op_type == "Reshape" q, k = tmp_node1.inputs[0], tmp_node2.inputs[0] tmp_node = om.get_next_nodes(softmax_node)[0] assert tmp_node.op_type == "MatMul" tmp_node3 = om.get_from_node(tmp_node.inputs[1]) if tmp_node3 is not None: assert tmp_node3.op_type == "Transpose" tmp_node3 = om.get_prev_nodes(tmp_node3)[0] assert tmp_node3.op_type == "Reshape" v = tmp_node3.inputs[0] else: v_init = om.get_initializer(tmp_node.inputs[1]) v_init_value = om.get_initializer_value(tmp_node.inputs[1]) v_init_value = v_init_value[None, ...].transpose(0, 2, 1, 3) B, S, H, D = v_init_value.shape v_init_value = np.ascontiguousarray(v_init_value.reshape(B, S, H*D)) om.set_initializer_value(tmp_node.inputs[1], v_init_value) v = v_init.name tmp_node = om.get_next_nodes(tmp_node)[0] assert tmp_node.op_type == "Transpose" tmp_node = om.get_next_nodes(tmp_node)[0] assert tmp_node.op_type == "Reshape" mha_next_node = om.get_next_nodes(tmp_node)[0] if mha_next_node.op_type == "Gemm": gemm_next_node = om.get_next_nodes(mha_next_node)[0] assert gemm_next_node.op_type == "Reshape" reshape_next_node = om.get_next_nodes(gemm_next_node)[0] assert reshape_next_node.op_type == "Add" else: assert mha_next_node.op_type == "MatMul" name_prefix = '/'.join(softmax_name.split('/')[:-1]) mha_name = f"{name_prefix}/MultiHeadAttention" mha_node = om.create_node("MultiHeadAttention", mha_name, [q, k, v] if mask is None else [q, k, v, new_mask], [mha_name+'_output_0'], num_heads=num_heads, domain="com.microsoft", index=mha_next_node.index-1) mha_next_node.replace_input(mha_next_node.inputs[0], mha_node.outputs[0]) if mha_next_node.op_type == "Gemm": weights = om.get_initializer_value(mha_next_node.inputs[1]) transB = mha_next_node.attrs["transB"] assert transB == 1 weights = np.ascontiguousarray(weights.transpose(1, 0)) om.set_initializer_value(mha_next_node.inputs[1], weights) new_matmul_name = mha_next_node.name.replace("Gemm", "MatMul(Gemm)") new_matmul_node = om.create_node("MatMul", new_matmul_name, [mha_node.outputs[0], mha_next_node.inputs[1]], [new_matmul_name + "_output_0"], index=mha_next_node.index) new_bias_name = mha_next_node.name.replace("Gemm", "Add(Gemm)") new_add_node = om.create_node("Add", new_bias_name, [new_matmul_node.outputs[0], mha_next_node.inputs[2]], [new_bias_name + "_output_0"], index=new_matmul_node.index+1) reshape_next_node.replace_input(gemm_next_node.outputs[0], new_add_node.outputs[0]) def optimize_normal_attention(om: ONNXModifier): def create_new_attention_mask(): mask_next_node = om.get_to_nodes("attention_mask")[0] cast_node = om.create_node("Cast", "Cast_for_attention_mask", ["attention_mask"], ["Cast_for_attention_mask_output_0"], to=1, index=mask_next_node.index) reducesum_node = om.create_node("ReduceSum", "ReduceSum_for_mask", [cast_node.outputs[0]], ["ReduceSum_for_mask_output_0"], axes=1, keepdims=0, index=cast_node.index+1) return reducesum_node.outputs[0] # bert # for i in range(12): # fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12) new_mask = create_new_attention_mask() for i in range(6): # /transformer/encoder # fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4) # /transformer/decoder fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8) fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8) om.update_map() def optimize_backbone_attention(om: ONNXModifier): def get_original_mask(mask_name, name_prefix): mask_value = om.get_initializer_value(mask_name) orig_mask = np.where(mask_value==0, 1, 0).astype(np.bool_) orig_mask_init = om.create_initializer(f"{name_prefix}/mask", orig_mask) return orig_mask_init.name def _fuse_one_attention(softmax_name: str): name_prefix = '/'.join(softmax_name.split('/')[:-1]) softmax_node = om.get_node(softmax_name) tmp_node = om.get_prev_nodes(softmax_node)[0] pos_bias_init = None if tmp_node.op_type == "Reshape": tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Add" pos_bias_init = om.get_initializer(tmp_node.inputs[1]) tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Reshape" tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Add" mask = get_original_mask(tmp_node.inputs[1], name_prefix) tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "MatMul" qk_matmul = tmp_node tmp_node = om.get_from_node(qk_matmul.inputs[0]) assert tmp_node.op_type == "Mul" tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Gather" q_gather_node = tmp_node tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Transpose" tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Reshape" reshape_node = tmp_node tmp_node = om.get_from_node(qk_matmul.inputs[1]) assert tmp_node.op_type == "Transpose" tmp_node = om.get_prev_nodes(tmp_node)[0] assert tmp_node.op_type == "Gather" k_gather_node = tmp_node tmp_node = om.get_next_nodes(softmax_node)[0] assert tmp_node.op_type == "MatMul" v_gather_node = om.get_from_node(tmp_node.inputs[1]) assert v_gather_node.op_type == "Gather" tmp_node = om.get_next_nodes(tmp_node)[0] assert tmp_node.op_type == "Transpose" tmp_node = om.get_next_nodes(tmp_node)[0] assert tmp_node.op_type == "Reshape" mha_out = tmp_node.outputs[0] old_dst_shape = om.get_initializer_value(reshape_node.inputs[1]) b, s, _, h, d = old_dst_shape new_dst_shape = [b, s, _, h*d] new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape", np.array(new_dst_shape, np.int64)) reshape_node.set_input(1, new_dst_shape_init.name) for node in [q_gather_node, k_gather_node, v_gather_node]: node.set_input(0, reshape_node.outputs[0]) node.set_attribute("axis", 2) mha_name = f"{name_prefix}/MultiHeadAttention" inputs = [q_gather_node.outputs[0], k_gather_node.outputs[0], v_gather_node.outputs[0], mask] if pos_bias_init is not None: inputs.append(pos_bias_init.name) mha_node = om.create_node("MultiHeadAttention", mha_name, inputs, [mha_name+'_output_0'], num_heads=h, domain="com.microsoft", index=softmax_node.index) mha_next_node = om.get_to_nodes(mha_out)[0] mha_next_node.replace_input(mha_out, mha_node.outputs[0]) num_layers = 4 for l in range(num_layers): num_blocks = 18 if l == 2 else 2 for b in range(num_blocks): _fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax") def optimize_bidirect_attention(om: ONNXModifier): for i in range(6): reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1" reduce_max_node = om.get_node(reduce_max_name) # 【安全校验】 if reduce_max_node is None: continue next_nodes = om.get_next_nodes(reduce_max_node) if not next_nodes: continue next_node = next_nodes[0] if next_node.op_type != "Sub": continue name_prefix = '/'.join(reduce_max_name.split('/')[:-1]) matmul_name = f"{name_prefix}/identity_MatMul" matmul_init = om.create_initializer(matmul_name + "_B", np.diag(np.array([1] * 1)).astype(np.float32)) matmul_node = om.create_node( "MatMul", matmul_name, [reduce_max_node.outputs[0], matmul_init.name], [f"{matmul_name}_output_0"], index = reduce_max_node.index+1 ) next_node.set_input(1, matmul_node.outputs[0]) # def main(): # input_onnx_path = sys.argv[1] # output_onnx_path = sys.argv[2] # # input_onnx_path = "ground_sim.onnx" # # output_onnx_path = "ground_sim_0424_2nd.onnx" # om = ONNXModifier(input_onnx_path) # optimize_where_ndoes(om) # 1. 替换where节点 # optimize_transpose_nodes(om) # 2. 优化transpose节点 # optmize_sin_cos_block(om) # 3. 优化位置编码 # # om.add_opset_import("com.microsoft", 1) # # optimize_normal_attention(om) # 4. 融合bert、transformer中的mha # # optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力 # # optimize_backbone_attention(om) # 6. 融合backbone中的注意力 # optimize_bidirect_attention(om) # 7. 优化双向注意力 # om.save(output_onnx_path, save_as_external_data=False) def main(): # 假设你的原始模型路径 input_onnx_path = "../weights/ground_deform_sim.onnx" # 优化后的模型输出路径 output_onnx_path = "../weights_opt/ground_deform_opt.onnx" print(f"Loading ONNX model from {input_onnx_path}...") om = ONNXModifier(input_onnx_path) print("1. Optimizing Where nodes (Crucial for FP16 & MIGraphX)...") optimize_where_ndoes(om) print("2. Optimizing Transpose nodes...") optimize_transpose_nodes(om) # print("3. Optimizing Sin/Cos positional encoding...") # optmize_sin_cos_block(om) # print("4. Optimizing Bidirectional attention...") # optimize_bidirect_attention(om) print(f"Saving optimized model to {output_onnx_path}...") om.save(output_onnx_path, save_as_external_data=False) print("Optimization Done!") if __name__ == "__main__": main()