pmml.py 5.53 KB
Newer Older
wxchan's avatar
wxchan committed
1
2
3
4
# coding: utf-8
# pylint: disable = C0111, C0103
"""convert LightGBM model to pmml"""
from __future__ import absolute_import
5

wxchan's avatar
wxchan committed
6
7
from sys import argv
from itertools import count
8
9
10


def get_value_string(line):
wxchan's avatar
wxchan committed
11
    return line[line.find('=') + 1:]
12
13
14


def get_array_strings(line):
wxchan's avatar
wxchan committed
15
    return get_value_string(line).split()
16
17
18


def get_array_ints(line):
wxchan's avatar
wxchan committed
19
    return [int(token) for token in get_array_strings(line)]
20
21
22


def get_field_name(node_id, prev_node_idx, is_child):
wxchan's avatar
wxchan committed
23
    idx = leaf_parent[node_id] if is_child else prev_node_idx
24
25
26
27
    return feature_names[split_feature[idx]]


def get_threshold(node_id, prev_node_idx, is_child):
wxchan's avatar
wxchan committed
28
    idx = leaf_parent[node_id] if is_child else prev_node_idx
29
30
31
    return threshold[idx]


wxchan's avatar
wxchan committed
32
def print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf):
33
    if is_left_child:
Guolin Ke's avatar
Guolin Ke committed
34
        op = 'lessOrEqual'
35
    else:
Guolin Ke's avatar
Guolin Ke committed
36
        op = 'greaterThan'
wxchan's avatar
wxchan committed
37
38
    out_('\t' * (tab_len + 1) + ("<SimplePredicate field=\"{0}\" " + " operator=\"{1}\" value=\"{2}\" />").format(
        get_field_name(node_id, prev_node_idx, is_leaf), op, get_threshold(node_id, prev_node_idx, is_leaf)))
39
40


wxchan's avatar
wxchan committed
41
def print_nodes_pmml(node_id, tab_len, is_left_child, prev_node_idx):
42
    if node_id < 0:
wxchan's avatar
wxchan committed
43
44
45
        node_id = ~node_id
        score = leaf_value[node_id]
        recordCount = leaf_count[node_id]
46
47
48
49
50
        is_leaf = True
    else:
        score = internal_value[node_id]
        recordCount = internal_count[node_id]
        is_leaf = False
wxchan's avatar
wxchan committed
51
52
53
    out_('\t' * tab_len + ("<Node id=\"{0}\" score=\"{1}\" " + " recordCount=\"{2}\">").format(
        next(unique_id), score, recordCount))
    print_simple_predicate(tab_len, node_id, is_left_child, prev_node_idx, is_leaf)
54
    if not is_leaf:
wxchan's avatar
wxchan committed
55
56
57
        print_nodes_pmml(left_child[node_id], tab_len + 1, True, node_id)
        print_nodes_pmml(right_child[node_id], tab_len + 1, False, node_id)
    out_('\t' * tab_len + "</Node>")
58
59
60


# print out the pmml for a decision tree
wxchan's avatar
wxchan committed
61
def print_pmml():
62
63
    # specify the objective as function name and binarySplit for
    # splitCharacteristic because each node has 2 children
wxchan's avatar
wxchan committed
64
65
    out_("\t\t\t\t<TreeModel functionName=\"regression\" splitCharacteristic=\"binarySplit\">")
    out_("\t\t\t\t\t<MiningSchema>")
66
67
68
    # list each feature name as a mining field, and treat all outliers as is,
    # unless specified
    for feature in feature_names:
wxchan's avatar
wxchan committed
69
70
        out_("\t\t\t\t\t\t<MiningField name=\"%s\"/>" % (feature))
    out_("\t\t\t\t\t</MiningSchema>")
71
    # begin printing out the decision tree
wxchan's avatar
wxchan committed
72
73
74
75
76
77
78
79
80
81
82
    out_("\t\t\t\t\t<Node id=\"{0}\" score=\"{1}\" recordCount=\"{2}\">".format(
        next(unique_id), internal_value[0], internal_count[0]))
    out_("\t\t\t\t\t\t<True/>")
    print_nodes_pmml(left_child[0], 6, True, 0)
    print_nodes_pmml(right_child[0], 6, False, 0)
    out_("\t\t\t\t\t</Node>")
    out_("\t\t\t\t</TreeModel>")


if len(argv) != 2:
    raise ValueError('usage: pmml.py <input model file>')
83
84

# open the model file and then process it
wxchan's avatar
wxchan committed
85
86
87
with open(argv[1], 'r') as model_in:
    # ignore first 6 and empty lines
    model_content = iter([line for line in model_in.read().splitlines() if line][6:])
88

wxchan's avatar
wxchan committed
89
90
feature_names = get_array_strings(next(model_content))
segment_id = count(1)
91
92

with open('LightGBM_pmml.xml', 'w') as pmml_out:
wxchan's avatar
wxchan committed
93
94
95
    def out_(string):
        pmml_out.write(string + '\n')
    out_(
96
97
98
        "<PMML version=\"4.3\" \n" +
        "\t\txmlns=\"http://www.dmg.org/PMML-4_3\"\n" +
        "\t\txmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
wxchan's avatar
wxchan committed
99
100
101
102
        "\t\txsi:schemaLocation=\"http://www.dmg.org/PMML-4_3 http://dmg.org/pmml/v4-3/pmml-4-3.xsd\">")
    out_("\t<Header copyright=\"Microsoft\">")
    out_("\t\t<Application name=\"LightGBM\"/>")
    out_("\t</Header>")
103
    # print out data dictionary entries for each column
wxchan's avatar
wxchan committed
104
    out_("\t<DataDictionary numberOfFields=\"%d\">" % len(feature_names))
105
106
107
    # not adding any interval definition, all values are currently
    # valid
    for feature in feature_names:
wxchan's avatar
wxchan committed
108
109
110
111
        out_("\t\t<DataField name=\"" + feature + "\" optype=\"continuous\" dataType=\"double\"/>")
    out_("\t</DataDictionary>")
    out_("\t<MiningModel functionName=\"regression\">")
    out_("\t\t<MiningSchema>")
112
113
114
    # list each feature name as a mining field, and treat all outliers
    # as is, unless specified
    for feature in feature_names:
wxchan's avatar
wxchan committed
115
116
117
        out_("\t\t\t<MiningField name=\"%s\"/>" % (feature))
    out_("\t\t</MiningSchema>")
    out_("\t\t<Segmentation multipleModelMethod=\"sum\">")
118
    # read each array that contains pertinent information for the pmml
wxchan's avatar
wxchan committed
119
120
121
122
123
124
125
    # these arrays will be used to recreate the traverse the decision tree
    while True:
        tree_start = next(model_content, '')
        if not tree_start.startswith('Tree'):
            break
        out_("\t\t\t<Segment id=\"%d\">" % next(segment_id))
        out_("\t\t\t\t<True/>")
126
127
128
        tree_no = tree_start[5:]
        num_leaves = int(get_value_string(next(model_content)))
        split_feature = get_array_ints(next(model_content))
wxchan's avatar
wxchan committed
129
        split_gain = next(model_content)  # unused
130
131
132
133
134
135
136
137
        threshold = get_array_strings(next(model_content))
        left_child = get_array_ints(next(model_content))
        right_child = get_array_ints(next(model_content))
        leaf_parent = get_array_ints(next(model_content))
        leaf_value = get_array_strings(next(model_content))
        leaf_count = get_array_strings(next(model_content))
        internal_value = get_array_strings(next(model_content))
        internal_count = get_array_strings(next(model_content))
wxchan's avatar
wxchan committed
138
139
140
141
142
143
144
        unique_id = count(1)
        print_pmml()
        out_("\t\t\t</Segment>")

    out_("\t\t</Segmentation>")
    out_("\t</MiningModel>")
    out_("</PMML>")