Commit c77153a1 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

add NumModelPerIteration and NumberOfTotalModel in C_API (#1613)

* added NumberOfTotalModel and NumModelPerIteration to C_API and python-package

* fixed tests

* added tests for current_iteration, num_trees, num_model_per_iteration methods

* break huge line in test

* hotfix
parent 8c6ef946
...@@ -465,6 +465,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle); ...@@ -465,6 +465,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRollbackOneIter(BoosterHandle handle);
*/ */
LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration); LIGHTGBM_C_EXPORT int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration);
/*!
* \brief Get number of tree per iteration
* \param out_tree_per_iteration number of tree per iteration
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration);
/*!
* \brief Get number of weak sub-models
* \param out_models number of weak sub-models
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models);
/*! /*!
* \brief Get number of eval * \brief Get number of eval
* \param out_len total number of eval results * \param out_len total number of eval results
......
...@@ -1737,6 +1737,34 @@ class Booster(object): ...@@ -1737,6 +1737,34 @@ class Booster(object):
ctypes.byref(out_cur_iter))) ctypes.byref(out_cur_iter)))
return out_cur_iter.value return out_cur_iter.value
def num_model_per_iteration(self):
"""Get number of models per iteration.
Returns
-------
model_per_iter : int
The number of models per iteration.
"""
model_per_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterNumModelPerIteration(
self.handle,
ctypes.byref(model_per_iter)))
return model_per_iter.value
def num_trees(self):
"""Get number of weak sub-models.
Returns
-------
num_trees : int
The number of weak sub-models.
"""
num_trees = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterNumberOfTotalModel(
self.handle,
ctypes.byref(num_trees)))
return num_trees.value
def eval(self, data, name, feval=None): def eval(self, data, name, feval=None):
"""Evaluate for data. """Evaluate for data.
......
...@@ -1017,6 +1017,20 @@ int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) { ...@@ -1017,6 +1017,20 @@ int LGBM_BoosterGetCurrentIteration(BoosterHandle handle, int* out_iteration) {
API_END(); API_END();
} }
int LGBM_BoosterNumModelPerIteration(BoosterHandle handle, int* out_tree_per_iteration) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_tree_per_iteration = ref_booster->GetBoosting()->NumModelPerIteration();
API_END();
}
int LGBM_BoosterNumberOfTotalModel(BoosterHandle handle, int* out_models) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_models = ref_booster->GetBoosting()->NumberOfTotalModel();
API_END();
}
int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) { int LGBM_BoosterGetEvalCounts(BoosterHandle handle, int* out_len) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
......
...@@ -253,5 +253,12 @@ def test_booster(): ...@@ -253,5 +253,12 @@ def test_booster():
c_str(''), c_str(''),
ctypes.byref(num_preb), ctypes.byref(num_preb),
preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double))) preb.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))
LIB.LGBM_BoosterPredictForFile(booster2, c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')), 0, 0, 50, c_str(''), c_str('preb.txt')) LIB.LGBM_BoosterPredictForFile(
booster2,
c_str(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/binary_classification/binary.test')),
0,
0,
50,
c_str(''),
c_str('preb.txt'))
LIB.LGBM_BoosterFree(booster2) LIB.LGBM_BoosterFree(booster2)
...@@ -33,6 +33,11 @@ class TestBasic(unittest.TestCase): ...@@ -33,6 +33,11 @@ class TestBasic(unittest.TestCase):
bst.update() bst.update()
if i % 10 == 0: if i % 10 == 0:
print(bst.eval_train(), bst.eval_valid()) print(bst.eval_train(), bst.eval_valid())
self.assertEqual(bst.current_iteration(), 30)
self.assertEqual(bst.num_trees(), 30)
self.assertEqual(bst.num_model_per_iteration(), 1)
bst.save_model("model.txt") bst.save_model("model.txt")
pred_from_matr = bst.predict(X_test) pred_from_matr = bst.predict(X_test)
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
......
...@@ -583,14 +583,14 @@ class TestEngine(unittest.TestCase): ...@@ -583,14 +583,14 @@ class TestEngine(unittest.TestCase):
stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0) stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0)
stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0) stacked_features = np.concatenate((stacked_features, np.ones(9, dtype=np.float32).reshape((1, 9))), axis=0)
# test sliced 2d matrix # test sliced 2d matrix
sliced_features = stacked_features[2:102, 2: 7] sliced_features = stacked_features[2:102, 2:7]
assert np.all(sliced_features == features) self.assertTrue(np.all(sliced_features == features))
sliced_pred = train_and_get_predictions(sliced_features, sliced_labels) sliced_pred = train_and_get_predictions(sliced_features, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred) np.testing.assert_almost_equal(origin_pred, sliced_pred)
# test sliced CSR # test sliced CSR
stacked_csr = csr_matrix(stacked_features) stacked_csr = csr_matrix(stacked_features)
sliced_csr = stacked_csr[2:102, 2: 7] sliced_csr = stacked_csr[2:102, 2:7]
assert np.all(sliced_csr == features) self.assertTrue(np.all(sliced_csr == features))
sliced_pred = train_and_get_predictions(sliced_csr, sliced_labels) sliced_pred = train_and_get_predictions(sliced_csr, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred) np.testing.assert_almost_equal(origin_pred, sliced_pred)
...@@ -632,7 +632,7 @@ class TestEngine(unittest.TestCase): ...@@ -632,7 +632,7 @@ class TestEngine(unittest.TestCase):
'monotone_constraints': '1,-1' 'monotone_constraints': '1,-1'
} }
constrained_model = lgb.train(params, trainset) constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model) self.assertTrue(is_correctly_constrained(constrained_model))
def test_refit(self): def test_refit(self):
X, y = load_breast_cancer(True) X, y = load_breast_cancer(True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment