Commit 7f4610a8 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine pmml.py (#179)

* add pmml to test

* refine pmml.py

* use ~n instead of -n-1

* change map to list comprehension

* fix check

* fix 'use ~n instead of -n-1'

* fix exception
parent 3f4ef95b
from __future__ import print_function # coding: utf-8
from builtins import map # pylint: disable = C0111, C0103
from builtins import next """convert LightGBM model to pmml"""
from decimal import Decimal from __future__ import absolute_import
import sys from sys import argv
import os from itertools import count
import traceback
import itertools
def get_value_string(line): def get_value_string(line):
return line[line.index('=') + 1:] return line[line.find('=') + 1:]
def get_array_strings(line): def get_array_strings(line):
return line[line.index('=') + 1:].split() return get_value_string(line).split()
def get_array_ints(line): def get_array_ints(line):
return list(map(int, line[line.index('=') + 1:].split())) return [int(token) for token in get_array_strings(line)]
def get_field_name(node_id, prev_node_idx, is_child): def get_field_name(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id - 1] if is_child else prev_node_idx idx = leaf_parent[node_id] if is_child else prev_node_idx
return feature_names[split_feature[idx]] return feature_names[split_feature[idx]]
def get_threshold(node_id, prev_node_idx, is_child): def get_threshold(node_id, prev_node_idx, is_child):
idx = leaf_parent[node_id - 1] if is_child else prev_node_idx idx = leaf_parent[node_id] if is_child else prev_node_idx
return threshold[idx] return threshold[idx]
def print_simple_predicate( def print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf):
tab_length,
node_id,
is_left_child,
prev_node_idx,
is_leaf,
pmml_out):
if is_left_child: if is_left_child:
op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual' op = 'equal' if decision_type[prev_node_idx] == 1 else 'lessOrEqual'
else: else:
op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan' op = 'notEqual' if decision_type[prev_node_idx] == 1 else 'greaterThan'
print('\t' * (tab_length + 1) + ("<SimplePredicate field=\"{0}\" " + " operator=\"{1}\" value=\"{2}\" />") .format( out_('\t' * (tab_len + 1) + ("<SimplePredicate field=\"{0}\" " + " operator=\"{1}\" value=\"{2}\" />").format(
get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)), file=pmml_out) get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)))
def print_nodes_pmml(**kwargs): def print_nodes_pmml(node_id, tab_len, is_left_child, prev_node_idx):
node_id = kwargs['node_id']
pmml_out = kwargs['out_file']
tab_len = kwargs['tab_length']
if node_id < 0: if node_id < 0:
node_id = -1 * node_id node_id = ~node_id
score = leaf_value[node_id - 1] score = leaf_value[node_id]
recordCount = leaf_count[node_id - 1] recordCount = leaf_count[node_id]
is_leaf = True is_leaf = True
else: else:
score = internal_value[node_id] score = internal_value[node_id]
recordCount = internal_count[node_id] recordCount = internal_count[node_id]
is_leaf = False is_leaf = False
print( out_('\t' * tab_len + ("<Node id=\"{0}\" score=\"{1}\" " + " recordCount=\"{2}\">").format(
'\t' * next(unique_id), score, recordCount))
tab_len + print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf)
(
"<Node id=\"{0}\" score=\"{1}\" " +
" recordCount=\"{2}\">").format(
next(unique_id),
score,
recordCount),
file=pmml_out)
print_simple_predicate(
tab_len,
node_id,
kwargs['is_left_child'],
kwargs['prev_node_idx'],
is_leaf,
pmml_out)
if not is_leaf: if not is_leaf:
print_nodes_pmml( print_nodes_pmml(left_child[node_id], tab_len + 1, True, node_id)
node_id=left_child[node_id], print_nodes_pmml(right_child[node_id], tab_len + 1, False, node_id)
tab_length=tab_len + 1, out_('\t' * tab_len + "</Node>")
is_left_child=True,
prev_node_idx=node_id,
out_file=pmml_out)
print_nodes_pmml(
node_id=right_child[node_id],
tab_length=tab_len + 1,
is_left_child=False,
prev_node_idx=node_id,
out_file=pmml_out)
print('\t' * tab_len + "</Node>", file=pmml_out)
# print out the pmml for a decision tree # print out the pmml for a decision tree
def print_pmml(pmml_out): def print_pmml():
# specify the objective as function name and binarySplit for # specify the objective as function name and binarySplit for
# splitCharacteristic because each node has 2 children # splitCharacteristic because each node has 2 children
print( out_("\t\t\t\t<TreeModel functionName=\"regression\" splitCharacteristic=\"binarySplit\">")
"\t\t\t\t<TreeModel functionName=\"regression\" splitCharacteristic=\"binarySplit\">", out_("\t\t\t\t\t<MiningSchema>")
file=pmml_out)
print("\t\t\t\t\t<MiningSchema>", file=pmml_out)
# list each feature name as a mining field, and treat all outliers as is, # list each feature name as a mining field, and treat all outliers as is,
# unless specified # unless specified
for feature in feature_names: for feature in feature_names:
print( out_("\t\t\t\t\t\t<MiningField name=\"%s\"/>" % (feature))
"\t\t\t\t\t\t<MiningField name=\"%s\"/>" % out_("\t\t\t\t\t</MiningSchema>")
(feature), file=pmml_out)
print("\t\t\t\t\t</MiningSchema>", file=pmml_out)
# begin printing out the decision tree # begin printing out the decision tree
print("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format( out_("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
next(unique_id), internal_value[0], internal_count[0]), file=pmml_out) next(unique_id), internal_value[0], internal_count[0]))
print("\t\t\t\t\t\t<True/>", file=pmml_out) out_("\t\t\t\t\t\t<True/>")
print_nodes_pmml( print_nodes_pmml(left_child[0], 6, True, 0)
node_id=left_child[0], print_nodes_pmml(right_child[0], 6, False, 0)
tab_length=6, out_("\t\t\t\t\t</Node>")
is_left_child=True, out_("\t\t\t\t</TreeModel>")
prev_node_idx=0,
out_file=pmml_out)
print_nodes_pmml( if len(argv) != 2:
node_id=right_child[0], raise ValueError('usage: pmml.py <input model file>')
tab_length=6,
is_left_child=False,
prev_node_idx=0,
out_file=pmml_out)
print("\t\t\t\t\t</Node>", file=pmml_out)
print("\t\t\t\t</TreeModel>", file=pmml_out)
if len(sys.argv) != 2:
print('usage: pmml.py <input model file>')
sys.exit(0)
# open the model file and then process it # open the model file and then process it
with open(sys.argv[1], 'r') as model_in: with open(argv[1], 'r') as model_in:
model_content = [l for l in model_in.read().splitlines() if l] # ignore first 6 and empty lines
model_content = iter([line for line in model_in.read().splitlines() if line][6:])
objective = get_value_string(model_content[4]) feature_names = get_array_strings(next(model_content))
sigmoid = Decimal(get_value_string(model_content[5])) segment_id = count(1)
feature_names = get_array_strings(model_content[6])
model_content = model_content[7:]
segment_id = 1
with open('LightGBM_pmml.xml', 'w') as pmml_out: with open('LightGBM_pmml.xml', 'w') as pmml_out:
print( def out_(string):
pmml_out.write(string + '\n')
out_(
"<PMML version=\"4.3\" \n" + "<PMML version=\"4.3\" \n" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" + "\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
"\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" + "\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
"\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\"" + "\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\">")
">", out_("\t<Header copyright=\"Microsoft\">")
file=pmml_out) out_("\t\t<Application name=\"LightGBM\"/>")
print("\t<Header copyright=\"Microsoft\">", file=pmml_out) out_("\t</Header>")
print("\t\t<Application name=\"LightGBM\"/>", file=pmml_out)
print("\t</Header>", file=pmml_out)
# print out data dictionary entries for each column # print out data dictionary entries for each column
print( out_("\t<DataDictionary numberOfFields=\"%d\">" % len(feature_names))
"\t<DataDictionary numberOfFields=\"%d\">" %
len(feature_names), file=pmml_out)
# not adding any interval definition, all values are currently # not adding any interval definition, all values are currently
# valid # valid
for feature in feature_names: for feature in feature_names:
print( out_("\t\t<DataField name=\"" + feature + "\" optype=\"continuous\" dataType=\"double\"/>")
"\t\t<DataField name=\"" + out_("\t</DataDictionary>")
feature + out_("\t<MiningModel functionName=\"regression\">")
"\" optype=\"continuous\" dataType=\"double\"/>", out_("\t\t<MiningSchema>")
file=pmml_out)
print("\t</DataDictionary>", file=pmml_out)
print("\t<MiningModel functionName=\"regression\">", file=pmml_out)
print("\t\t<MiningSchema>", file=pmml_out)
# list each feature name as a mining field, and treat all outliers # list each feature name as a mining field, and treat all outliers
# as is, unless specified # as is, unless specified
for feature in feature_names: for feature in feature_names:
print( out_("\t\t\t<MiningField name=\"%s\"/>" % (feature))
"\t\t\t<MiningField name=\"%s\"/>" % out_("\t\t</MiningSchema>")
(feature), file=pmml_out) out_("\t\t<Segmentation multipleModelMethod=\"sum\">")
print("\t\t</MiningSchema>", file=pmml_out)
print(
"\t\t<Segmentation multipleModelMethod=\"sum\">",
file=pmml_out)
# read each array that contains pertinent information for the pmml # read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision # these arrays will be used to recreate the traverse the decision tree
# tree while True:
model_content = iter(model_content) tree_start = next(model_content, '')
tree_start = next(model_content) if not tree_start.startswith('Tree'):
while tree_start[:4] == 'Tree': break
print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out) out_("\t\t\t<Segment id=\"%d\">" % next(segment_id))
print("\t\t\t\t<True/>", file=pmml_out) out_("\t\t\t\t<True/>")
tree_no = tree_start[5:] tree_no = tree_start[5:]
num_leaves = int(get_value_string(next(model_content))) num_leaves = int(get_value_string(next(model_content)))
split_feature = get_array_ints(next(model_content)) split_feature = get_array_ints(next(model_content))
split_gain = next(model_content) split_gain = next(model_content) # unused
threshold = get_array_strings(next(model_content)) threshold = get_array_strings(next(model_content))
decision_type = get_array_ints(next(model_content)) decision_type = get_array_ints(next(model_content))
left_child = get_array_ints(next(model_content)) left_child = get_array_ints(next(model_content))
...@@ -197,12 +136,10 @@ with open('LightGBM_pmml.xml', 'w') as pmml_out: ...@@ -197,12 +136,10 @@ with open('LightGBM_pmml.xml', 'w') as pmml_out:
leaf_count = get_array_strings(next(model_content)) leaf_count = get_array_strings(next(model_content))
internal_value = get_array_strings(next(model_content)) internal_value = get_array_strings(next(model_content))
internal_count = get_array_strings(next(model_content)) internal_count = get_array_strings(next(model_content))
tree_start = next(model_content) unique_id = count(1)
unique_id = itertools.count(1) print_pmml()
print_pmml(pmml_out) out_("\t\t\t</Segment>")
print("\t\t\t</Segment>", file=pmml_out)
segment_id += 1 out_("\t\t</Segmentation>")
out_("\t</MiningModel>")
print("\t\t</Segmentation>", file=pmml_out) out_("</PMML>")
print("\t</MiningModel>", file=pmml_out)
print("</PMML>", file=pmml_out)
...@@ -48,6 +48,8 @@ class TestBasic(unittest.TestCase): ...@@ -48,6 +48,8 @@ class TestBasic(unittest.TestCase):
self.assertEqual(len(pred_from_matr), len(pred_from_model_file)) self.assertEqual(len(pred_from_matr), len(pred_from_model_file))
for preds in zip(pred_from_matr, pred_from_model_file): for preds in zip(pred_from_matr, pred_from_model_file):
self.assertEqual(*preds) self.assertEqual(*preds)
# check pmml
os.system('python ../../pmml/pmml.py model.txt')
print("----------------------------------------------------------------------") print("----------------------------------------------------------------------")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment