Unverified Commit 414c7b44 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix bug when nan bin is most freq bin (#2811)

* fix bug when nan bin is most freq bin

* fix naming

* fix bug

* add test

* Apply suggestions from code review

* fix more bugs
parent 60710c70
......@@ -137,11 +137,11 @@ class DenseBin: public Bin {
VAL_T th = static_cast<VAL_T>(threshold + min_bin);
const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
VAL_T t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_zero_bin -= 1;
t_most_freq_bin -= 1;
}
data_size_t lte_count = 0;
......@@ -159,17 +159,31 @@ class DenseBin: public Bin {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (t_most_freq_bin == maxb) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (t_most_freq_bin == bin || bin < minb || bin > maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
} else {
......@@ -194,7 +208,7 @@ class DenseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = data_[idx];
if (bin == t_default_bin) {
if (bin == t_zero_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
......
......@@ -142,11 +142,11 @@ class Dense4bitsBin : public Bin {
uint8_t th = static_cast<uint8_t>(threshold + min_bin);
const uint8_t minb = static_cast<uint8_t>(min_bin);
const uint8_t maxb = static_cast<uint8_t>(max_bin);
uint8_t t_default_bin = static_cast<uint8_t>(min_bin + default_bin);
uint8_t t_zero_bin = static_cast<uint8_t>(min_bin + default_bin);
uint8_t t_most_freq_bin = static_cast<uint8_t>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_zero_bin -= 1;
t_most_freq_bin -= 1;
}
data_size_t lte_count = 0;
......@@ -164,17 +164,31 @@ class Dense4bitsBin : public Bin {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (t_most_freq_bin == maxb) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (t_most_freq_bin == bin || bin < minb || bin > maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
} else {
......@@ -199,7 +213,7 @@ class Dense4bitsBin : public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const uint8_t bin = (data_[idx >> 1] >> ((idx & 1) << 2)) & 0xf;
if (bin == t_default_bin) {
if (bin == t_zero_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
......
......@@ -202,11 +202,11 @@ class SparseBin: public Bin {
VAL_T th = static_cast<VAL_T>(threshold + min_bin);
const VAL_T minb = static_cast<VAL_T>(min_bin);
const VAL_T maxb = static_cast<VAL_T>(max_bin);
VAL_T t_default_bin = static_cast<VAL_T>(min_bin + default_bin);
VAL_T t_zero_bin = static_cast<VAL_T>(min_bin + default_bin);
VAL_T t_most_freq_bin = static_cast<VAL_T>(min_bin + most_freq_bin);
if (most_freq_bin == 0) {
th -= 1;
t_default_bin -= 1;
t_zero_bin -= 1;
t_most_freq_bin -= 1;
}
data_size_t lte_count = 0;
......@@ -225,17 +225,31 @@ class SparseBin: public Bin {
missing_default_indices = lte_indices;
missing_default_count = &lte_count;
}
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
if (t_most_freq_bin == maxb) {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (t_most_freq_bin == bin || bin < minb || bin > maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
} else {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin == maxb) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
} else if (bin > th) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
}
} else {
......@@ -260,7 +274,7 @@ class SparseBin: public Bin {
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
const VAL_T bin = iterator.InnerRawGet(idx);
if (bin == t_default_bin) {
if (bin == t_zero_bin) {
missing_default_indices[(*missing_default_count)++] = idx;
} else if (bin < minb || bin > maxb || t_most_freq_bin == bin) {
default_indices[(*default_count)++] = idx;
......
......@@ -139,6 +139,31 @@ class TestEngine(unittest.TestCase):
self.assertLess(ret, 0.005)
self.assertAlmostEqual(evals_result['valid_0']['l2'][-1], ret, places=5)
def test_missing_value_handle_more_na(self):
X_train = np.ones((100, 1))
y_train = np.ones(100)
trues = random.sample(range(100), 80)
for idx in trues:
X_train[idx, 0] = np.nan
y_train[idx] = 0
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_train, y_train)
params = {
'metric': 'l2',
'verbose': -1,
'boost_from_average': False
}
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=20,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
ret = mean_squared_error(y_train, gbm.predict(X_train))
self.assertLess(ret, 0.005)
self.assertAlmostEqual(evals_result['valid_0']['l2'][-1], ret, places=5)
def test_missing_value_handle_na(self):
x = [0, 1, 2, 3, 4, 5, 6, 7, np.nan]
y = [1, 1, 1, 1, 0, 0, 0, 0, 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment