Commit bc0579c8 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add init_score & test cpp and python result consistency (#1007)

* add init_score & test cpp and python result consistency

* try fix common.h

* Fix tests (#3)

* update atof

* fix bug

* fix tests.

* fix bug

* fix dtypes

* fix categorical feature override

* fix protobuf on vs build (#1004)

* [optional] support protobuf

* fix windows/LightGBM.vcxproj

* add doc

* fix doc

* fix vs support (#2)

* fix vs support

* fix cmake

* fix #1012

* [python] add network config api  (#1019)

* add network

* update doc

* add float tolerance in bin finder.

* fix a bug

* update tests

* add double torelance on tree model

* fix tests

* simplify the double comparison

* fix lightsvm zero base

* move double tolerance to the bin finder.

* fix pylint

* clean test.sh

* add sklearn test

* remove underline

* clean codes

* set random_state=None

* add last line

* fix doc

* rename file

* try fix test
parent 04d4811b
...@@ -377,3 +377,6 @@ lightgbm.model ...@@ -377,3 +377,6 @@ lightgbm.model
# VSCode # VSCode
.vscode .vscode
# duplicate version file
python-package/lightgbm/VERSION.txt
0.039
0.187
0.831
0.767
0.351
0.377
0.534
0.000
0.241
0.208
0.250
0.806
0.280
0.192
0.504
0.866
0.241
0.079
0.356
0.748
0.551
0.817
0.960
0.793
0.604
0.493
0.040
0.984
0.383
0.152
0.667
0.284
0.586
0.587
0.446
0.836
0.265
0.449
0.538
0.664
0.784
0.395
0.646
0.151
0.933
0.383
0.730
0.020
0.205
0.487
0.878
0.527
0.930
0.484
0.490
0.120
0.803
0.247
0.900
0.911
0.943
0.520
0.677
0.779
0.131
0.601
0.034
0.498
0.155
0.183
0.365
0.432
0.623
0.074
0.504
0.183
0.574
0.637
0.557
0.738
0.336
0.765
0.433
0.484
0.648
0.018
0.654
0.619
0.310
0.086
0.091
0.923
0.689
0.127
0.357
0.592
0.836
0.044
0.237
0.890
0.009
0.201
0.959
0.613
0.262
0.067
0.028
0.245
0.881
0.416
0.720
0.918
0.408
0.191
0.517
0.908
0.804
0.066
0.693
0.572
0.907
0.122
0.534
0.879
0.410
0.482
0.070
0.278
0.325
0.945
0.283
0.461
0.671
0.162
0.486
0.739
0.867
0.626
0.669
0.126
0.946
0.133
0.775
0.265
0.934
0.720
0.754
0.219
0.443
0.618
0.770
0.104
0.962
0.890
0.270
0.823
0.518
0.462
0.314
0.581
0.730
0.411
0.629
0.699
0.711
0.052
0.860
0.458
0.262
0.242
0.483
0.887
0.378
0.750
0.097
0.476
0.992
0.770
0.211
0.501
0.234
0.410
0.780
0.771
0.228
0.922
0.593
0.380
0.502
0.605
0.560
0.486
0.505
0.176
0.813
0.542
0.131
0.766
0.932
0.947
0.369
0.136
0.518
0.113
0.934
0.184
0.253
0.407
0.383
0.795
0.456
0.171
0.267
0.509
0.147
0.612
0.566
0.715
0.938
0.912
0.946
0.245
0.132
0.302
0.895
0.972
0.859
0.110
0.947
0.423
0.009
0.442
0.046
0.544
0.339
0.473
0.613
0.869
0.662
0.434
0.819
0.906
0.120
0.532
0.285
0.047
0.669
0.863
0.163
0.812
0.853
0.914
0.265
0.904
0.321
0.552
0.051
0.044
0.720
0.444
0.256
0.190
0.670
0.000
0.806
0.079
0.191
0.386
0.485
0.355
0.321
0.964
0.642
0.023
0.430
0.875
0.301
0.095
0.758
0.606
0.570
0.054
0.140
0.623
0.208
0.504
0.545
0.284
0.948
0.842
0.722
0.078
0.106
0.493
0.161
0.978
0.159
0.487
0.364
0.639
0.129
0.430
0.275
0.888
0.041
0.914
0.833
0.298
0.789
0.031
0.967
0.527
0.303
0.363
0.066
0.989
0.039
0.655
0.443
0.949
0.246
0.532
0.482
0.703
0.068
0.194
0.215
0.738
0.189
0.573
0.215
0.862
0.942
0.518
0.352
0.234
0.050
0.269
0.654
0.534
0.944
0.396
0.694
0.489
0.513
0.268
0.455
0.471
0.707
0.941
0.329
0.042
0.496
0.544
0.168
0.760
0.985
0.946
0.197
0.875
0.704
0.454
0.541
0.850
0.480
0.373
0.493
0.579
0.189
0.901
0.674
0.633
0.099
0.604
0.121
0.079
0.527
0.403
0.589
0.089
0.431
0.175
0.987
0.561
0.687
0.325
0.095
0.976
0.286
0.424
0.650
0.025
0.810
0.537
0.278
0.062
0.162
0.895
0.686
0.250
0.066
0.691
0.572
0.405
0.364
0.217
0.670
0.971
0.176
0.597
0.424
0.447
0.254
0.825
0.485
0.543
0.305
0.182
0.086
0.714
0.196
0.690
0.390
0.416
0.469
0.368
0.101
0.310
0.664
0.666
0.286
0.460
0.193
0.210
0.023
0.897
0.211
0.228
0.280
0.127
0.639
0.075
0.134
0.645
0.340
0.708
0.557
0.256
0.651
0.116
0.536
0.437
0.268
0.604
0.871
0.999
0.608
0.405
0.225
0.257
0.479
0.367
0.914
0.368
0.373
0.384
0.837
0.651
0.614
0.334
0.818
0.038
0.871
0.513
0.398
0.497
0.667
0.013
0.872
0.447
0.343
0.138
0.439
0.496
0.404
0.679
0.421
0.961
0.599
0.807
0.109
0.397
0.337
0.569
0.861
0.078
0.073
0.850
0.213
0.669
This diff is collapsed.
...@@ -160,6 +160,21 @@ inline static const char* Atoi(const char* p, int* out) { ...@@ -160,6 +160,21 @@ inline static const char* Atoi(const char* p, int* out) {
return p; return p;
} }
template<class T>
inline static double Pow(T base, int power) {
if (power < 0) {
return 1.0 / Pow(base, -power);
} else if (power == 0) {
return 1;
} else if (power % 2 == 0) {
return Pow(base*base, power / 2);
} else if (power % 3 == 0) {
return Pow(base*base*base, power / 3);
} else {
return base * Pow(base, power - 1);
}
}
inline static const char* Atof(const char* p, double* out) { inline static const char* Atof(const char* p, double* out) {
int frac; int frac;
double sign, value, scale; double sign, value, scale;
...@@ -168,7 +183,6 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -168,7 +183,6 @@ inline static const char* Atof(const char* p, double* out) {
while (*p == ' ') { while (*p == ' ') {
++p; ++p;
} }
// Get sign, if any. // Get sign, if any.
sign = 1.0; sign = 1.0;
if (*p == '-') { if (*p == '-') {
...@@ -187,13 +201,15 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -187,13 +201,15 @@ inline static const char* Atof(const char* p, double* out) {
// Get digits after decimal point, if any. // Get digits after decimal point, if any.
if (*p == '.') { if (*p == '.') {
double pow10 = 10.0; double right = 0.0;
int nn = 0;
++p; ++p;
while (*p >= '0' && *p <= '9') { while (*p >= '0' && *p <= '9') {
value += (*p - '0') / pow10; right = (*p - '0') + right * 10.0;
pow10 *= 10.0; ++nn;
++p; ++p;
} }
value += right / Pow(10.0, nn);
} }
// Handle exponent, if any. // Handle exponent, if any.
...@@ -250,8 +266,6 @@ inline static const char* Atof(const char* p, double* out) { ...@@ -250,8 +266,6 @@ inline static const char* Atof(const char* p, double* out) {
return p; return p;
} }
inline bool AtoiAndCheck(const char* p, int* out) { inline bool AtoiAndCheck(const char* p, int* out) {
const char* after = Atoi(p, out); const char* after = Atoi(p, out);
if (*after != '\0') { if (*after != '\0') {
...@@ -632,6 +646,15 @@ inline bool FindInBitset(const uint32_t* bits, int n, T pos) { ...@@ -632,6 +646,15 @@ inline bool FindInBitset(const uint32_t* bits, int n, T pos) {
return (bits[i1] >> i2) & 1; return (bits[i1] >> i2) & 1;
} }
inline static bool CheckDoubleEqualOrdered(double a, double b) {
double upper = std::nextafter(a, INFINITY);
return b <= upper;
}
inline static double GetDoubleUpperBound(double a) {
return std::nextafter(a, INFINITY);;
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
...@@ -131,7 +131,7 @@ def param_dict_to_str(data): ...@@ -131,7 +131,7 @@ def param_dict_to_str(data):
pairs.append(str(key) + '=' + ','.join(map(str, val))) pairs.append(str(key) + '=' + ','.join(map(str, val)))
elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val): elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val):
pairs.append(str(key) + '=' + str(val)) pairs.append(str(key) + '=' + str(val))
else: elif val is not None:
raise TypeError('Unknown type of parameter:%s, got:%s' raise TypeError('Unknown type of parameter:%s, got:%s'
% (key, type(val).__name__)) % (key, type(val).__name__))
return ' '.join(pairs) return ' '.join(pairs)
...@@ -555,8 +555,8 @@ class _InnerPredictor(object): ...@@ -555,8 +555,8 @@ class _InnerPredictor(object):
class Dataset(object): class Dataset(object):
"""Dataset in LightGBM.""" """Dataset in LightGBM."""
def __init__(self, data, label=None, max_bin=255, reference=None, def __init__(self, data, label=None, max_bin=None, reference=None,
weight=None, group=None, silent=False, weight=None, group=None, init_score=None, silent=False,
feature_name='auto', categorical_feature='auto', params=None, feature_name='auto', categorical_feature='auto', params=None,
free_raw_data=True): free_raw_data=True):
"""Constract Dataset. """Constract Dataset.
...@@ -566,9 +566,9 @@ class Dataset(object): ...@@ -566,9 +566,9 @@ class Dataset(object):
data : string, numpy array or scipy.sparse data : string, numpy array or scipy.sparse
Data source of Dataset. Data source of Dataset.
If string, it represents the path to txt file. If string, it represents the path to txt file.
label : list or numpy 1-D array, optional (default=None) label : list, numpy 1-D array or None, optional (default=None)
Label of the data. Label of the data.
max_bin : int, optional (default=255) max_bin : int or None, optional (default=None)
Max number of discrete bins for features. Max number of discrete bins for features.
reference : Dataset or None, optional (default=None) reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference. If this is Dataset for validation, training data should be used as reference.
...@@ -576,6 +576,8 @@ class Dataset(object): ...@@ -576,6 +576,8 @@ class Dataset(object):
Weight for each instance. Weight for each instance.
group : list, numpy 1-D array or None, optional (default=None) group : list, numpy 1-D array or None, optional (default=None)
Group/query size for Dataset. Group/query size for Dataset.
init_score : list, numpy 1-D array or None, optional (default=None)
Init score for Dataset.
silent : bool, optional (default=False) silent : bool, optional (default=False)
Whether to print messages during construction. Whether to print messages during construction.
feature_name : list of strings or 'auto', optional (default="auto") feature_name : list of strings or 'auto', optional (default="auto")
...@@ -598,6 +600,7 @@ class Dataset(object): ...@@ -598,6 +600,7 @@ class Dataset(object):
self.reference = reference self.reference = reference
self.weight = weight self.weight = weight
self.group = group self.group = group
self.init_score = init_score
self.silent = silent self.silent = silent
self.feature_name = feature_name self.feature_name = feature_name
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
...@@ -616,8 +619,8 @@ class Dataset(object): ...@@ -616,8 +619,8 @@ class Dataset(object):
_safe_call(_LIB.LGBM_DatasetFree(self.handle)) _safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None self.handle = None
def _lazy_init(self, data, label=None, max_bin=255, reference=None, def _lazy_init(self, data, label=None, max_bin=None, reference=None,
weight=None, group=None, predictor=None, weight=None, group=None, init_score=None, predictor=None,
silent=False, feature_name='auto', silent=False, feature_name='auto',
categorical_feature='auto', params=None): categorical_feature='auto', params=None):
if data is None: if data is None:
...@@ -633,7 +636,8 @@ class Dataset(object): ...@@ -633,7 +636,8 @@ class Dataset(object):
params = {} if params is None else params params = {} if params is None else params
self.max_bin = max_bin self.max_bin = max_bin
self.predictor = predictor self.predictor = predictor
params["max_bin"] = max_bin if self.max_bin is not None:
params["max_bin"] = self.max_bin
if "verbosity" in params: if "verbosity" in params:
params.setdefault("verbose", params.pop("verbosity")) params.setdefault("verbose", params.pop("verbosity"))
if silent: if silent:
...@@ -655,6 +659,10 @@ class Dataset(object): ...@@ -655,6 +659,10 @@ class Dataset(object):
raise TypeError("Wrong type({}) or unknown name({}) in categorical_feature" raise TypeError("Wrong type({}) or unknown name({}) in categorical_feature"
.format(type(name).__name__, name)) .format(type(name).__name__, name))
if categorical_indices: if categorical_indices:
if "categorical_feature" in params or "categorical_column" in params:
warnings.warn('categorical_feature in param dict is overrided.')
params.pop("categorical_feature", None)
params.pop("categorical_column", None)
params['categorical_column'] = sorted(categorical_indices) params['categorical_column'] = sorted(categorical_indices)
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
...@@ -697,7 +705,11 @@ class Dataset(object): ...@@ -697,7 +705,11 @@ class Dataset(object):
if group is not None: if group is not None:
self.set_group(group) self.set_group(group)
# load init score # load init score
if isinstance(self.predictor, _InnerPredictor): if init_score is not None:
self.set_init_score(init_score)
if self.predictor is not None:
warnings.warn("The prediction of init_model will be overrided by init_score.")
elif isinstance(self.predictor, _InnerPredictor):
init_score = self.predictor.predict(data, init_score = self.predictor.predict(data,
raw_score=True, raw_score=True,
data_has_header=self.data_has_header, data_has_header=self.data_has_header,
...@@ -802,7 +814,7 @@ class Dataset(object): ...@@ -802,7 +814,7 @@ class Dataset(object):
if self.used_indices is None: if self.used_indices is None:
"""create valid""" """create valid"""
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference, self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, reference=self.reference,
weight=self.weight, group=self.group, predictor=self._predictor, weight=self.weight, group=self.group, init_score=self.init_score, predictor=self._predictor,
silent=self.silent, feature_name=self.feature_name, params=self.params) silent=self.silent, feature_name=self.feature_name, params=self.params)
else: else:
"""construct subset""" """construct subset"""
...@@ -820,15 +832,15 @@ class Dataset(object): ...@@ -820,15 +832,15 @@ class Dataset(object):
else: else:
"""create train""" """create train"""
self._lazy_init(self.data, label=self.label, max_bin=self.max_bin, self._lazy_init(self.data, label=self.label, max_bin=self.max_bin,
weight=self.weight, group=self.group, predictor=self._predictor, weight=self.weight, group=self.group, init_score=self.init_score,
silent=self.silent, feature_name=self.feature_name, predictor=self._predictor, silent=self.silent, feature_name=self.feature_name,
categorical_feature=self.categorical_feature, params=self.params) categorical_feature=self.categorical_feature, params=self.params)
if self.free_raw_data: if self.free_raw_data:
self.data = None self.data = None
return self return self
def create_valid(self, data, label=None, weight=None, group=None, def create_valid(self, data, label=None, weight=None, group=None,
silent=False, params=None): init_score=None, silent=False, params=None):
"""Create validation data align with current Dataset. """Create validation data align with current Dataset.
Parameters Parameters
...@@ -842,6 +854,8 @@ class Dataset(object): ...@@ -842,6 +854,8 @@ class Dataset(object):
Weight for each instance. Weight for each instance.
group : list, numpy 1-D array or None, optional (default=None) group : list, numpy 1-D array or None, optional (default=None)
Group/query size for Dataset. Group/query size for Dataset.
init_score : list, numpy 1-D array or None, optional (default=None)
Init score for Dataset.
silent : bool, optional (default=False) silent : bool, optional (default=False)
Whether to print messages during construction. Whether to print messages during construction.
params: dict or None, optional (default=None) params: dict or None, optional (default=None)
...@@ -853,8 +867,8 @@ class Dataset(object): ...@@ -853,8 +867,8 @@ class Dataset(object):
Returns self. Returns self.
""" """
ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self, ret = Dataset(data, label=label, max_bin=self.max_bin, reference=self,
weight=weight, group=group, silent=silent, params=params, weight=weight, group=group, init_score=init_score,
free_raw_data=self.free_raw_data) silent=silent, params=params, free_raw_data=self.free_raw_data)
ret._predictor = self._predictor ret._predictor = self._predictor
ret.pandas_categorical = self.pandas_categorical ret.pandas_categorical = self.pandas_categorical
return ret return ret
......
...@@ -95,13 +95,13 @@ def train(params, train_set, num_boost_round=100, ...@@ -95,13 +95,13 @@ def train(params, train_set, num_boost_round=100,
"""create predictor first""" """create predictor first"""
for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]: for alias in ["num_boost_round", "num_iterations", "num_iteration", "num_tree", "num_trees", "num_round", "num_rounds"]:
if alias in params: if alias in params:
num_boost_round = int(params.pop(alias))
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break break
for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]: for alias in ["early_stopping_round", "early_stopping_rounds", "early_stopping"]:
if alias in params: if alias in params and params[alias] is not None:
early_stopping_rounds = int(params.pop(alias))
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias)) warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break break
if isinstance(init_model, string_type): if isinstance(init_model, string_type):
......
...@@ -142,7 +142,7 @@ class LGBMModel(_LGBMModelBase): ...@@ -142,7 +142,7 @@ class LGBMModel(_LGBMModelBase):
subsample_for_bin=200000, objective=None, subsample_for_bin=200000, objective=None,
min_split_gain=0., min_child_weight=1e-3, min_child_samples=20, min_split_gain=0., min_child_weight=1e-3, min_child_samples=20,
subsample=1., subsample_freq=1, colsample_bytree=1., subsample=1., subsample_freq=1, colsample_bytree=1.,
reg_alpha=0., reg_lambda=0., random_state=0, reg_alpha=0., reg_lambda=0., random_state=None,
n_jobs=-1, silent=True, **kwargs): n_jobs=-1, silent=True, **kwargs):
"""Construct a gradient boosting model. """Construct a gradient boosting model.
...@@ -185,8 +185,9 @@ class LGBMModel(_LGBMModelBase): ...@@ -185,8 +185,9 @@ class LGBMModel(_LGBMModelBase):
L1 regularization term on weights. L1 regularization term on weights.
reg_lambda : float, optional (default=0.) reg_lambda : float, optional (default=0.)
L2 regularization term on weights. L2 regularization term on weights.
random_state : int, optional (default=0) random_state : int or None, optional (default=None)
Random number seed. Random number seed.
Will use default seeds in c++ code if set to None.
n_jobs : int, optional (default=-1) n_jobs : int, optional (default=-1)
Number of parallel threads. Number of parallel threads.
silent : bool, optional (default=True) silent : bool, optional (default=True)
......
...@@ -79,10 +79,13 @@ namespace LightGBM { ...@@ -79,10 +79,13 @@ namespace LightGBM {
for (int i = 0; i < num_distinct_values - 1; ++i) { for (int i = 0; i < num_distinct_values - 1; ++i) {
cur_cnt_inbin += counts[i]; cur_cnt_inbin += counts[i];
if (cur_cnt_inbin >= min_data_in_bin) { if (cur_cnt_inbin >= min_data_in_bin) {
bin_upper_bound.push_back((distinct_values[i] + distinct_values[i + 1]) / 2); auto val = Common::GetDoubleUpperBound((distinct_values[i] + distinct_values[i + 1]) / 2.0);
if (bin_upper_bound.empty() || !Common::CheckDoubleEqualOrdered(bin_upper_bound.back(), val)) {
bin_upper_bound.push_back(val);
cur_cnt_inbin = 0; cur_cnt_inbin = 0;
} }
} }
}
cur_cnt_inbin += counts[num_distinct_values - 1]; cur_cnt_inbin += counts[num_distinct_values - 1];
bin_upper_bound.push_back(std::numeric_limits<double>::infinity()); bin_upper_bound.push_back(std::numeric_limits<double>::infinity());
} else { } else {
...@@ -131,12 +134,15 @@ namespace LightGBM { ...@@ -131,12 +134,15 @@ namespace LightGBM {
} }
++bin_cnt; ++bin_cnt;
// update bin upper bound // update bin upper bound
bin_upper_bound.resize(bin_cnt); bin_upper_bound.clear();
for (int i = 0; i < bin_cnt - 1; ++i) { for (int i = 0; i < bin_cnt - 1; ++i) {
bin_upper_bound[i] = (upper_bounds[i] + lower_bounds[i + 1]) / 2.0f; auto val = Common::GetDoubleUpperBound((upper_bounds[i] + lower_bounds[i + 1]) / 2.0);
if (bin_upper_bound.empty() || !Common::CheckDoubleEqualOrdered(bin_upper_bound.back(), val)) {
bin_upper_bound.push_back(val);
}
} }
// last bin upper bound // last bin upper bound
bin_upper_bound[bin_cnt - 1] = std::numeric_limits<double>::infinity(); bin_upper_bound.push_back(std::numeric_limits<double>::infinity());
} }
return bin_upper_bound; return bin_upper_bound;
} }
...@@ -241,7 +247,7 @@ namespace LightGBM { ...@@ -241,7 +247,7 @@ namespace LightGBM {
} }
for (int i = 1; i < num_sample_values; ++i) { for (int i = 1; i < num_sample_values; ++i) {
if (values[i] != values[i - 1]) { if (!Common::CheckDoubleEqualOrdered(values[i - 1], values[i])) {
if (values[i - 1] < 0.0f && values[i] > 0.0f) { if (values[i - 1] < 0.0f && values[i] > 0.0f) {
distinct_values.push_back(0.0f); distinct_values.push_back(0.0f);
counts.push_back(zero_cnt); counts.push_back(zero_cnt);
...@@ -249,6 +255,8 @@ namespace LightGBM { ...@@ -249,6 +255,8 @@ namespace LightGBM {
distinct_values.push_back(values[i]); distinct_values.push_back(values[i]);
counts.push_back(1); counts.push_back(1);
} else { } else {
// use the large value
distinct_values.back() = values[i];
++counts.back(); ++counts.back();
} }
} }
......
# coding: utf-8
# pylint: skip-file
import os
import unittest
import lightgbm as lgb
import numpy as np
from sklearn.datasets import load_svmlight_file
class FileLoader(object):
def __init__(self, directory, prefix, config_file='train.conf'):
directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), directory)
self.directory = directory
self.prefix = prefix
self.params = {'gpu_use_dp': True}
with open(os.path.join(directory, config_file), 'r') as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith('#'):
key, value = [token.strip() for token in line.split('=')]
if 'early_stopping' not in key: # disable early_stopping
self.params[key] = value
def load_dataset(self, suffix, is_sparse=False):
filename = os.path.join(self.directory, self.prefix + suffix)
if is_sparse:
X, Y = load_svmlight_file(filename, dtype=np.float64, zero_based=True)
return X, Y, filename
else:
mat = np.loadtxt(filename, dtype=np.float64)
return mat[:, 1:], mat[:, 0], filename
def load_field(self, suffix):
return np.loadtxt(os.path.join(self.directory, self.prefix + suffix))
def load_cpp_result(self, result_file='LightGBM_predict_result.txt'):
return np.loadtxt(os.path.join(self.directory, result_file))
def train_predict_check(self, lgb_train, X_test, X_test_fn, sk_pred):
gbm = lgb.train(self.params, lgb_train)
y_pred = gbm.predict(X_test)
cpp_pred = gbm.predict(X_test_fn)
np.testing.assert_array_almost_equal(y_pred, cpp_pred, decimal=5)
np.testing.assert_array_almost_equal(y_pred, sk_pred, decimal=5)
class TestEngine(unittest.TestCase):
def test_binary(self):
fd = FileLoader('../../examples/binary_classification', 'binary')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
weight_train = fd.load_field('.train.weight')
lgb_train = lgb.Dataset(X_train, y_train, params=fd.params, weight=weight_train)
gbm = lgb.LGBMClassifier(**fd.params)
gbm.fit(X_train, y_train, sample_weight=weight_train)
sk_pred = gbm.predict_proba(X_test)[:, 1]
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
def test_multiclass(self):
fd = FileLoader('../../examples/multiclass_classification', 'multiclass')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
lgb_train = lgb.Dataset(X_train, y_train)
gbm = lgb.LGBMClassifier(**fd.params)
gbm.fit(X_train, y_train)
sk_pred = gbm.predict_proba(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
def test_regression(self):
fd = FileLoader('../../examples/regression', 'regression')
X_train, y_train, _ = fd.load_dataset('.train')
X_test, _, X_test_fn = fd.load_dataset('.test')
init_score_train = fd.load_field('.train.init')
lgb_train = lgb.Dataset(X_train, y_train, init_score=init_score_train)
gbm = lgb.LGBMRegressor(**fd.params)
gbm.fit(X_train, y_train, init_score=init_score_train)
sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
def test_lambdarank(self):
fd = FileLoader('../../examples/lambdarank', 'rank')
X_train, y_train, _ = fd.load_dataset('.train', is_sparse=True)
X_test, _, X_test_fn = fd.load_dataset('.test', is_sparse=True)
group_train = fd.load_field('.train.query')
lgb_train = lgb.Dataset(X_train, y_train, group=group_train)
gbm = lgb.LGBMRanker(**fd.params)
gbm.fit(X_train, y_train, group=group_train)
sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment