lightgbmlib.i 10.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2018 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
7
/* lightgbmlib.i */
%module lightgbmlib
%ignore LGBM_BoosterSaveModelToString;
8
%ignore LGBM_BoosterGetEvalNames;
9
10
11
12
%{
/* Includes the header in the wrapper code */
#include "../include/LightGBM/export.h"
#include "../include/LightGBM/utils/log.h"
13
#include "../include/LightGBM/utils/common.h"
14
15
16
#include "../include/LightGBM/c_api.h"
%}

17
18
19
%include "various.i"
%include "carrays.i"
%include "cpointer.i"
20
%include "stdint.i"
21
22
23
24
25
26

/* Note: instead of using array_functions for string array we apply a typemap instead.
   Future char** parameter names should be added to the typemap.
*/
%apply char **STRING_ARRAY { char **feature_names, char **out_strs }

27
28
29
30
/* header files */
%include "../include/LightGBM/export.h"
%include "../include/LightGBM/c_api.h"

31
32
%typemap(in, numinputs = 0) JNIEnv *jenv %{
  $1 = jenv;
33
34
%}

35
36
%inline %{
  char * LGBM_BoosterSaveModelToStringSWIG(BoosterHandle handle,
37
38
39
40
                                           int start_iteration,
                                           int num_iteration,
                                           int64_t buffer_len,
                                           int64_t* out_len) {
41
    char* dst = new char[buffer_len];
42
    int result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
43
44
45
46
47
48
49
50
51
52
53
54
55
    // Reallocate to use larger length
    if (*out_len > buffer_len) {
      delete [] dst;
      int64_t realloc_len = *out_len;
      dst = new char[realloc_len];
      result = LGBM_BoosterSaveModelToString(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
    }
    if (result != 0) {
      return nullptr;
    }
    return dst;
  }

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
  char * LGBM_BoosterDumpModelSWIG(BoosterHandle handle,
                                   int start_iteration,
                                   int num_iteration,
                                   int64_t buffer_len,
                                   int64_t* out_len) {
    char* dst = new char[buffer_len];
    int result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, buffer_len, out_len, dst);
    // Reallocate to use larger length
    if (*out_len > buffer_len) {
      delete [] dst;
      int64_t realloc_len = *out_len;
      dst = new char[realloc_len];
      result = LGBM_BoosterDumpModel(handle, start_iteration, num_iteration, realloc_len, out_len, dst);
    }
    if (result != 0) {
      return nullptr;
    }
    return dst;
  }

76
77
78
79
80
81
82
  char ** LGBM_BoosterGetEvalNamesSWIG(BoosterHandle handle,
                                       int eval_counts) {
    char** dst = new char*[eval_counts];
    for (int i = 0; i < eval_counts; ++i) {
      dst[i] = new char[128];
    }
    int result = LGBM_BoosterGetEvalNames(handle, &eval_counts, dst);
83
84
85
86
    if (result != 0) {
      return nullptr;
    }
    return dst;
87
  }
88

89
  int LGBM_BoosterPredictForMatSingle(JNIEnv *jenv,
90
91
92
93
94
95
96
97
98
99
100
                                      jdoubleArray data,
                                      BoosterHandle handle,
                                      int data_type,
                                      int ncol,
                                      int is_row_major,
                                      int predict_type,
                                      int num_iteration,
                                      const char* parameter,
                                      int64_t* out_len,
                                      double* out_result) {
    double* data0 = (double*)jenv->GetPrimitiveArrayCritical(data, 0);
101

102
    int ret = LGBM_BoosterPredictForMatSingleRow(handle, data0, data_type, ncol, is_row_major, predict_type,
103
                                                 num_iteration, parameter, out_len, out_result);
104

105
    jenv->ReleasePrimitiveArrayCritical(data, data0, JNI_ABORT);
106

107
108
    return ret;
  }
109

110
  int LGBM_BoosterPredictForCSRSingle(JNIEnv *jenv,
111
112
113
114
115
116
117
118
119
120
121
122
123
                                      jintArray indices,
                                      jdoubleArray values,
                                      int numNonZeros,
                                      BoosterHandle handle,
                                      int indptr_type,
                                      int data_type,
                                      int64_t nelem,
                                      int64_t num_col,
                                      int predict_type,
                                      int num_iteration,
                                      const char* parameter,
                                      int64_t* out_len,
                                      double* out_result) {
124
125
126
127
128
    // Alternatives
    // - GetIntArrayElements: performs copy
    // - GetDirectBufferAddress: fails on wrapped array
    // Some words of warning for GetPrimitiveArrayCritical
    // https://stackoverflow.com/questions/23258357/whats-the-trade-off-between-using-getprimitivearraycritical-and-getprimitivety
129

130
131
132
    jboolean isCopy;
    int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, &isCopy);
    double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, &isCopy);
133

134
    int32_t ind[2] = { 0, numNonZeros };
135

136
    int ret = LGBM_BoosterPredictForCSRSingleRow(handle, ind, indptr_type, indices0, values0, data_type, 2,
137
138
                                                 nelem, num_col, predict_type, num_iteration, parameter, out_len, out_result);

139
140
    jenv->ReleasePrimitiveArrayCritical(values, values0, JNI_ABORT);
    jenv->ReleasePrimitiveArrayCritical(indices, indices0, JNI_ABORT);
141

142
143
    return ret;
  }
144

145
  #include <functional>
146
147
  #include <vector>

148
149
150
151
152
153
154
  struct CSRDirect {
          jintArray indices;
          jdoubleArray values;
          int* indices0;
          double* values0;
          int size;
  };
155

156
157
158
159
160
161
162
  int LGBM_DatasetCreateFromCSRSpark(JNIEnv *jenv,
                                     jobjectArray arrayOfSparseVector,
                                     int num_rows,
                                     int64_t num_col,
                                     const char* parameters,
                                     const DatasetHandle reference,
                                     DatasetHandle* out) {
163
164
165
166
167
168
169
170
171
172
173
    jclass sparseVectorClass = jenv->FindClass("org/apache/spark/ml/linalg/SparseVector");
    jmethodID sparseVectorIndices = jenv->GetMethodID(sparseVectorClass, "indices", "()[I");
    jmethodID sparseVectorValues = jenv->GetMethodID(sparseVectorClass, "values", "()[D");

    std::vector<CSRDirect> jniCache;
    jniCache.reserve(num_rows);

    // this needs to be done ahead of time as row_func is invoked from multiple threads
    // these threads would have to be registered with the JVM and also unregistered.
    // It is not clear if that can be achieved with OpenMP
    for (int i = 0; i < num_rows; i++) {
174
      // get the row
175
176
      jobject objSparseVec = jenv->GetObjectArrayElement(arrayOfSparseVector, i);

177
178
      // get the size, indices and values
      auto indices = (jintArray)jenv->CallObjectMethod(objSparseVec, sparseVectorIndices);
179
180
181
      if (jenv->ExceptionCheck()) {
        return -1;
      }
182
      auto values = (jdoubleArray)jenv->CallObjectMethod(objSparseVec, sparseVectorValues);
183
184
185
      if (jenv->ExceptionCheck()) {
        return -1;
      }
186
      int size = jenv->GetArrayLength(indices);
187

188
189
190
191
192
193
      // Note: when testing on larger data (e.g. 288k rows per partition and 36mio rows total)
      // using GetPrimitiveArrayCritical resulted in a dead-lock
      // lock arrays
      // int* indices0 = (int*)jenv->GetPrimitiveArrayCritical(indices, 0);
      // double* values0 = (double*)jenv->GetPrimitiveArrayCritical(values, 0);
      // in test-usecase an alternative to GetPrimitiveArrayCritical as it performs copies
194
      int* indices0 = (int *)jenv->GetIntArrayElements(indices, 0);
195
      double* values0 = jenv->GetDoubleArrayElements(values, 0);
196

197
      jniCache.push_back({indices, values, indices0, values0, size});
198
199
200
201
    }

    // type is important here as we want a std::function, rather than a lambda
    std::function<void(int idx, std::vector<std::pair<int, double>>& ret)> row_func = [&](int row_num, std::vector<std::pair<int, double>>& ret) {
202
      auto& jc = jniCache[row_num];
203
204
205
      ret.clear();  // reset size, but not free()
      ret.reserve(jc.size);  // make sure we have enough allocated

206
207
208
209
      // copy data
      int* indices0p = jc.indices0;
      double* values0p = jc.values0;
      int* indices0e = indices0p + jc.size;
210

211
212
      for (; indices0p != indices0e; ++indices0p, ++values0p)
        ret.emplace_back(*indices0p, *values0p);
213
214
215
216
217
    };

    int ret = LGBM_DatasetCreateFromCSRFunc(&row_func, num_rows, num_col, parameters, reference, out);

    for (auto& jc : jniCache) {
218
219
220
      // jenv->ReleasePrimitiveArrayCritical(jc.values, jc.values0, JNI_ABORT);
      // jenv->ReleasePrimitiveArrayCritical(jc.indices, jc.indices0, JNI_ABORT);
      jenv->ReleaseDoubleArrayElements(jc.values, jc.values0, JNI_ABORT);
221
      jenv->ReleaseIntArrayElements(jc.indices, (jint *)jc.indices0, JNI_ABORT);
222
223
224
    }

    return ret;
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
  }
%}

%pointer_functions(int, intp)
%pointer_functions(long, longp)
%pointer_functions(double, doublep)
%pointer_functions(float, floatp)
%pointer_functions(int64_t, int64_tp)
%pointer_functions(int32_t, int32_tp)

%pointer_cast(int64_t *, long *, int64_t_to_long_ptr)
%pointer_cast(int64_t *, double *, int64_t_to_double_ptr)
%pointer_cast(int32_t *, int *, int32_t_to_int_ptr)
%pointer_cast(long *, int64_t *, long_to_int64_t_ptr)
%pointer_cast(double *, int64_t *, double_to_int64_t_ptr)
%pointer_cast(int *, int32_t *, int_to_int32_t_ptr)
241
242

%pointer_cast(double *, void *, double_to_voidp_ptr)
243
%pointer_cast(float *, void *, float_to_voidp_ptr)
244
245
246
%pointer_cast(int *, void *, int_to_voidp_ptr)
%pointer_cast(int32_t *, void *, int32_t_to_voidp_ptr)
%pointer_cast(int64_t *, void *, int64_t_to_voidp_ptr)
247
248
249
250
251

%array_functions(double, doubleArray)
%array_functions(float, floatArray)
%array_functions(int, intArray)
%array_functions(long, longArray)
252
253
254
/* Note: there is a bug in the SWIG generated string arrays when creating
   a new array with strings where the strings are prematurely deallocated
*/
255
%array_functions(char *, stringArray)
256
257

/* Custom pointer manipulation template */
258
%define %pointer_manipulation(TYPE, NAME)
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
%{
  static TYPE *new_##NAME() { %}
  %{  TYPE* NAME = new TYPE; return NAME; %}
  %{}

  static void delete_##NAME(TYPE *self) { %}
  %{  if (self) delete self; %}
  %{}
  %}

TYPE *new_##NAME();
void  delete_##NAME(TYPE *self);

%enddef

274
%define %pointer_dereference(TYPE, NAME)
275
276
277
278
279
280
281
282
283
284
285
%{
  static TYPE NAME ##_value(TYPE *self) {
    TYPE NAME = *self;
    return NAME;
  }
%}

TYPE NAME##_value(TYPE *self);

%enddef

286
%define %pointer_handle(TYPE, NAME)
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
%{
  static TYPE* NAME ##_handle() { %}
  %{ TYPE* NAME = new TYPE; *NAME = (TYPE)operator new(sizeof(int*)); return NAME; %}
  %{}
%}

TYPE *NAME##_handle();

%enddef

%pointer_manipulation(void*, voidpp)

/* Allow dereferencing of void** to void* */
%pointer_dereference(void*, voidpp)

/* Allow retrieving handle to void** */
%pointer_handle(void*, voidpp)