Commit 81a6a442 authored by Qiwei Ye's avatar Qiwei Ye Committed by GitHub
Browse files

Merge pull request #172 from rmhasan/master

Removing the string to decimal conversion for float values
parents a87af879 fa15332e
from __future__ import print_function from __future__ import print_function
from builtins import map
from builtins import next
from decimal import Decimal from decimal import Decimal
import sys import sys
import os import os
import traceback import traceback
import itertools
def unique_id():
global unique_node_id
nid = unique_node_id
unique_node_id += 1
return nid
def get_value_string(line): def get_value_string(line):
...@@ -22,11 +18,7 @@ def get_array_strings(line): ...@@ -22,11 +18,7 @@ def get_array_strings(line):
def get_array_ints(line): def get_array_ints(line):
return map(lambda x: int(x), line[line.index('=') + 1:].split()) return list(map(int, line[line.index('=') + 1:].split()))
def get_array_floats(line):
return map(lambda x: Decimal(x), line[line.index('=') + 1:].split())
def get_field_name(node_id, prev_node_idx, is_child): def get_field_name(node_id, prev_node_idx, is_child):
...@@ -73,7 +65,7 @@ def print_nodes_pmml(**kwargs): ...@@ -73,7 +65,7 @@ def print_nodes_pmml(**kwargs):
( (
"<Node id=\"{0}\" score=\"{1}\" " + "<Node id=\"{0}\" score=\"{1}\" " +
" recordCount=\"{2}\">").format( " recordCount=\"{2}\">").format(
unique_id(), next(unique_id),
score, score,
recordCount), recordCount),
file=pmml_out) file=pmml_out)
...@@ -116,8 +108,8 @@ def print_pmml(pmml_out): ...@@ -116,8 +108,8 @@ def print_pmml(pmml_out):
(feature), file=pmml_out) (feature), file=pmml_out)
print("\t\t\t\t\t</MiningSchema>", file=pmml_out) print("\t\t\t\t\t</MiningSchema>", file=pmml_out)
# begin printing out the decision tree # begin printing out the decision tree
print("\t\t\t\t\t<Node id=\"%d\" score=\"%s\" recordCount=\"%d\">" % print("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
(unique_id(), internal_value[0], internal_count[0]), file=pmml_out) next(unique_id), internal_value[0], internal_count[0]), file=pmml_out)
print("\t\t\t\t\t\t<True/>", file=pmml_out) print("\t\t\t\t\t\t<True/>", file=pmml_out)
print_nodes_pmml( print_nodes_pmml(
node_id=left_child[0], node_id=left_child[0],
...@@ -139,19 +131,16 @@ if len(sys.argv) != 2: ...@@ -139,19 +131,16 @@ if len(sys.argv) != 2:
sys.exit(0) sys.exit(0)
# open the model file and then process it # open the model file and then process it
try: with open(sys.argv[1], 'r') as model_in:
with open(sys.argv[1]) as model_in: model_content = [l for l in model_in.read().splitlines() if l]
model_content = filter(
lambda line: line != '', objective = get_value_string(model_content[4])
model_in.read().strip().split('\n')) sigmoid = Decimal(get_value_string(model_content[5]))
objective = get_value_string(model_content[4]) feature_names = get_array_strings(model_content[6])
sigmoid = Decimal(get_value_string(model_content[5])) model_content = model_content[7:]
feature_names = get_array_strings(model_content[6]) segment_id = 1
model_content = model_content[7:]
line_no = 0 with open('LightGBM_pmml.xml', 'w') as pmml_out:
segment_id = 1
with open('LightGBM_pmml.xml', 'w') as pmml_out:
print( print(
"<PMML version=\"4.3\" \n" + "<PMML version=\"4.3\" \n" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" + "\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
...@@ -190,29 +179,30 @@ try: ...@@ -190,29 +179,30 @@ try:
# read each array that contains pertinent information for the pmml # read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision # these arrays will be used to recreate the traverse the decision
# tree # tree
while model_content[line_no][:4] == 'Tree': model_content = iter(model_content)
tree_start = next(model_content)
while tree_start[:4] == 'Tree':
print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out) print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out)
print("\t\t\t\t<True/>", file=pmml_out) print("\t\t\t\t<True/>", file=pmml_out)
tree_no = model_content[line_no][5:] tree_no = tree_start[5:]
num_leaves = int(get_value_string(model_content[line_no + 1])) num_leaves = int(get_value_string(next(model_content)))
split_feature = get_array_ints(model_content[line_no + 2]) split_feature = get_array_ints(next(model_content))
threshold = get_array_floats(model_content[line_no + 4]) split_gain = next(model_content)
decision_type = get_array_ints(model_content[line_no + 5]) threshold = get_array_strings(next(model_content))
left_child = get_array_ints(model_content[line_no + 6]) decision_type = get_array_ints(next(model_content))
right_child = get_array_ints(model_content[line_no + 7]) left_child = get_array_ints(next(model_content))
leaf_parent = get_array_ints(model_content[line_no + 8]) right_child = get_array_ints(next(model_content))
leaf_value = get_array_floats(model_content[line_no + 9]) leaf_parent = get_array_ints(next(model_content))
leaf_count = get_array_ints(model_content[line_no + 10]) leaf_value = get_array_strings(next(model_content))
internal_value = get_array_floats(model_content[line_no + 11]) leaf_count = get_array_strings(next(model_content))
internal_count = get_array_ints(model_content[line_no + 12]) internal_value = get_array_strings(next(model_content))
unique_node_id = 0 internal_count = get_array_strings(next(model_content))
tree_start = next(model_content)
unique_id = itertools.count(1)
print_pmml(pmml_out) print_pmml(pmml_out)
print("\t\t\t</Segment>", file=pmml_out) print("\t\t\t</Segment>", file=pmml_out)
line_no += 13
segment_id += 1 segment_id += 1
print("\t\t</Segmentation>", file=pmml_out) print("\t\t</Segmentation>", file=pmml_out)
print("\t</MiningModel>", file=pmml_out) print("\t</MiningModel>", file=pmml_out)
print("</PMML>", file=pmml_out) print("</PMML>", file=pmml_out)
except Exception as ioex:
print(ioex)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment