Unverified Commit 44d37184 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Use double precision in threaded calculation of linear tree coefficients (fixes #5226) (#5368)

parent e0af160a
...@@ -109,6 +109,7 @@ include_directories(${EIGEN_DIR}) ...@@ -109,6 +109,7 @@ include_directories(${EIGEN_DIR})
# See https://gitlab.com/libeigen/eigen/-/blob/master/COPYING.README # See https://gitlab.com/libeigen/eigen/-/blob/master/COPYING.README
add_definitions(-DEIGEN_MPL2_ONLY) add_definitions(-DEIGEN_MPL2_ONLY)
add_definitions(-DEIGEN_DONT_PARALLELIZE)
if(__BUILD_FOR_R) if(__BUILD_FOR_R)
find_package(LibR REQUIRED) find_package(LibR REQUIRED)
......
...@@ -1713,7 +1713,7 @@ LGB_CPPFLAGS="" ...@@ -1713,7 +1713,7 @@ LGB_CPPFLAGS=""
# Eigen # # Eigen #
######### #########
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY" LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY -DEIGEN_DONT_PARALLELIZE"
############### ###############
# MM_PREFETCH # # MM_PREFETCH #
......
...@@ -35,7 +35,7 @@ LGB_CPPFLAGS="" ...@@ -35,7 +35,7 @@ LGB_CPPFLAGS=""
# Eigen # # Eigen #
######### #########
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY" LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY -DEIGEN_DONT_PARALLELIZE"
############### ###############
# MM_PREFETCH # # MM_PREFETCH #
......
...@@ -19,7 +19,7 @@ LGB_CPPFLAGS="" ...@@ -19,7 +19,7 @@ LGB_CPPFLAGS=""
# Eigen # # Eigen #
######### #########
LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY" LGB_CPPFLAGS="${LGB_CPPFLAGS} -DEIGEN_MPL2_ONLY -DEIGEN_DONT_PARALLELIZE"
############### ###############
# MM_PREFETCH # # MM_PREFETCH #
......
...@@ -47,8 +47,8 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav ...@@ -47,8 +47,8 @@ void LinearTreeLearner::InitLinear(const Dataset* train_data, const int max_leav
// store only upper triangular half of matrix as an array, in row-major order // store only upper triangular half of matrix as an array, in row-major order
// this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression) // this requires (max_num_feat + 1) * (max_num_feat + 2) / 2 entries (including the constant terms of the regression)
// we add another 8 to ensure cache lines are not shared among processors // we add another 8 to ensure cache lines are not shared among processors
XTHX_.push_back(std::vector<float>((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0)); XTHX_.push_back(std::vector<double>((max_num_feat + 1) * (max_num_feat + 2) / 2 + 8, 0));
XTg_.push_back(std::vector<float>(max_num_feat + 9, 0.0)); XTg_.push_back(std::vector<double>(max_num_feat + 9, 0.0));
} }
XTHX_by_thread_.clear(); XTHX_by_thread_.clear();
XTg_by_thread_.clear(); XTg_by_thread_.clear();
......
...@@ -118,10 +118,10 @@ class LinearTreeLearner: public SerialTreeLearner { ...@@ -118,10 +118,10 @@ class LinearTreeLearner: public SerialTreeLearner {
/*! \brief map dataset to leaves */ /*! \brief map dataset to leaves */
mutable std::vector<int> leaf_map_; mutable std::vector<int> leaf_map_;
/*! \brief temporary storage for calculating linear model coefficients */ /*! \brief temporary storage for calculating linear model coefficients */
mutable std::vector<std::vector<float>> XTHX_; mutable std::vector<std::vector<double>> XTHX_;
mutable std::vector<std::vector<float>> XTg_; mutable std::vector<std::vector<double>> XTg_;
mutable std::vector<std::vector<std::vector<float>>> XTHX_by_thread_; mutable std::vector<std::vector<std::vector<double>>> XTHX_by_thread_;
mutable std::vector<std::vector<std::vector<float>>> XTg_by_thread_; mutable std::vector<std::vector<std::vector<double>>> XTg_by_thread_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -84,7 +84,7 @@ def test_binary_linear(): ...@@ -84,7 +84,7 @@ def test_binary_linear():
X_test, _, X_test_fn = fd.load_dataset('.test') X_test, _, X_test_fn = fd.load_dataset('.test')
weight_train = fd.load_field('.train.weight') weight_train = fd.load_field('.train.weight')
lgb_train = lgb.Dataset(X_train, y_train, params=fd.params, weight=weight_train) lgb_train = lgb.Dataset(X_train, y_train, params=fd.params, weight=weight_train)
gbm = lgb.LGBMClassifier(**fd.params, n_jobs=0) gbm = lgb.LGBMClassifier(**fd.params)
gbm.fit(X_train, y_train, sample_weight=weight_train) gbm.fit(X_train, y_train, sample_weight=weight_train)
sk_pred = gbm.predict_proba(X_test)[:, 1] sk_pred = gbm.predict_proba(X_test)[:, 1]
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
......
...@@ -3046,6 +3046,26 @@ def test_interaction_constraints(): ...@@ -3046,6 +3046,26 @@ def test_interaction_constraints():
train_data, num_boost_round=10) train_data, num_boost_round=10)
def test_linear_trees_num_threads():
# check that number of threads does not affect result
np.random.seed(0)
x = np.arange(0, 1000, 0.1)
y = 2 * x + np.random.normal(0, 0.1, len(x))
x = x[:, np.newaxis]
lgb_train = lgb.Dataset(x, label=y)
params = {'verbose': -1,
'objective': 'regression',
'seed': 0,
'linear_tree': True,
'num_threads': 2}
est = lgb.train(params, lgb_train, num_boost_round=100)
pred1 = est.predict(x)
params["num_threads"] = 4
est = lgb.train(params, lgb_train, num_boost_round=100)
pred2 = est.predict(x)
np.testing.assert_allclose(pred1, pred2)
def test_linear_trees(tmp_path): def test_linear_trees(tmp_path):
# check that setting linear_tree=True fits better than ordinary trees when data has linear relationship # check that setting linear_tree=True fits better than ordinary trees when data has linear relationship
np.random.seed(0) np.random.seed(0)
......
...@@ -101,7 +101,7 @@ ...@@ -101,7 +101,7 @@
</PropertyGroup> </PropertyGroup>
<ItemDefinitionGroup> <ItemDefinitionGroup>
<ClCompile> <ClCompile>
<PreprocessorDefinitions>EIGEN_MPL2_ONLY</PreprocessorDefinitions> <PreprocessorDefinitions>EIGEN_MPL2_ONLY;EIGEN_DONT_PARALLELIZE</PreprocessorDefinitions>
</ClCompile> </ClCompile>
</ItemDefinitionGroup> </ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'"> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment