Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
e951a3d7
Commit
e951a3d7
authored
Nov 29, 2017
by
ww
Committed by
Guolin Ke
Nov 29, 2017
Browse files
Network interface with c_api (#1067)
parent
38b65e5f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
66 additions
and
2 deletions
+66
-2
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+8
-0
include/LightGBM/meta.h
include/LightGBM/meta.h
+8
-0
include/LightGBM/network.h
include/LightGBM/network.h
+10
-1
src/c_api.cpp
src/c_api.cpp
+26
-0
src/network/network.cpp
src/network/network.cpp
+14
-1
No files found.
include/LightGBM/c_api.h
View file @
e951a3d7
#ifndef LIGHTGBM_C_API_H_
#define LIGHTGBM_C_API_H_
#include <LightGBM/meta.h>
#include <cstdint>
#include <exception>
#include <stdexcept>
...
...
@@ -754,6 +757,11 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT
int
LGBM_NetworkFree
();
LIGHTGBM_C_EXPORT
int
LGBM_GetFuncions
(
void
*
AllreduceFuncPtr
,
void
*
ReduceScatterFuncPtr
,
void
*
AllgatherFuncPtr
,
int
num_machines
,
int
rank
);
// exception handle and error msg
static
char
*
LastErrorMsg
()
{
static
THREAD_LOCAL
char
err_msg
[
512
]
=
"Everything is fine"
;
return
err_msg
;
}
...
...
include/LightGBM/meta.h
View file @
e951a3d7
...
...
@@ -23,9 +23,17 @@ const double kZeroAsMissingValueRange = 1e-20f;
using
ReduceFunction
=
std
::
function
<
void
(
const
char
*
,
char
*
,
int
)
>
;
typedef
void
(
*
ReduceFunctionInC
)(
const
char
*
,
char
*
,
int
);
using
PredictFunction
=
std
::
function
<
void
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
,
double
*
output
)
>
;
using
AllreduceFunction
=
std
::
function
<
void
(
char
*
,
int
,
int
,
char
*
,
const
ReduceFunction
&
)
>
;
using
ReduceScatterFunction
=
std
::
function
<
void
(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunction
&
)
>
;
using
AllgatherFunction
=
std
::
function
<
void
(
char
*
,
int
,
char
*
)
>
;
#define NO_SPECIFIC (-1)
#if (_MSC_VER <= 1800)
...
...
include/LightGBM/network.h
View file @
e951a3d7
...
...
@@ -188,6 +188,12 @@ public:
});
return
global
;
}
/*! \brief set variables and function ptrs */
static
void
SetRank
(
int
rank
)
{
rank_
=
rank
;}
static
void
SetNumMachines
(
int
num_machines
)
{
num_machines_
=
num_machines
;
}
static
void
SetAllReduce
(
AllreduceFunction
AllreduceFuncPtr
)
{
AllreduceFuncPtr_
=
AllreduceFuncPtr
;}
static
void
SetReduceScatter
(
ReduceScatterFunction
ReduceScatterFuncPtr
)
{
ReduceScatterFuncPtr_
=
ReduceScatterFuncPtr
;
}
static
void
SetAllgather
(
AllgatherFunction
AllgatherFuncPtr
)
{
AllgatherFuncPtr_
=
AllgatherFuncPtr
;
}
private:
/*! \brief Number of all machines */
...
...
@@ -208,7 +214,10 @@ private:
static
THREAD_LOCAL
std
::
vector
<
char
>
buffer_
;
/*! \brief Size of buffer_ */
static
THREAD_LOCAL
int
buffer_size_
;
/*! \brief Funcs*/
static
THREAD_LOCAL
AllreduceFunction
AllreduceFuncPtr_
;
static
THREAD_LOCAL
ReduceScatterFunction
ReduceScatterFuncPtr_
;
static
THREAD_LOCAL
AllgatherFunction
AllgatherFuncPtr_
;
};
inline
int
Network
::
rank
()
{
...
...
src/c_api.cpp
View file @
e951a3d7
...
...
@@ -1220,6 +1220,32 @@ int LGBM_NetworkFree() {
API_END
();
}
int
LGBM_GetFuncions
(
void
*
AllreduceFuncPtr
,
void
*
ReduceScatterFuncPtr
,
void
*
AllgatherFuncPtr
,
int
num_machines
,
int
rank
)
{
API_BEGIN
();
if
(
num_machines
>
1
)
{
auto
func1
=
[
AllreduceFuncPtr
](
char
*
arg1
,
int
arg2
,
int
arg3
,
char
*
arg4
,
const
ReduceFunction
&
func
)
{
auto
ptr
=
*
func
.
target
<
ReduceFunctionInC
>
();
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
int
,
char
*
,
const
ReduceFunctionInC
&
))
AllreduceFuncPtr
;
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
ptr
);
};
Network
::
SetAllReduce
(
func1
);
auto
func2
=
[
ReduceScatterFuncPtr
](
char
*
arg1
,
int
arg2
,
const
int
*
arg3
,
const
int
*
arg4
,
char
*
arg5
,
const
ReduceFunction
&
func
)
{
auto
ptr
=
*
func
.
target
<
ReduceFunctionInC
>
();
auto
tmp
=
(
void
(
*
)(
char
*
,
int
,
const
int
*
,
const
int
*
,
char
*
,
const
ReduceFunctionInC
&
))
ReduceScatterFuncPtr
;
return
tmp
(
arg1
,
arg2
,
arg3
,
arg4
,
arg5
,
ptr
);
};
Network
::
SetReduceScatter
(
func2
);
Network
::
SetAllgather
((
void
(
*
)(
char
*
,
int
,
char
*
))
AllgatherFuncPtr
);
Network
::
SetNumMachines
(
num_machines
);
Network
::
SetRank
(
rank
);
}
API_END
();
}
// ---- start of some help functions
std
::
function
<
std
::
vector
<
double
>
(
int
row_idx
)
>
...
...
src/network/network.cpp
View file @
e951a3d7
...
...
@@ -19,6 +19,10 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL
std
::
vector
<
int
>
Network
::
block_len_
;
THREAD_LOCAL
int
Network
::
buffer_size_
;
THREAD_LOCAL
std
::
vector
<
char
>
Network
::
buffer_
;
THREAD_LOCAL
AllreduceFunction
Network
::
AllreduceFuncPtr_
=
NULL
;
THREAD_LOCAL
ReduceScatterFunction
Network
::
ReduceScatterFuncPtr_
=
NULL
;
THREAD_LOCAL
AllgatherFunction
Network
::
AllgatherFuncPtr_
=
NULL
;
void
Network
::
Init
(
NetworkConfig
config
)
{
if
(
config
.
num_machines
>
1
)
{
...
...
@@ -45,6 +49,9 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
if
(
AllreduceFuncPtr_
!=
NULL
)
{
return
AllreduceFuncPtr_
(
input
,
input_size
,
type_size
,
output
,
reducer
);
}
int
count
=
input_size
/
type_size
;
// if small package or small count , do it by all gather.(reduce the communication times.)
if
(
count
<
num_machines_
||
input_size
<
4096
)
{
...
...
@@ -99,6 +106,9 @@ void Network::Allgather(char* input, int send_size, char* output) {
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
if
(
num_machines_
<=
1
)
{
return
;
}
if
(
AllgatherFuncPtr_
!=
NULL
)
{
return
AllgatherFuncPtr_
(
input
,
send_size
,
output
);
}
// assign blocks
block_start_
[
0
]
=
0
;
block_len_
[
0
]
=
send_size
;
...
...
@@ -145,10 +155,13 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
std
::
reverse
<
char
*>
(
output
+
block_start
[
rank_
],
output
+
all_size
);
}
void
Network
::
ReduceScatter
(
char
*
input
,
int
,
const
int
*
block_start
,
const
int
*
block_len
,
char
*
output
,
const
ReduceFunction
&
reducer
)
{
void
Network
::
ReduceScatter
(
char
*
input
,
int
input_size
,
const
int
*
block_start
,
const
int
*
block_len
,
char
*
output
,
const
ReduceFunction
&
reducer
)
{
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
if
(
ReduceScatterFuncPtr_
!=
NULL
)
{
return
ReduceScatterFuncPtr_
(
input
,
input_size
,
block_start
,
block_len
,
output
,
reducer
);
}
if
(
recursive_halving_map_
.
need_pairwise
)
{
for
(
int
i
=
1
;
i
<
num_machines_
;
++
i
)
{
int
out_rank
=
(
rank_
+
i
)
%
num_machines_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment