Unverified Commit a119639a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix the objective init issues in distributed mode (#2420)

* fix bug

* fix include
parent 02374923
...@@ -188,7 +188,6 @@ class Network { ...@@ -188,7 +188,6 @@ class Network {
}); });
return global; return global;
} }
template<class T> template<class T>
static T GlobalSyncUpByMax(T& local) { static T GlobalSyncUpByMax(T& local) {
T global = local; T global = local;
...@@ -214,25 +213,30 @@ class Network { ...@@ -214,25 +213,30 @@ class Network {
} }
template<class T> template<class T>
static T GlobalSyncUpByMean(T& local) { static T GlobalSyncUpBySum(T& local) {
T global = (T)0; T global = (T)0;
Allreduce(reinterpret_cast<char*>(&local), Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local), sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global), reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int type_size, comm_size_t len) { [](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0; comm_size_t used_size = 0;
const T *p1; const T* p1;
T *p2; T* p2;
while (used_size < len) { while (used_size < len) {
p1 = reinterpret_cast<const T *>(src); p1 = reinterpret_cast<const T*>(src);
p2 = reinterpret_cast<T *>(dst); p2 = reinterpret_cast<T*>(dst);
*p2 += *p1; *p2 += *p1;
src += type_size; src += type_size;
dst += type_size; dst += type_size;
used_size += type_size; used_size += type_size;
} }
}); });
return static_cast<T>(global / num_machines_); return static_cast<T>(global);
}
template<class T>
static T GlobalSyncUpByMean(T& local) {
return static_cast<T>(GlobalSyncUpBySum(local) / num_machines_);
} }
template<class T> template<class T>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_ #ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_ #define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <string> #include <string>
...@@ -72,6 +73,11 @@ class BinaryLogloss: public ObjectiveFunction { ...@@ -72,6 +73,11 @@ class BinaryLogloss: public ObjectiveFunction {
++cnt_negative; ++cnt_negative;
} }
} }
num_pos_data_ = cnt_positive;
if (Network::num_machines() > 1) {
cnt_positive = Network::GlobalSyncUpBySum(cnt_positive);
cnt_negative = Network::GlobalSyncUpBySum(cnt_negative);
}
need_train_ = true; need_train_ = true;
if (cnt_negative == 0 || cnt_positive == 0) { if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Contains only one class"); Log::Warning("Contains only one class");
...@@ -96,7 +102,6 @@ class BinaryLogloss: public ObjectiveFunction { ...@@ -96,7 +102,6 @@ class BinaryLogloss: public ObjectiveFunction {
} }
} }
label_weights_[1] *= scale_pos_weight_; label_weights_[1] *= scale_pos_weight_;
num_pos_data_ = cnt_positive;
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_ #ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_ #define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h> #include <LightGBM/objective_function.h>
#include <string> #include <string>
...@@ -66,6 +67,12 @@ class MulticlassSoftmax: public ObjectiveFunction { ...@@ -66,6 +67,12 @@ class MulticlassSoftmax: public ObjectiveFunction {
if (weights_ == nullptr) { if (weights_ == nullptr) {
sum_weight = num_data_; sum_weight = num_data_;
} }
if (Network::num_machines() > 1) {
sum_weight = Network::GlobalSyncUpBySum(sum_weight);
for (int i = 0; i < num_class_; ++i) {
class_init_probs_[i] = Network::GlobalSyncUpBySum(class_init_probs_[i]);
}
}
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
class_init_probs_[i] /= sum_weight; class_init_probs_[i] /= sum_weight;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment