Commit 4f77bd28 authored by Guolin Ke's avatar Guolin Ke
Browse files

update to v2.

parent 13d4581b
...@@ -6,21 +6,21 @@ import unittest ...@@ -6,21 +6,21 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer, dump_svmlight_file
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
class TestBasic(unittest.TestCase): class TestBasic(unittest.TestCase):
def test(self): def test(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=1) X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=2)
train_data = lgb.Dataset(X_train, max_bin=255, label=y_train) train_data = lgb.Dataset(X_train, max_bin=255, label=y_train)
valid_data = train_data.create_valid(X_test, label=y_test) valid_data = train_data.create_valid(X_test, label=y_test)
params = { params = {
"objective": "binary", "objective": "binary",
"metric": "auc", "metric": "auc",
"min_data": 1, "min_data": 10,
"num_leaves": 15, "num_leaves": 15,
"verbose": -1 "verbose": -1
} }
...@@ -36,7 +36,7 @@ class TestBasic(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestBasic(unittest.TestCase):
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
tname = f.name tname = f.name
with open(tname, "w+b") as f: with open(tname, "w+b") as f:
np.savetxt(f, X_test, delimiter=',') dump_svmlight_file(X_test, y_test, f)
pred_from_file = bst.predict(tname) pred_from_file = bst.predict(tname)
os.remove(tname) os.remove(tname)
self.assertEqual(len(pred_from_matr), len(pred_from_file)) self.assertEqual(len(pred_from_matr), len(pred_from_file))
...@@ -49,7 +49,7 @@ class TestBasic(unittest.TestCase): ...@@ -49,7 +49,7 @@ class TestBasic(unittest.TestCase):
for preds in zip(pred_from_matr, pred_from_model_file): for preds in zip(pred_from_matr, pred_from_model_file):
self.assertEqual(*preds) self.assertEqual(*preds)
# check pmml # check pmml
os.system('python ../../pmml/pmml.py model.txt') # os.system('python ../../pmml/pmml.py model.txt')
print("----------------------------------------------------------------------") print("----------------------------------------------------------------------")
......
...@@ -32,7 +32,7 @@ class template(object): ...@@ -32,7 +32,7 @@ class template(object):
@staticmethod @staticmethod
def test_template(params={'objective': 'regression', 'metric': 'l2'}, def test_template(params={'objective': 'regression', 'metric': 'l2'},
X_y=load_boston(True), feval=mean_squared_error, X_y=load_boston(True), feval=mean_squared_error,
num_round=100, init_model=None, custom_eval=None, num_round=150, init_model=None, custom_eval=None,
early_stopping_rounds=10, early_stopping_rounds=10,
return_data=False, return_model=False): return_data=False, return_model=False):
params['verbose'], params['seed'] = -1, 42 params['verbose'], params['seed'] = -1, 42
...@@ -153,49 +153,6 @@ class TestEngine(unittest.TestCase): ...@@ -153,49 +153,6 @@ class TestEngine(unittest.TestCase):
for ret in other_ret: for ret in other_ret:
self.assertAlmostEqual(ret_origin, ret, places=5) self.assertAlmostEqual(ret_origin, ret, places=5)
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100), # int
"C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
"D": np.random.permutation([True, False] * 150)}) # bool
y = np.random.permutation([0, 1] * 150)
X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20),
"B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)})
for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category')
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1
}
lgb_train = lgb.Dataset(X, y)
gbm0 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False)
pred0 = list(gbm0.predict(X_test))
lgb_train = lgb.Dataset(X, y)
gbm1 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=[0])
pred1 = list(gbm1.predict(X_test))
lgb_train = lgb.Dataset(X, y)
gbm2 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A'])
pred2 = list(gbm2.predict(X_test))
lgb_train = lgb.Dataset(X, y)
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test))
lgb_train = lgb.Dataset(X, y)
gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test))
self.assertListEqual(pred0, pred1)
self.assertListEqual(pred0, pred2)
self.assertListEqual(pred0, pred3)
self.assertListEqual(pred0, pred4)
print("----------------------------------------------------------------------") print("----------------------------------------------------------------------")
print("running test_engine.py") print("running test_engine.py")
......
...@@ -196,7 +196,7 @@ ...@@ -196,7 +196,7 @@
<ClInclude Include="..\include\LightGBM\c_api.h" /> <ClInclude Include="..\include\LightGBM\c_api.h" />
<ClInclude Include="..\include\LightGBM\dataset.h" /> <ClInclude Include="..\include\LightGBM\dataset.h" />
<ClInclude Include="..\include\LightGBM\dataset_loader.h" /> <ClInclude Include="..\include\LightGBM\dataset_loader.h" />
<ClInclude Include="..\include\LightGBM\feature.h" /> <ClInclude Include="..\include\LightGBM\feature_group.h" />
<ClInclude Include="..\include\LightGBM\meta.h" /> <ClInclude Include="..\include\LightGBM\meta.h" />
<ClInclude Include="..\include\LightGBM\metric.h" /> <ClInclude Include="..\include\LightGBM\metric.h" />
<ClInclude Include="..\include\LightGBM\network.h" /> <ClInclude Include="..\include\LightGBM\network.h" />
...@@ -213,6 +213,7 @@ ...@@ -213,6 +213,7 @@
<ClInclude Include="..\src\application\predictor.hpp" /> <ClInclude Include="..\src\application\predictor.hpp" />
<ClInclude Include="..\src\boosting\gbdt.h" /> <ClInclude Include="..\src\boosting\gbdt.h" />
<ClInclude Include="..\src\boosting\dart.hpp" /> <ClInclude Include="..\src\boosting\dart.hpp" />
<ClInclude Include="..\src\boosting\goss.hpp" />
<ClInclude Include="..\src\boosting\score_updater.hpp" /> <ClInclude Include="..\src\boosting\score_updater.hpp" />
<ClInclude Include="..\src\io\dense_bin.hpp" /> <ClInclude Include="..\src\io\dense_bin.hpp" />
<ClInclude Include="..\src\io\ordered_sparse_bin.hpp" /> <ClInclude Include="..\src\io\ordered_sparse_bin.hpp" />
......
...@@ -96,15 +96,9 @@ ...@@ -96,15 +96,9 @@
<ClInclude Include="..\src\treelearner\data_partition.hpp"> <ClInclude Include="..\src\treelearner\data_partition.hpp">
<Filter>src\treelearner</Filter> <Filter>src\treelearner</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\treelearner\feature_histogram.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\src\treelearner\leaf_splits.hpp"> <ClInclude Include="..\src\treelearner\leaf_splits.hpp">
<Filter>src\treelearner</Filter> <Filter>src\treelearner</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\src\treelearner\split_info.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\application.h"> <ClInclude Include="..\include\LightGBM\application.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
...@@ -120,9 +114,6 @@ ...@@ -120,9 +114,6 @@
<ClInclude Include="..\include\LightGBM\dataset.h"> <ClInclude Include="..\include\LightGBM\dataset.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\include\LightGBM\feature.h">
<Filter>include\LightGBM</Filter>
</ClInclude>
<ClInclude Include="..\include\LightGBM\meta.h"> <ClInclude Include="..\include\LightGBM\meta.h">
<Filter>include\LightGBM</Filter> <Filter>include\LightGBM</Filter>
</ClInclude> </ClInclude>
...@@ -171,6 +162,18 @@ ...@@ -171,6 +162,18 @@
<ClInclude Include="..\src\boosting\dart.hpp"> <ClInclude Include="..\src\boosting\dart.hpp">
<Filter>src\boosting</Filter> <Filter>src\boosting</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="..\include\LightGBM\feature_group.h">
<Filter>include\LightGBM</Filter>
</ClInclude>
<ClInclude Include="..\src\treelearner\feature_histogram.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\src\treelearner\split_info.hpp">
<Filter>src\treelearner</Filter>
</ClInclude>
<ClInclude Include="..\src\boosting\goss.hpp">
<Filter>src\boosting</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="..\src\application\application.cpp"> <ClCompile Include="..\src\application\application.cpp">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment