from __future__ import print_function
from decimal import Decimal
import sys
import os
import traceback
def unique_id():
global unique_node_id
nid = unique_node_id
unique_node_id += 1
return nid
def get_value_string(line):
return line[line.index('=') + 1:]
def get_array_strings(line):
return line[line.index('=') + 1:].split()
def get_array_ints(line):
return map(lambda x: int(x), line[line.index('=') + 1:].split())
def get_array_floats(line):
return map(lambda x: Decimal(x), line[line.index('=') + 1:].split())
def get_field_name(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id - 1] if is_child else prev_node_idx
return feature_names[split_feature[idx]]
def get_threshold(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id - 1] if is_child else prev_node_idx
return threshold[idx]
def print_simple_predicate(
tab_length,
node_id,
is_left_child,
prev_node_idx,
is_leaf,
pmml_out):
if is_left_child:
op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual'
else:
op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan'
print('\t' * (tab_length + 1) + ("") .format(
get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)), file=pmml_out)
def print_nodes_pmml(**kwargs):
node_id = kwargs['node_id']
pmml_out = kwargs['out_file']
tab_len = kwargs['tab_length']
if node_id < 0:
node_id = -1 * node_id
score = leaf_value[node_id - 1]
recordCount = leaf_count[node_id - 1]
is_leaf = True
else:
score = internal_value[node_id]
recordCount = internal_count[node_id]
is_leaf = False
print(
'\t' *
tab_len +
(
"").format(
unique_id(),
score,
recordCount),
file=pmml_out)
print_simple_predicate(
tab_len,
node_id,
kwargs['is_left_child'],
kwargs['prev_node_idx'],
is_leaf,
pmml_out)
if not is_leaf:
print_nodes_pmml(
node_id=left_child[node_id],
tab_length=tab_len + 1,
is_left_child=True,
prev_node_idx=node_id,
out_file=pmml_out)
print_nodes_pmml(
node_id=right_child[node_id],
tab_length=tab_len + 1,
is_left_child=False,
prev_node_idx=node_id,
out_file=pmml_out)
print('\t' * tab_len + "", file=pmml_out)
# print out the pmml for a decision tree
def print_pmml(pmml_out):
# specify the objective as function name and binarySplit for
# splitCharacteristic because each node has 2 children
print(
"\t\t\t\t",
file=pmml_out)
print("\t\t\t\t\t", file=pmml_out)
# list each feature name as a mining field, and treat all outliers as is,
# unless specified
for feature in feature_names:
print(
"\t\t\t\t\t\t" %
(feature), file=pmml_out)
print("\t\t\t\t\t", file=pmml_out)
# begin printing out the decision tree
print("\t\t\t\t\t" %
(unique_id(), internal_value[0], internal_count[0]), file=pmml_out)
print("\t\t\t\t\t\t", file=pmml_out)
print_nodes_pmml(
node_id=left_child[0],
tab_length=6,
is_left_child=True,
prev_node_idx=0,
out_file=pmml_out)
print_nodes_pmml(
node_id=right_child[0],
tab_length=6,
is_left_child=False,
prev_node_idx=0,
out_file=pmml_out)
print("\t\t\t\t\t", file=pmml_out)
print("\t\t\t\t", file=pmml_out)
if len(sys.argv) != 2:
print('usage: pmml.py ')
sys.exit(0)
# open the model file and then process it
try:
with open(sys.argv[1]) as model_in:
model_content = filter(
lambda line: line != '',
model_in.read().strip().split('\n'))
objective = get_value_string(model_content[4])
sigmoid = Decimal(get_value_string(model_content[5]))
feature_names = get_array_strings(model_content[6])
model_content = model_content[7:]
line_no = 0
segment_id = 1
with open('LightGBM_pmml.xml', 'w') as pmml_out:
print(
"",
file=pmml_out)
print("\t", file=pmml_out)
print("\t\t", file=pmml_out)
print("\t", file=pmml_out)
# print out data dictionary entries for each column
print(
"\t" %
len(feature_names), file=pmml_out)
# not adding any interval definition, all values are currently
# valid
for feature in feature_names:
print(
"\t\t",
file=pmml_out)
print("\t", file=pmml_out)
print("\t", file=pmml_out)
print("\t\t", file=pmml_out)
# list each feature name as a mining field, and treat all outliers
# as is, unless specified
for feature in feature_names:
print(
"\t\t\t" %
(feature), file=pmml_out)
print("\t\t", file=pmml_out)
print(
"\t\t",
file=pmml_out)
# read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision
# tree
while model_content[line_no][:4] == 'Tree':
print("\t\t\t" % segment_id, file=pmml_out)
print("\t\t\t\t", file=pmml_out)
tree_no = model_content[line_no][5:]
num_leaves = int(get_value_string(model_content[line_no + 1]))
split_feature = get_array_ints(model_content[line_no + 2])
threshold = get_array_floats(model_content[line_no + 4])
decision_type = get_array_ints(model_content[line_no + 5])
left_child = get_array_ints(model_content[line_no + 6])
right_child = get_array_ints(model_content[line_no + 7])
leaf_parent = get_array_ints(model_content[line_no + 8])
leaf_value = get_array_floats(model_content[line_no + 9])
leaf_count = get_array_ints(model_content[line_no + 10])
internal_value = get_array_floats(model_content[line_no + 11])
internal_count = get_array_ints(model_content[line_no + 12])
unique_node_id = 0
print_pmml(pmml_out)
print("\t\t\t", file=pmml_out)
line_no += 13
segment_id += 1
print("\t\t", file=pmml_out)
print("\t", file=pmml_out)
print("", file=pmml_out)
except Exception as ioex:
print(ioex)