# coding: utf-8 # pylint: disable = C0111, C0103 """convert LightGBM model to pmml""" from __future__ import absolute_import from sys import argv from itertools import count def get_value_string(line): return line[line.find('=') + 1:] def get_array_strings(line): return get_value_string(line).split() def get_array_ints(line): return [int(token) for token in get_array_strings(line)] def get_field_name(node_id, prev_node_idx, is_child): idx = leaf_parent[node_id] if is_child else prev_node_idx return feature_names[split_feature[idx]] def get_threshold(node_id, prev_node_idx, is_child): idx = leaf_parent[node_id] if is_child else prev_node_idx return threshold[idx] def print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf): if is_left_child: op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual' else: op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan' out_('\t' * (tab_len + 1) + ("").format( get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf))) def print_nodes_pmml(node_id, tab_len, is_left_child, prev_node_idx): if node_id < 0: node_id = ~node_id score = leaf_value[node_id] recordCount = leaf_count[node_id] is_leaf = True else: score = internal_value[node_id] recordCount = internal_count[node_id] is_leaf = False out_('\t' * tab_len + ("").format( next(unique_id), score, recordCount)) print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf) if not is_leaf: print_nodes_pmml(left_child[node_id], tab_len + 1, True, node_id) print_nodes_pmml(right_child[node_id], tab_len + 1, False, node_id) out_('\t' * tab_len + "") # print out the pmml for a decision tree def print_pmml(): # specify the objective as function name and binarySplit for # splitCharacteristic because each node has 2 children out_("\t\t\t\t") out_("\t\t\t\t\t") # list each feature name as a mining field, and treat all outliers as is, # unless specified for feature in feature_names: out_("\t\t\t\t\t\t" % (feature)) out_("\t\t\t\t\t") # begin printing out the decision tree out_("\t\t\t\t\t".format( next(unique_id), internal_value[0], internal_count[0])) out_("\t\t\t\t\t\t") print_nodes_pmml(left_child[0], 6, True, 0) print_nodes_pmml(right_child[0], 6, False, 0) out_("\t\t\t\t\t") out_("\t\t\t\t") if len(argv) != 2: raise ValueError('usage: pmml.py ') # open the model file and then process it with open(argv[1], 'r') as model_in: # ignore first 6 and empty lines model_content = iter([line for line in model_in.read().splitlines() if line][6:]) feature_names = get_array_strings(next(model_content)) feature_infos = get_array_strings(next(model_content)) segment_id = count(1) with open('LightGBM_pmml.xml', 'w') as pmml_out: def out_(string): pmml_out.write(string + '\n') out_( "") out_("\t
") out_("\t\t") out_("\t
") # print out data dictionary entries for each column out_("\t" % len(feature_names)) # not adding any interval definition, all values are currently # valid for feature in feature_names: out_("\t\t") out_("\t") out_("\t") out_("\t\t") # list each feature name as a mining field, and treat all outliers # as is, unless specified for feature in feature_names: out_("\t\t\t" % (feature)) out_("\t\t") out_("\t\t") # read each array that contains pertinent information for the pmml # these arrays will be used to recreate the traverse the decision tree while True: tree_start = next(model_content, '') if not tree_start.startswith('Tree'): break out_("\t\t\t" % next(segment_id)) out_("\t\t\t\t") tree_no = tree_start[5:] num_leaves = int(get_value_string(next(model_content))) split_feature = get_array_ints(next(model_content)) split_gain = next(model_content) # unused threshold = get_array_strings(next(model_content)) decision_type = get_array_ints(next(model_content)) left_child = get_array_ints(next(model_content)) right_child = get_array_ints(next(model_content)) leaf_parent = get_array_ints(next(model_content)) leaf_value = get_array_strings(next(model_content)) leaf_count = get_array_strings(next(model_content)) internal_value = get_array_strings(next(model_content)) internal_count = get_array_strings(next(model_content)) shrinkage = get_value_string(next(model_content)) has_categorical = get_value_string(next(model_content)) unique_id = count(1) print_pmml() out_("\t\t\t") out_("\t\t") out_("\t") out_("
")