Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a119639a
Unverified
Commit
a119639a
authored
Sep 19, 2019
by
Guolin Ke
Committed by
GitHub
Sep 19, 2019
Browse files
fix the objective init issues in distributed mode (#2420)
* fix bug * fix include
parent
02374923
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
19 deletions
+35
-19
include/LightGBM/network.h
include/LightGBM/network.h
+22
-18
src/objective/binary_objective.hpp
src/objective/binary_objective.hpp
+6
-1
src/objective/multiclass_objective.hpp
src/objective/multiclass_objective.hpp
+7
-0
No files found.
include/LightGBM/network.h
View file @
a119639a
...
@@ -188,7 +188,6 @@ class Network {
...
@@ -188,7 +188,6 @@ class Network {
});
});
return
global
;
return
global
;
}
}
template
<
class
T
>
template
<
class
T
>
static
T
GlobalSyncUpByMax
(
T
&
local
)
{
static
T
GlobalSyncUpByMax
(
T
&
local
)
{
T
global
=
local
;
T
global
=
local
;
...
@@ -214,25 +213,30 @@ class Network {
...
@@ -214,25 +213,30 @@ class Network {
}
}
template
<
class
T
>
template
<
class
T
>
static
T
GlobalSyncUpBy
Mean
(
T
&
local
)
{
static
T
GlobalSyncUpBy
Sum
(
T
&
local
)
{
T
global
=
(
T
)
0
;
T
global
=
(
T
)
0
;
Allreduce
(
reinterpret_cast
<
char
*>
(
&
local
),
Allreduce
(
reinterpret_cast
<
char
*>
(
&
local
),
sizeof
(
local
),
sizeof
(
local
),
sizeof
(
local
),
sizeof
(
local
),
reinterpret_cast
<
char
*>
(
&
global
),
reinterpret_cast
<
char
*>
(
&
global
),
[](
const
char
*
src
,
char
*
dst
,
int
type_size
,
comm_size_t
len
)
{
[](
const
char
*
src
,
char
*
dst
,
int
type_size
,
comm_size_t
len
)
{
comm_size_t
used_size
=
0
;
comm_size_t
used_size
=
0
;
const
T
*
p1
;
const
T
*
p1
;
T
*
p2
;
T
*
p2
;
while
(
used_size
<
len
)
{
while
(
used_size
<
len
)
{
p1
=
reinterpret_cast
<
const
T
*>
(
src
);
p1
=
reinterpret_cast
<
const
T
*>
(
src
);
p2
=
reinterpret_cast
<
T
*>
(
dst
);
p2
=
reinterpret_cast
<
T
*>
(
dst
);
*
p2
+=
*
p1
;
*
p2
+=
*
p1
;
src
+=
type_size
;
src
+=
type_size
;
dst
+=
type_size
;
dst
+=
type_size
;
used_size
+=
type_size
;
used_size
+=
type_size
;
}
}
});
});
return
static_cast
<
T
>
(
global
/
num_machines_
);
return
static_cast
<
T
>
(
global
);
}
template
<
class
T
>
static
T
GlobalSyncUpByMean
(
T
&
local
)
{
return
static_cast
<
T
>
(
GlobalSyncUpBySum
(
local
)
/
num_machines_
);
}
}
template
<
class
T
>
template
<
class
T
>
...
...
src/objective/binary_objective.hpp
View file @
a119639a
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/objective_function.h>
#include <string>
#include <string>
...
@@ -72,6 +73,11 @@ class BinaryLogloss: public ObjectiveFunction {
...
@@ -72,6 +73,11 @@ class BinaryLogloss: public ObjectiveFunction {
++
cnt_negative
;
++
cnt_negative
;
}
}
}
}
num_pos_data_
=
cnt_positive
;
if
(
Network
::
num_machines
()
>
1
)
{
cnt_positive
=
Network
::
GlobalSyncUpBySum
(
cnt_positive
);
cnt_negative
=
Network
::
GlobalSyncUpBySum
(
cnt_negative
);
}
need_train_
=
true
;
need_train_
=
true
;
if
(
cnt_negative
==
0
||
cnt_positive
==
0
)
{
if
(
cnt_negative
==
0
||
cnt_positive
==
0
)
{
Log
::
Warning
(
"Contains only one class"
);
Log
::
Warning
(
"Contains only one class"
);
...
@@ -96,7 +102,6 @@ class BinaryLogloss: public ObjectiveFunction {
...
@@ -96,7 +102,6 @@ class BinaryLogloss: public ObjectiveFunction {
}
}
}
}
label_weights_
[
1
]
*=
scale_pos_weight_
;
label_weights_
[
1
]
*=
scale_pos_weight_
;
num_pos_data_
=
cnt_positive
;
}
}
void
GetGradients
(
const
double
*
score
,
score_t
*
gradients
,
score_t
*
hessians
)
const
override
{
void
GetGradients
(
const
double
*
score
,
score_t
*
gradients
,
score_t
*
hessians
)
const
override
{
...
...
src/objective/multiclass_objective.hpp
View file @
a119639a
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/objective_function.h>
#include <string>
#include <string>
...
@@ -66,6 +67,12 @@ class MulticlassSoftmax: public ObjectiveFunction {
...
@@ -66,6 +67,12 @@ class MulticlassSoftmax: public ObjectiveFunction {
if
(
weights_
==
nullptr
)
{
if
(
weights_
==
nullptr
)
{
sum_weight
=
num_data_
;
sum_weight
=
num_data_
;
}
}
if
(
Network
::
num_machines
()
>
1
)
{
sum_weight
=
Network
::
GlobalSyncUpBySum
(
sum_weight
);
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
class_init_probs_
[
i
]
=
Network
::
GlobalSyncUpBySum
(
class_init_probs_
[
i
]);
}
}
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_class_
;
++
i
)
{
class_init_probs_
[
i
]
/=
sum_weight
;
class_init_probs_
[
i
]
/=
sum_weight
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment