Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
72b54956
Commit
72b54956
authored
Dec 15, 2017
by
Guolin Ke
Browse files
fix bug in LGBM_NetworkInitWithFunctions
parent
159e9a1e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
34 deletions
+33
-34
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+3
-3
include/LightGBM/meta.h
include/LightGBM/meta.h
+1
-3
include/LightGBM/network.h
include/LightGBM/network.h
+6
-6
src/c_api.cpp
src/c_api.cpp
+13
-12
src/network/network.cpp
src/network/network.cpp
+10
-10
No files found.
include/LightGBM/c_api.h
View file @
72b54956
...
@@ -757,9 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
...
@@ -757,9 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
*/
LIGHTGBM_C_EXPORT
int
LGBM_NetworkFree
();
LIGHTGBM_C_EXPORT
int
LGBM_NetworkFree
();
LIGHTGBM_C_EXPORT
int
LGBM_NetworkInitWithFunctions
(
void
*
A
llreduce
FuncP
tr
,
LIGHTGBM_C_EXPORT
int
LGBM_NetworkInitWithFunctions
(
void
*
a
llreduce
_fun_p
tr
,
void
*
R
educe
S
catter
FuncP
tr
,
void
*
r
educe
_s
catter
_fun_p
tr
,
void
*
A
llgather
FuncP
tr
,
void
*
a
llgather
_fun_p
tr
,
int
num_machines
,
int
num_machines
,
int
rank
);
int
rank
);
...
...
include/LightGBM/meta.h
View file @
72b54956
...
@@ -23,8 +23,6 @@ const double kZeroThreshold = 1e-35f;
...
@@ -23,8 +23,6 @@ const double kZeroThreshold = 1e-35f;
using
ReduceFunction
=
std
::
function
<
void
(
const
char
*
,
char
*
,
int
)
>
;
using
ReduceFunction
=
std
::
function
<
void
(
const
char
*
,
char
*
,
int
)
>
;
typedef
void
(
*
ReduceFunctionInC
)(
const
char
*
,
char
*
,
int
);
using
PredictFunction
=
using
PredictFunction
=
std
::
function
<
void
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
,
double
*
output
)
>
;
std
::
function
<
void
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
,
double
*
output
)
>
;
...
@@ -32,7 +30,7 @@ using AllreduceFunction = std::function<void(char*, int, int, char*, const Reduc
...
@@ -32,7 +30,7 @@ using AllreduceFunction = std::function<void(char*, int, int, char*, const Reduc
using
ReduceScatterFunction
=
std
::
function
<
void
(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunction
&
)
>
;
using
ReduceScatterFunction
=
std
::
function
<
void
(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunction
&
)
>
;
using
AllgatherFunction
=
std
::
function
<
void
(
char
*
,
int
,
char
*
)
>
;
using
AllgatherFunction
=
std
::
function
<
void
(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
)
>
;
#define NO_SPECIFIC (-1)
#define NO_SPECIFIC (-1)
...
...
include/LightGBM/network.h
View file @
72b54956
...
@@ -191,9 +191,9 @@ public:
...
@@ -191,9 +191,9 @@ public:
/*! \brief set variables and function ptrs */
/*! \brief set variables and function ptrs */
static
void
SetRank
(
int
rank
)
{
rank_
=
rank
;}
static
void
SetRank
(
int
rank
)
{
rank_
=
rank
;}
static
void
SetNumMachines
(
int
num_machines
)
{
num_machines_
=
num_machines
;
}
static
void
SetNumMachines
(
int
num_machines
)
{
num_machines_
=
num_machines
;
}
static
void
SetAllReduceFunction
(
AllreduceFunction
A
llreduce
FuncPtr
)
{
A
llreduce
FuncPtr
_
=
A
llreduce
FuncPtr
;}
static
void
SetAllReduceFunction
(
AllreduceFunction
a
llreduce
_ext_fun
)
{
a
llreduce
_ext_fun
_
=
a
llreduce
_ext_fun
;}
static
void
SetReduceScatterFunction
(
ReduceScatterFunction
R
educe
S
catter
FuncPtr
)
{
R
educe
S
catter
FuncPtr
_
=
R
educe
S
catter
FuncPtr
;
}
static
void
SetReduceScatterFunction
(
ReduceScatterFunction
r
educe
_s
catter
_ext_fun
)
{
r
educe
_s
catter
_ext_fun
_
=
r
educe
_s
catter
_ext_fun
;
}
static
void
SetAllgatherFunction
(
AllgatherFunction
A
llgather
FuncPtr
)
{
A
llgather
FuncPtr
_
=
A
llgather
FuncPtr
;
}
static
void
SetAllgatherFunction
(
AllgatherFunction
a
llgather
_ext_fun
)
{
a
llgather
_ext_fun
_
=
a
llgather
_ext_fun
;
}
private:
private:
/*! \brief Number of all machines */
/*! \brief Number of all machines */
...
@@ -215,9 +215,9 @@ private:
...
@@ -215,9 +215,9 @@ private:
/*! \brief Size of buffer_ */
/*! \brief Size of buffer_ */
static
THREAD_LOCAL
int
buffer_size_
;
static
THREAD_LOCAL
int
buffer_size_
;
/*! \brief Funcs*/
/*! \brief Funcs*/
static
THREAD_LOCAL
AllreduceFunction
A
llreduce
FuncPtr
_
;
static
THREAD_LOCAL
AllreduceFunction
a
llreduce
_ext_fun
_
;
static
THREAD_LOCAL
ReduceScatterFunction
R
educe
S
catter
FuncPtr
_
;
static
THREAD_LOCAL
ReduceScatterFunction
r
educe
_s
catter
_ext_fun
_
;
static
THREAD_LOCAL
AllgatherFunction
A
llgather
FuncPtr
_
;
static
THREAD_LOCAL
AllgatherFunction
a
llgather
_ext_fun
_
;
};
};
inline
int
Network
::
rank
()
{
inline
int
Network
::
rank
()
{
...
...
src/c_api.cpp
View file @
72b54956
...
@@ -1220,26 +1220,27 @@ int LGBM_NetworkFree() {
...
@@ -1220,26 +1220,27 @@ int LGBM_NetworkFree() {
API_END
();
API_END
();
}
}
int
LGBM_NetworkInitWithFunctions
(
void
*
A
llreduce
FuncP
tr
,
int
LGBM_NetworkInitWithFunctions
(
void
*
a
llreduce
_fun_p
tr
,
void
*
R
educe
S
catter
FuncP
tr
,
void
*
r
educe
_s
catter
_fun_p
tr
,
void
*
A
llgather
FuncP
tr
,
void
*
a
llgather
_fun_p
tr
,
int
num_machines
,
int
num_machines
,
int
rank
)
{
int
rank
)
{
API_BEGIN
();
API_BEGIN
();
typedef
void
(
*
ReduceFunctionPtr
)(
const
char
*
input
,
char
*
output
,
int
array_size
);
if
(
num_machines
>
1
)
{
if
(
num_machines
>
1
)
{
auto
allreduce_fun
=
[
A
llreduce
FuncP
tr
](
char
*
arg1
,
int
arg2
,
int
arg3
,
char
*
arg4
,
const
ReduceFunction
&
fun
c
)
{
auto
allreduce_fun
=
[
a
llreduce
_fun_p
tr
](
char
*
arg1
,
int
arg2
,
int
arg3
,
char
*
arg4
,
const
ReduceFunction
&
reduce_
fun
)
{
auto
ptr
=
*
fun
c
.
target
<
ReduceFunction
InC
>
();
auto
reduce_fun_ptr
=
*
reduce_
fun
.
target
<
ReduceFunction
Ptr
>
();
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
int
,
char
*
,
const
ReduceFunction
InC
&
))
A
llreduce
FuncP
tr
;
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
int
,
char
*
,
const
ReduceFunction
Ptr
&
))
a
llreduce
_fun_p
tr
;
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
ptr
);
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
reduce_fun_
ptr
);
};
};
Network
::
SetAllReduceFunction
(
allreduce_fun
);
Network
::
SetAllReduceFunction
(
allreduce_fun
);
auto
reduce_scatter_fun
=
[
R
educe
S
catter
FuncP
tr
](
char
*
arg1
,
int
arg2
,
const
int
*
arg3
,
const
int
*
arg4
,
char
*
arg5
,
const
ReduceFunction
&
fun
c
)
{
auto
reduce_scatter_fun
=
[
r
educe
_s
catter
_fun_p
tr
](
char
*
arg1
,
int
arg2
,
const
int
*
arg3
,
const
int
*
arg4
,
char
*
arg5
,
const
ReduceFunction
&
reduce_
fun
)
{
auto
ptr
=
*
fun
c
.
target
<
ReduceFunction
InC
>
();
auto
reduce_fun_ptr
=
*
reduce_
fun
.
target
<
ReduceFunction
Ptr
>
();
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunction
InC
&
))
R
educe
S
catter
FuncP
tr
;
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunction
Ptr
&
))
r
educe
_s
catter
_fun_p
tr
;
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
ptr
);
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
reduce_fun_
ptr
);
};
};
Network
::
SetReduceScatterFunction
(
reduce_scatter_fun
);
Network
::
SetReduceScatterFunction
(
reduce_scatter_fun
);
Network
::
SetAllgatherFunction
((
void
(
*
)(
char
*
,
int
,
char
*
))
A
llgather
FuncP
tr
);
Network
::
SetAllgatherFunction
((
void
(
*
)(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
))
a
llgather
_fun_p
tr
);
Network
::
SetNumMachines
(
num_machines
);
Network
::
SetNumMachines
(
num_machines
);
Network
::
SetRank
(
rank
);
Network
::
SetRank
(
rank
);
}
}
...
...
src/network/network.cpp
View file @
72b54956
...
@@ -19,9 +19,9 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
...
@@ -19,9 +19,9 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL
std
::
vector
<
int
>
Network
::
block_len_
;
THREAD_LOCAL
std
::
vector
<
int
>
Network
::
block_len_
;
THREAD_LOCAL
int
Network
::
buffer_size_
;
THREAD_LOCAL
int
Network
::
buffer_size_
;
THREAD_LOCAL
std
::
vector
<
char
>
Network
::
buffer_
;
THREAD_LOCAL
std
::
vector
<
char
>
Network
::
buffer_
;
THREAD_LOCAL
AllreduceFunction
Network
::
A
llreduce
FuncPtr
_
=
NULL
;
THREAD_LOCAL
AllreduceFunction
Network
::
a
llreduce
_ext_fun
_
=
NULL
;
THREAD_LOCAL
ReduceScatterFunction
Network
::
R
educe
S
catter
FuncPtr
_
=
NULL
;
THREAD_LOCAL
ReduceScatterFunction
Network
::
r
educe
_s
catter
_ext_fun
_
=
NULL
;
THREAD_LOCAL
AllgatherFunction
Network
::
A
llgather
FuncPtr
_
=
NULL
;
THREAD_LOCAL
AllgatherFunction
Network
::
a
llgather
_ext_fun
_
=
NULL
;
void
Network
::
Init
(
NetworkConfig
config
)
{
void
Network
::
Init
(
NetworkConfig
config
)
{
...
@@ -49,8 +49,8 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
...
@@ -49,8 +49,8 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
if
(
num_machines_
<=
1
)
{
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
}
if
(
A
llreduce
FuncPtr
_
!=
NULL
)
{
if
(
a
llreduce
_ext_fun
_
!=
NULL
)
{
return
A
llreduce
FuncPtr
_
(
input
,
input_size
,
type_size
,
output
,
reducer
);
return
a
llreduce
_ext_fun
_
(
input
,
input_size
,
type_size
,
output
,
reducer
);
}
}
int
count
=
input_size
/
type_size
;
int
count
=
input_size
/
type_size
;
// if small package or small count , do it by all gather.(reduce the communication times.)
// if small package or small count , do it by all gather.(reduce the communication times.)
...
@@ -106,9 +106,6 @@ void Network::Allgather(char* input, int send_size, char* output) {
...
@@ -106,9 +106,6 @@ void Network::Allgather(char* input, int send_size, char* output) {
Log
::
Fatal
(
"Please initilize the network interface first"
);
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
}
if
(
num_machines_
<=
1
)
{
return
;
}
if
(
num_machines_
<=
1
)
{
return
;
}
if
(
AllgatherFuncPtr_
!=
NULL
)
{
return
AllgatherFuncPtr_
(
input
,
send_size
,
output
);
}
// assign blocks
// assign blocks
block_start_
[
0
]
=
0
;
block_start_
[
0
]
=
0
;
block_len_
[
0
]
=
send_size
;
block_len_
[
0
]
=
send_size
;
...
@@ -124,6 +121,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
...
@@ -124,6 +121,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
if
(
num_machines_
<=
1
)
{
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
}
if
(
allgather_ext_fun_
!=
NULL
)
{
return
allgather_ext_fun_
(
input
,
all_size
,
block_start
,
block_len
,
output
);
}
int
write_pos
=
0
;
int
write_pos
=
0
;
// use output as receive buffer
// use output as receive buffer
std
::
memcpy
(
output
,
input
,
block_len
[
rank_
]);
std
::
memcpy
(
output
,
input
,
block_len
[
rank_
]);
...
@@ -159,8 +159,8 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
...
@@ -159,8 +159,8 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
if
(
num_machines_
<=
1
)
{
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
}
if
(
R
educe
S
catter
FuncPtr
_
!=
NULL
)
{
if
(
r
educe
_s
catter
_ext_fun
_
!=
NULL
)
{
return
R
educe
S
catter
FuncPtr
_
(
input
,
input_size
,
block_start
,
block_len
,
output
,
reducer
);
return
r
educe
_s
catter
_ext_fun
_
(
input
,
input_size
,
block_start
,
block_len
,
output
,
reducer
);
}
}
if
(
recursive_halving_map_
.
need_pairwise
)
{
if
(
recursive_halving_map_
.
need_pairwise
)
{
for
(
int
i
=
1
;
i
<
num_machines_
;
++
i
)
{
for
(
int
i
=
1
;
i
<
num_machines_
;
++
i
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment