Commit 81a6a442 authored by Qiwei Ye's avatar Qiwei Ye Committed by GitHub
Browse files

Merge pull request #172 from rmhasan/master

Removing the string to decimal conversion for float values
parents a87af879 fa15332e
from __future__ import print_function
from builtins import map
from builtins import next
from decimal import Decimal
import sys
import os
import traceback
def unique_id():
global unique_node_id
nid = unique_node_id
unique_node_id += 1
return nid
import itertools
def get_value_string(line):
......@@ -22,11 +18,7 @@ def get_array_strings(line):
def get_array_ints(line):
return map(lambda x: int(x), line[line.index('=') + 1:].split())
def get_array_floats(line):
return map(lambda x: Decimal(x), line[line.index('=') + 1:].split())
return list(map(int, line[line.index('=') + 1:].split()))
def get_field_name(node_id, prev_node_idx, is_child):
......@@ -73,7 +65,7 @@ def print_nodes_pmml(**kwargs):
(
"<Node id=\"{0}\" score=\"{1}\" " +
" recordCount=\"{2}\">").format(
unique_id(),
next(unique_id),
score,
recordCount),
file=pmml_out)
......@@ -116,8 +108,8 @@ def print_pmml(pmml_out):
(feature), file=pmml_out)
print("\t\t\t\t\t</MiningSchema>", file=pmml_out)
# begin printing out the decision tree
print("\t\t\t\t\t<Node id=\"%d\" score=\"%s\" recordCount=\"%d\">" %
(unique_id(), internal_value[0], internal_count[0]), file=pmml_out)
print("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
next(unique_id), internal_value[0], internal_count[0]), file=pmml_out)
print("\t\t\t\t\t\t<True/>", file=pmml_out)
print_nodes_pmml(
node_id=left_child[0],
......@@ -139,19 +131,16 @@ if len(sys.argv) != 2:
sys.exit(0)
# open the model file and then process it
try:
with open(sys.argv[1]) as model_in:
model_content = filter(
lambda line: line != '',
model_in.read().strip().split('\n'))
objective = get_value_string(model_content[4])
sigmoid = Decimal(get_value_string(model_content[5]))
feature_names = get_array_strings(model_content[6])
model_content = model_content[7:]
line_no = 0
segment_id = 1
with open('LightGBM_pmml.xml', 'w') as pmml_out:
with open(sys.argv[1], 'r') as model_in:
model_content = [l for l in model_in.read().splitlines() if l]
objective = get_value_string(model_content[4])
sigmoid = Decimal(get_value_string(model_content[5]))
feature_names = get_array_strings(model_content[6])
model_content = model_content[7:]
segment_id = 1
with open('LightGBM_pmml.xml', 'w') as pmml_out:
print(
"<PMML version=\"4.3\" \n" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
......@@ -190,29 +179,30 @@ try:
# read each array that contains pertinent information for the pmml
# these arrays will be used to recreate the traverse the decision
# tree
while model_content[line_no][:4] == 'Tree':
model_content = iter(model_content)
tree_start = next(model_content)
while tree_start[:4] == 'Tree':
print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out)
print("\t\t\t\t<True/>", file=pmml_out)
tree_no = model_content[line_no][5:]
num_leaves = int(get_value_string(model_content[line_no + 1]))
split_feature = get_array_ints(model_content[line_no + 2])
threshold = get_array_floats(model_content[line_no + 4])
decision_type = get_array_ints(model_content[line_no + 5])
left_child = get_array_ints(model_content[line_no + 6])
right_child = get_array_ints(model_content[line_no + 7])
leaf_parent = get_array_ints(model_content[line_no + 8])
leaf_value = get_array_floats(model_content[line_no + 9])
leaf_count = get_array_ints(model_content[line_no + 10])
internal_value = get_array_floats(model_content[line_no + 11])
internal_count = get_array_ints(model_content[line_no + 12])
unique_node_id = 0
tree_no = tree_start[5:]
num_leaves = int(get_value_string(next(model_content)))
split_feature = get_array_ints(next(model_content))
split_gain = next(model_content)
threshold = get_array_strings(next(model_content))
decision_type = get_array_ints(next(model_content))
left_child = get_array_ints(next(model_content))
right_child = get_array_ints(next(model_content))
leaf_parent = get_array_ints(next(model_content))
leaf_value = get_array_strings(next(model_content))
leaf_count = get_array_strings(next(model_content))
internal_value = get_array_strings(next(model_content))
internal_count = get_array_strings(next(model_content))
tree_start = next(model_content)
unique_id = itertools.count(1)
print_pmml(pmml_out)
print("\t\t\t</Segment>", file=pmml_out)
line_no += 13
segment_id += 1
print("\t\t</Segmentation>", file=pmml_out)
print("\t</MiningModel>", file=pmml_out)
print("</PMML>", file=pmml_out)
except Exception as ioex:
print(ioex)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment