lightgbmlib.i 10 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2018 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
7
/* lightgbmlib.i */
%module lightgbmlib
%ignore LGBM_BoosterSaveModelToString;
8
%ignore LGBM_BoosterGetEvalNames;
9
10
11
12
%{
/* Includes the header in the wrapper code */
#include "../include/LightGBM/export.h"
#include "../include/LightGBM/utils/log.h"
13
#include "../include/LightGBM/utils/common.h"
14
15
16
#include "../include/LightGBM/c_api.h"
%}

17
18
19
20
21
22
23
24
25
%include "various.i"
%include "carrays.i"
%include "cpointer.i"

/* Note: instead of using array_functions for string array we apply a typemap instead.
   Future char** parameter names should be added to the typemap.
*/
%apply char **STRING_ARRAY { char **feature_names, char **out_strs }

26
27
28
29
/* header files */
%include "../include/LightGBM/export.h"
%include "../include/LightGBM/c_api.h"

30
31
%typemap(in, numinputs = 0) JNIEnv *jenv %{
  $1 = jenv;
32
33
%}

34
35
%inline %{
  char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle,
36
37
38
39
                                           int start_iteration,
                                           int num_iteration,
                                           int64_t buffer_len,
                                           int64_t* out_len) {
40
    char* dst = new char[buffer_len];
41
    int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    // Reallocate to use larger length
    if (*out_len > buffer_len) {
      delete [] dst;
      int64_t realloc_len = *out_len;
      dst = new char[realloc_len];
      result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
    }
    if (result != 0) {
      return nullptr;
    }
    return dst;
  }

  char ** LGBM_BoosterGetEvalNamesSWIG(BoosterHandle handle,
                                       int eval_counts) {
    char** dst = new char*[eval_counts];
    for (int i = 0; i < eval_counts; ++i) {
      dst[i] = new char[128];
    }
    int result = LGBM_BoosterGetEvalNames(handle, &eval_counts, dst);
62
63
64
65
    if (result != 0) {
      return nullptr;
    }
    return dst;
66
  }
67

68
  int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
69
70
71
72
73
74
75
76
77
78
79
                                      jdoubleArray data,
                                      BoosterHandle handle,
                                      int data_type,
                                      int ncol,
                                      int is_row_major,
                                      int predict_type,
                                      int num_iteration,
                                      const char* parameter,
                                      int64_t* out_len,
                                      double* out_result) {
    double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
80

81
    int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type,
82
                                                 num_iteration, parameter, out_len, out_result);
83

84
    jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
85

86
87
    return ret;
  }
88

89
  int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
90
91
92
93
94
95
96
97
98
99
100
101
102
                                      jintArray indices,
                                      jdoubleArray values,
                                      int numNonZeros,
                                      BoosterHandle handle,
                                      int indptr_type,
                                      int data_type,
                                      int64_t nelem,
                                      int64_t num_col,
                                      int predict_type,
                                      int num_iteration,
                                      const char* parameter,
                                      int64_t* out_len,
                                      double* out_result) {
103
104
105
106
107
    // Alternatives
    // - GetIntArrayElements: performs copy
    // - GetDirectBufferAddress: fails on wrapped array
    // Some words of warning for GetPrimitiveArrayCritical
    // https://stackoverflow.com/questions/23258357/whats-the-trade-off-between-using-getprimitivearraycritical-and-getprimitivety
108

109
110
111
    jboolean isCopy;
    int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy);
    double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy);
112

113
    int32_t ind[2] = { 0, numNonZeros };
114

115
    int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
116
117
                                                 nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result);

118
119
    jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
    jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
120

121
122
    return ret;
  }
123

124
  #include <functional>
125
126
  #include <vector>

127
128
129
130
131
132
133
  struct CSRDirect {
          jintArray indices;
          jdoubleArray values;
          int* indices0;
          double* values0;
          int size;
  };
134

135
136
137
138
139
140
141
  int LGBM_DatasetCreateFromCSRSpark(JNIEnv *jenv,
                                     jobjectArray arrayOfSparseVector,
                                     int num_rows,
                                     int64_t num_col,
                                     const char* parameters,
                                     const DatasetHandle reference,
                                     DatasetHandle* out) {
142
143
144
145
146
147
148
149
150
151
152
    jclass sparseVectorClass = jenv->FindClass("org/apache/spark/ml/linalg/SparseVector");
    jmethodID sparseVectorIndices = jenv->GetMethodID(sparseVectorClass, "indices", "()[I");
    jmethodID sparseVectorValues = jenv->GetMethodID(sparseVectorClass, "values", "()[D");

    std::vector<CSRDirect> jniCache;
    jniCache.reserve(num_rows);

    // this needs to be done ahead of time as row_func is invoked from multiple threads
    // these threads would have to be registered with the JVM and also unregistered.
    // It is not clear if that can be achieved with OpenMP
    for (int i = 0; i < num_rows; i++) {
153
      // get the row
154
155
      jobject objSparseVec = jenv->GetObjectArrayElement(arrayOfSparseVector, i);

156
157
158
159
      // get the size, indices and values
      auto indices = (jintArray)jenv->CallObjectMethod(objSparseVec, sparseVectorIndices);
      auto values = (jdoubleArray)jenv->CallObjectMethod(objSparseVec, sparseVectorValues);
      int size = jenv->GetArrayLength(indices);
160

161
162
163
164
165
166
      // Note: when testing on larger data (e.g. 288k rows per partition and 36mio rows total)
      // using GetPrimitiveArrayCritical resulted in a dead-lock
      // lock arrays
      // int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, 0);
      // double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, 0);
      // in test-usecase an alternative to GetPrimitiveArrayCritical as it performs copies
167
      int* indices0 = (int *)jenv->GetIntArrayElements(indices, 0);
168
      double* values0 = jenv->GetDoubleArrayElements(values, 0);
169

170
      jniCache.push_back({indices, values, indices0, values0, size});
171
172
173
174
    }

    // type is important here as we want a std::function, rather than a lambda
    std::function<void(int idx, std::vector<std::pair<int, double>>& ret)> row_func = [&](int row_num, std::vector<std::pair<int, double>>& ret) {
175
      auto& jc = jniCache[row_num];
176
177
178
      ret.clear();  // reset size, but not free()
      ret.reserve(jc.size);  // make sure we have enough allocated

179
180
181
182
      // copy data
      int* indices0p = jc.indices0;
      double* values0p = jc.values0;
      int* indices0e = indices0p + jc.size;
183

184
185
      for (; indices0p != indices0e; ++indices0p, ++values0p)
        ret.emplace_back(*indices0p, *values0p);
186
187
188
189
190
    };

    int ret = LGBM_DatasetCreateFromCSRFunc(&row_func, num_rows, num_col, parameters, reference, out);

    for (auto& jc : jniCache) {
191
192
193
      // jenv->ReleasePrimitiveArrayCritical(jc.values, jc.values0, JNI_ABORT);
      // jenv->ReleasePrimitiveArrayCritical(jc.indices, jc.indices0, JNI_ABORT);
      jenv->ReleaseDoubleArrayElements(jc.values, jc.values0, JNI_ABORT);
194
      jenv->ReleaseIntArrayElements(jc.indices, (jint *)jc.indices0, JNI_ABORT);
195
196
197
    }

    return ret;
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  }
%}

%pointer_functions(int, intp)
%pointer_functions(long, longp)
%pointer_functions(double, doublep)
%pointer_functions(float, floatp)
%pointer_functions(int64_t, int64_tp)
%pointer_functions(int32_t, int32_tp)

%pointer_cast(int64_t *, long *, int64_t_to_long_ptr)
%pointer_cast(int64_t *, double *, int64_t_to_double_ptr)
%pointer_cast(int32_t *, int *, int32_t_to_int_ptr)
%pointer_cast(long *, int64_t *, long_to_int64_t_ptr)
%pointer_cast(double *, int64_t *, double_to_int64_t_ptr)
%pointer_cast(int *, int32_t *, int_to_int32_t_ptr)
214
215

%pointer_cast(double *, void *, double_to_voidp_ptr)
216
%pointer_cast(float *, void *, float_to_voidp_ptr)
217
218
219
%pointer_cast(int *, void *, int_to_voidp_ptr)
%pointer_cast(int32_t *, void *, int32_t_to_voidp_ptr)
%pointer_cast(int64_t *, void *, int64_t_to_voidp_ptr)
220
221
222
223
224

%array_functions(double, doubleArray)
%array_functions(float, floatArray)
%array_functions(int, intArray)
%array_functions(long, longArray)
225
226
227
/* Note: there is a bug in the SWIG generated string arrays when creating
   a new array with strings where the strings are prematurely deallocated
*/
228
%array_functions(char *, stringArray)
229
230

/* Custom pointer manipulation template */
231
%define %pointer_manipulation(TYPE, NAME)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
%{
  static TYPE *new_##NAME() { %}
  %{  TYPE* NAME = new TYPE; return NAME; %}
  %{}

  static void delete_##NAME(TYPE *self) { %}
  %{  if (self) delete self; %}
  %{}
  %}

TYPE *new_##NAME();
void  delete_##NAME(TYPE *self);

%enddef

247
%define %pointer_dereference(TYPE, NAME)
248
249
250
251
252
253
254
255
256
257
258
%{
  static TYPE NAME ##_value(TYPE *self) {
    TYPE NAME = *self;
    return NAME;
  }
%}

TYPE NAME##_value(TYPE *self);

%enddef

259
%define %pointer_handle(TYPE, NAME)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
%{
  static TYPE* NAME ##_handle() { %}
  %{ TYPE* NAME = new TYPE; *NAME = (TYPE)operator new(sizeof(int*)); return NAME; %}
  %{}
%}

TYPE *NAME##_handle();

%enddef

%pointer_manipulation(void*, voidpp)

/* Allow dereferencing of void** to void* */
%pointer_dereference(void*, voidpp)

/* Allow retrieving handle to void** */
%pointer_handle(void*, voidpp)