"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "e754f23affbe27ae442a8152e0f81d9eda33edc1"
Commit 1bade1e2 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix bug in data/voting parallel.

parent 4d6ff287
......@@ -22,10 +22,7 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
rank_ = Network::rank();
num_machines_ = Network::num_machines();
// allocate buffer for communication
size_t buffer_size = 0;
for (int i = 0; i < num_features_; ++i) {
buffer_size += train_data_->FeatureNumBin(i) * sizeof(HistogramBinEntry);
}
size_t buffer_size = train_data_->NumTotalBin() * sizeof(HistogramBinEntry);
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
......@@ -54,7 +51,11 @@ void DataParallelTreeLearner::BeforeTrain() {
if (is_feature_used_[i]) {
int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
feature_distribution[cur_min_machine].push_back(i);
num_bins_distributed[cur_min_machine] += train_data_->FeatureNumBin(i);
auto num_bin = train_data_->FeatureNumBin(i);
if (train_data_->FeatureBinMapper(i)->GetDefaultBin() == 0) {
num_bin -= 1;
}
num_bins_distributed[cur_min_machine] += num_bin;
}
is_feature_aggregated_[i] = false;
}
......@@ -68,7 +69,11 @@ void DataParallelTreeLearner::BeforeTrain() {
for (int i = 0; i < num_machines_; ++i) {
block_len_[i] = 0;
for (auto fid : feature_distribution[i]) {
block_len_[i] += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry);
auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
block_len_[i] += num_bin * sizeof(HistogramBinEntry);
}
reduce_scatter_size_ += block_len_[i];
}
......@@ -83,7 +88,11 @@ void DataParallelTreeLearner::BeforeTrain() {
for (int i = 0; i < num_machines_; ++i) {
for (auto fid : feature_distribution[i]) {
buffer_write_start_pos_[fid] = bin_size;
bin_size += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry);
auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
}
}
......@@ -91,7 +100,11 @@ void DataParallelTreeLearner::BeforeTrain() {
bin_size = 0;
for (auto fid : feature_distribution[rank_]) {
buffer_read_start_pos_[fid] = bin_size;
bin_size += train_data_->FeatureNumBin(fid) * sizeof(HistogramBinEntry);
auto num_bin = train_data_->FeatureNumBin(fid);
if (train_data_->FeatureBinMapper(fid)->GetDefaultBin() == 0) {
num_bin -= 1;
}
bin_size += num_bin * sizeof(HistogramBinEntry);
}
// sync global data sumup info
......@@ -157,7 +170,7 @@ void DataParallelTreeLearner::FindBestThresholds() {
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_->LeafIndex()),
smaller_leaf_histogram_array_[feature_index].RawData());
SplitInfo smaller_split;
// find best threshold for smaller child
......
......@@ -368,6 +368,12 @@ void VotingParallelTreeLearner::FindBestThresholds() {
// restore from buffer
smaller_leaf_histogram_array_global_[feature_index].FromMemory(
output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold
smaller_leaf_histogram_array_global_[feature_index].FindBestThreshold(
smaller_leaf_splits_global_->sum_gradients(),
......@@ -383,6 +389,12 @@ void VotingParallelTreeLearner::FindBestThresholds() {
SplitInfo larger_split;
// restore from buffer
larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
train_data_->FixHistogram(feature_index,
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_histogram_array_global_[feature_index].RawData());
// find best threshold
larger_leaf_histogram_array_global_[feature_index].FindBestThreshold(
larger_leaf_splits_global_->sum_gradients(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment