Commit facefa02 authored by Davis King's avatar Davis King
Browse files

Fix random foreset regression not doing quite the right thing.

parent fe803b56
...@@ -376,8 +376,6 @@ namespace dlib ...@@ -376,8 +376,6 @@ namespace dlib
std::vector<std::vector<internal_tree_node<feature_extractor>>> all_trees(num_trees); std::vector<std::vector<internal_tree_node<feature_extractor>>> all_trees(num_trees);
std::vector<std::vector<float>> all_leaves(num_trees); std::vector<std::vector<float>> all_leaves(num_trees);
const double sumy = sum(mat(y));
const size_t feats_per_node = std::max(1.0,std::round(fe.max_num_feats()*feature_subsampling_frac)); const size_t feats_per_node = std::max(1.0,std::round(fe.max_num_feats()*feature_subsampling_frac));
// Each tree couldn't have more than this many interior nodes. It might // Each tree couldn't have more than this many interior nodes. It might
...@@ -412,15 +410,18 @@ namespace dlib ...@@ -412,15 +410,18 @@ namespace dlib
// don't make any tree. Just average the things and be done. // don't make any tree. Just average the things and be done.
if (y.size() <= min_samples_per_leaf) if (y.size() <= min_samples_per_leaf)
{ {
leaves.push_back(sumy/y.size()); leaves.push_back(mean(mat(y)));
return; return;
} }
double sumy = 0;
// pick a random bootstrap of the data. // pick a random bootstrap of the data.
std::vector<std::pair<float,uint32_t>> idxs(y.size()); std::vector<std::pair<float,uint32_t>> idxs(y.size());
for (auto& idx : idxs) for (auto& idx : idxs) {
idx = std::make_pair(0.0f, static_cast<uint32_t>(rnd.get_integer(y.size()))); idx = std::make_pair(0.0f, static_cast<uint32_t>(rnd.get_integer(y.size())));
sumy += y[idx.second];
}
// We are going to use ranges_to_process as a stack that tracks which // We are going to use ranges_to_process as a stack that tracks which
// range of samples we are going to split next. // range of samples we are going to split next.
...@@ -702,7 +703,7 @@ namespace dlib ...@@ -702,7 +703,7 @@ namespace dlib
for (auto i = range.begin; i < range.end; ++i) for (auto i = range.begin; i < range.end; ++i)
idxs[i].first = fe.extract_feature_value(x[idxs[i].second], feat); idxs[i].first = fe.extract_feature_value(x[idxs[i].second], feat);
std::sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first); std::stable_sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first);
auto split = find_best_split(range, y, idxs); auto split = find_best_split(range, y, idxs);
...@@ -716,7 +717,7 @@ namespace dlib ...@@ -716,7 +717,7 @@ namespace dlib
// resort idxs based on winning feat // resort idxs based on winning feat
for (auto i = range.begin; i < range.end; ++i) for (auto i = range.begin; i < range.end; ++i)
idxs[i].first = fe.extract_feature_value(x[idxs[i].second], best.split_feature); idxs[i].first = fe.extract_feature_value(x[idxs[i].second], best.split_feature);
std::sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first); std::stable_sort(idxs.begin()+range.begin, idxs.begin()+range.end, compare_first);
return best; return best;
} }
......
...@@ -62,15 +62,23 @@ namespace ...@@ -62,15 +62,23 @@ namespace
DLIB_TEST(df.get_num_trees() == 1000); DLIB_TEST(df.get_num_trees() == 1000);
auto result = test_regression_function(df, samples, labels); auto result = test_regression_function(df, samples, labels);
// train: 2.239 0.987173 0.970669 1.1399 // train: 1.95064 0.990374 0.92738 1.04536
dlog << LINFO << "train: " << result; dlog << LINFO << "train: " << result;
DLIB_TEST_MSG(result(0) < 2.3, result(0)); DLIB_TEST_MSG(result(0) < 2.0, result(0));
// By construction, output values should be in the span of the training labels.
const double min_label = min(mat(labels));
const double max_label = max(mat(labels));
for (auto&& x : samples) {
double y = df(x);
DLIB_TEST(min_label <= y && y <= max_label);
}
running_stats<double> rs; running_stats<double> rs;
for (size_t i = 0; i < oobs.size(); ++i) for (size_t i = 0; i < oobs.size(); ++i)
rs.add(std::pow(oobs[i]-labels[i],2.0)); rs.add(std::pow(oobs[i]-labels[i],2.0));
dlog << LINFO << "OOB MSE: "<< rs.mean(); dlog << LINFO << "OOB MSE: "<< rs.mean();
DLIB_TEST_MSG(rs.mean() < 10.2, rs.mean()); DLIB_TEST_MSG(rs.mean() < 10.0, rs.mean());
print_spinner(); print_spinner();
...@@ -80,9 +88,9 @@ namespace ...@@ -80,9 +88,9 @@ namespace
deserialize(df2, ss); deserialize(df2, ss);
DLIB_TEST(df2.get_num_trees() == 1000); DLIB_TEST(df2.get_num_trees() == 1000);
result = test_regression_function(df2, samples, labels); result = test_regression_function(df2, samples, labels);
// train: 2.239 0.987173 0.970669 1.1399 // train: 1.95064 0.990374 0.92738 1.04536
dlog << LINFO << "serialized train results: " << result; dlog << LINFO << "serialized train results: " << result;
DLIB_TEST_MSG(result(0) < 2.3, result(0)); DLIB_TEST_MSG(result(0) < 2.0, result(0));
} }
} a; } a;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment