Commit fa15332e authored by Rakib Hasan's avatar Rakib Hasan
Browse files

Removing the string to decimal conversion for float

We dont need it we aren't doing any computation with
float values. We print out whatever values are
read from the LightGBM_model.txt as a string.
parent 1b7643ba
PMML Generator PMML Generator
============== ==============
The script pmml.py can be used to translate the LightGBM models, found in LightGBM_model.txt, to predictive model markup language (PMML). These models can then be imported by other analytics applications. The models that the language can describe includes decision trees. The specification of PMML can be found here at the Data Mining Group's [website](http://dmg.org/pmml/v4-3/GeneralStructure.html). The script pmml.py can be used to translate the LightGBM models, found in LightGBM_model.txt, to predictive model markup language (PMML). These models can then be imported by other analytics applications. The models that the language can describe includes decision trees. The specification of PMML can be found here at the Data Mining Group's [website](http://dmg.org/pmml/v4-3/GeneralStructure.html).
In order to generate pmml files do the following steps. In order to generate pmml files do the following steps.
``` ```
......
from __future__ import print_function from __future__ import print_function
from builtins import map
from builtins import next
from decimal import Decimal from decimal import Decimal
import sys import sys
import os import os
import traceback import traceback
import itertools
def unique_id():
global unique_node_id
nid = unique_node_id
unique_node_id += 1
return nid
def get_value_string(line): def get_value_string(line):
...@@ -22,11 +18,7 @@ def get_array_strings(line): ...@@ -22,11 +18,7 @@ def get_array_strings(line):
def get_array_ints(line): def get_array_ints(line):
return map(lambda x: int(x), line[line.index('=') + 1:].split()) return list(map(int, line[line.index('=') + 1:].split()))
def get_array_floats(line):
return map(lambda x: Decimal(x), line[line.index('=') + 1:].split())
def get_field_name(node_id, prev_node_idx, is_child): def get_field_name(node_id, prev_node_idx, is_child):
...@@ -73,7 +65,7 @@ def print_nodes_pmml(**kwargs): ...@@ -73,7 +65,7 @@ def print_nodes_pmml(**kwargs):
( (
"<Node id=\"{0}\" score=\"{1}\" " + "<Node id=\"{0}\" score=\"{1}\" " +
" recordCount=\"{2}\">").format( " recordCount=\"{2}\">").format(
unique_id(), next(unique_id),
score, score,
recordCount), recordCount),
file=pmml_out) file=pmml_out)
...@@ -116,8 +108,8 @@ def print_pmml(pmml_out): ...@@ -116,8 +108,8 @@ def print_pmml(pmml_out):
(feature), file=pmml_out) (feature), file=pmml_out)
print("\t\t\t\t\t</MiningSchema>", file=pmml_out) print("\t\t\t\t\t</MiningSchema>", file=pmml_out)
# begin printing out the decision tree # begin printing out the decision tree
print("\t\t\t\t\t<Node id=\"%d\" score=\"%s\" recordCount=\"%d\">" % print("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
(unique_id(), internal_value[0], internal_count[0]), file=pmml_out) next(unique_id), internal_value[0], internal_count[0]), file=pmml_out)
print("\t\t\t\t\t\t<True/>", file=pmml_out) print("\t\t\t\t\t\t<True/>", file=pmml_out)
print_nodes_pmml( print_nodes_pmml(
node_id=left_child[0], node_id=left_child[0],
...@@ -139,80 +131,78 @@ if len(sys.argv) != 2: ...@@ -139,80 +131,78 @@ if len(sys.argv) != 2:
sys.exit(0) sys.exit(0)
# open the model file and then process it # open the model file and then process it
try: with open(sys.argv[1], 'r') as model_in:
with open(sys.argv[1]) as model_in: model_content = [l for l in model_in.read().splitlines() if l]
model_content = filter(
lambda line: line != '', objective = get_value_string(model_content[4])
model_in.read().strip().split('\n')) sigmoid = Decimal(get_value_string(model_content[5]))
objective = get_value_string(model_content[4]) feature_names = get_array_strings(model_content[6])
sigmoid = Decimal(get_value_string(model_content[5])) model_content = model_content[7:]
feature_names = get_array_strings(model_content[6]) segment_id = 1
model_content = model_content[7:]
line_no = 0 with open('LightGBM_pmml.xml', 'w') as pmml_out:
segment_id = 1 print(
"<PMML version=\"4.3\" \n" +
with open('LightGBM_pmml.xml', 'w') as pmml_out: "\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
print( "\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
"<PMML version=\"4.3\" \n" + "\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\"" +
"\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" + ">",
"\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" + file=pmml_out)
"\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\"" + print("\t<Header copyright=\"Microsoft\">", file=pmml_out)
">", print("\t\t<Application name=\"LightGBM\"/>", file=pmml_out)
file=pmml_out) print("\t</Header>", file=pmml_out)
print("\t<Header copyright=\"Microsoft\">", file=pmml_out) # print out data dictionary entries for each column
print("\t\t<Application name=\"LightGBM\"/>", file=pmml_out) print(
print("\t</Header>", file=pmml_out) "\t<DataDictionary numberOfFields=\"%d\">" %
# print out data dictionary entries for each column len(feature_names), file=pmml_out)
print( # not adding any interval definition, all values are currently
"\t<DataDictionary numberOfFields=\"%d\">" % # valid
len(feature_names), file=pmml_out) for feature in feature_names:
# not adding any interval definition, all values are currently print(
# valid "\t\t<DataField name=\"" +
for feature in feature_names: feature +
print( "\" optype=\"continuous\" dataType=\"double\"/>",
"\t\t<DataField name=\"" + file=pmml_out)
feature + print("\t</DataDictionary>", file=pmml_out)
"\" optype=\"continuous\" dataType=\"double\"/>", print("\t<MiningModel functionName=\"regression\">", file=pmml_out)
file=pmml_out) print("\t\t<MiningSchema>", file=pmml_out)
print("\t</DataDictionary>", file=pmml_out) # list each feature name as a mining field, and treat all outliers
print("\t<MiningModel functionName=\"regression\">", file=pmml_out) # as is, unless specified
print("\t\t<MiningSchema>", file=pmml_out) for feature in feature_names:
# list each feature name as a mining field, and treat all outliers print(
# as is, unless specified "\t\t\t<MiningField name=\"%s\"/>" %
for feature in feature_names: (feature), file=pmml_out)
print( print("\t\t</MiningSchema>", file=pmml_out)
"\t\t\t<MiningField name=\"%s\"/>" % print(
(feature), file=pmml_out) "\t\t<Segmentation multipleModelMethod=\"sum\">",
print("\t\t</MiningSchema>", file=pmml_out) file=pmml_out)
print( # read each array that contains pertinent information for the pmml
"\t\t<Segmentation multipleModelMethod=\"sum\">", # these arrays will be used to recreate the traverse the decision
file=pmml_out) # tree
# read each array that contains pertinent information for the pmml model_content = iter(model_content)
# these arrays will be used to recreate the traverse the decision tree_start = next(model_content)
# tree while tree_start[:4] == 'Tree':
while model_content[line_no][:4] == 'Tree': print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out)
print("\t\t\t<Segment id=\"%d\">" % segment_id, file=pmml_out) print("\t\t\t\t<True/>", file=pmml_out)
print("\t\t\t\t<True/>", file=pmml_out) tree_no = tree_start[5:]
tree_no = model_content[line_no][5:] num_leaves = int(get_value_string(next(model_content)))
num_leaves = int(get_value_string(model_content[line_no + 1])) split_feature = get_array_ints(next(model_content))
split_feature = get_array_ints(model_content[line_no + 2]) split_gain = next(model_content)
threshold = get_array_floats(model_content[line_no + 4]) threshold = get_array_strings(next(model_content))
decision_type = get_array_ints(model_content[line_no + 5]) decision_type = get_array_ints(next(model_content))
left_child = get_array_ints(model_content[line_no + 6]) left_child = get_array_ints(next(model_content))
right_child = get_array_ints(model_content[line_no + 7]) right_child = get_array_ints(next(model_content))
leaf_parent = get_array_ints(model_content[line_no + 8]) leaf_parent = get_array_ints(next(model_content))
leaf_value = get_array_floats(model_content[line_no + 9]) leaf_value = get_array_strings(next(model_content))
leaf_count = get_array_ints(model_content[line_no + 10]) leaf_count = get_array_strings(next(model_content))
internal_value = get_array_floats(model_content[line_no + 11]) internal_value = get_array_strings(next(model_content))
internal_count = get_array_ints(model_content[line_no + 12]) internal_count = get_array_strings(next(model_content))
unique_node_id = 0 tree_start = next(model_content)
print_pmml(pmml_out) unique_id = itertools.count(1)
print("\t\t\t</Segment>", file=pmml_out) print_pmml(pmml_out)
line_no += 13 print("\t\t\t</Segment>", file=pmml_out)
segment_id += 1 segment_id += 1
print("\t\t</Segmentation>", file=pmml_out) print("\t\t</Segmentation>", file=pmml_out)
print("\t</MiningModel>", file=pmml_out) print("\t</MiningModel>", file=pmml_out)
print("</PMML>", file=pmml_out) print("</PMML>", file=pmml_out)
except Exception as ioex:
print(ioex)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment