Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
162509ae
Commit
162509ae
authored
Dec 16, 2017
by
Guolin Ke
Browse files
fix network init with extern functions.
parent
b65f6e65
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
22 deletions
+33
-22
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+3
-4
include/LightGBM/network.h
include/LightGBM/network.h
+4
-5
src/c_api.cpp
src/c_api.cpp
+4
-8
src/network/network.cpp
src/network/network.cpp
+22
-5
No files found.
include/LightGBM/c_api.h
View file @
162509ae
...
...
@@ -757,10 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT
int
LGBM_NetworkFree
();
LIGHTGBM_C_EXPORT
int
LGBM_NetworkInitWithFunctions
(
void
*
reduce_scatter_fun_ptr
,
void
*
allgather_fun_ptr
,
int
num_machines
,
int
rank
);
LIGHTGBM_C_EXPORT
int
LGBM_NetworkInitWithFunctions
(
int
num_machines
,
int
rank
,
void
*
reduce_scatter_ext_fun
,
void
*
allgather_ext_fun
);
// exception handle and error msg
static
char
*
LastErrorMsg
()
{
static
THREAD_LOCAL
char
err_msg
[
512
]
=
"Everything is fine"
;
return
err_msg
;
}
...
...
include/LightGBM/network.h
View file @
162509ae
...
...
@@ -73,6 +73,10 @@ public:
* \param config Config of network setting
*/
static
void
Init
(
NetworkConfig
config
);
/*!
* \brief Initialize
*/
static
void
Init
(
int
num_machines
,
int
rank
,
ReduceScatterFunction
reduce_scatter_ext_fun
,
AllgatherFunction
allgather_ext_fun
);
/*! \brief Free this static class */
static
void
Dispose
();
/*! \brief Get rank of this machine */
...
...
@@ -188,11 +192,6 @@ public:
});
return
global
;
}
/*! \brief set variables and function ptrs */
static
void
SetRank
(
int
rank
)
{
rank_
=
rank
;}
static
void
SetNumMachines
(
int
num_machines
)
{
num_machines_
=
num_machines
;
}
static
void
SetReduceScatterFunction
(
ReduceScatterFunction
reduce_scatter_ext_fun
)
{
reduce_scatter_ext_fun_
=
reduce_scatter_ext_fun
;
}
static
void
SetAllgatherFunction
(
AllgatherFunction
allgather_ext_fun
)
{
allgather_ext_fun_
=
allgather_ext_fun
;
}
private:
/*! \brief Number of all machines */
...
...
src/c_api.cpp
View file @
162509ae
...
...
@@ -1220,16 +1220,12 @@ int LGBM_NetworkFree() {
API_END
();
}
int
LGBM_NetworkInitWithFunctions
(
void
*
reduce_scatter_fun_ptr
,
void
*
allgather_fun_ptr
,
int
num_machines
,
int
rank
)
{
int
LGBM_NetworkInitWithFunctions
(
int
num_machines
,
int
rank
,
void
*
reduce_scatter_ext_fun
,
void
*
allgather_ext_fun
)
{
API_BEGIN
();
if
(
num_machines
>
1
)
{
Network
::
SetReduceScatterFunction
((
ReduceScatterFunction
)
reduce_scatter_fun_ptr
);
Network
::
SetAllgatherFunction
((
AllgatherFunction
)
allgather_fun_ptr
);
Network
::
SetNumMachines
(
num_machines
);
Network
::
SetRank
(
rank
);
Network
::
Init
(
num_machines
,
rank
,
(
ReduceScatterFunction
)
reduce_scatter_ext_fun
,
(
AllgatherFunction
)
allgather_ext_fun
);
}
API_END
();
}
...
...
src/network/network.cpp
View file @
162509ae
...
...
@@ -17,10 +17,10 @@ THREAD_LOCAL BruckMap Network::bruck_map_;
THREAD_LOCAL
RecursiveHalvingMap
Network
::
recursive_halving_map_
;
THREAD_LOCAL
std
::
vector
<
comm_size_t
>
Network
::
block_start_
;
THREAD_LOCAL
std
::
vector
<
comm_size_t
>
Network
::
block_len_
;
THREAD_LOCAL
comm_size_t
Network
::
buffer_size_
;
THREAD_LOCAL
comm_size_t
Network
::
buffer_size_
=
0
;
THREAD_LOCAL
std
::
vector
<
char
>
Network
::
buffer_
;
THREAD_LOCAL
ReduceScatterFunction
Network
::
reduce_scatter_ext_fun_
=
NULL
;
THREAD_LOCAL
AllgatherFunction
Network
::
allgather_ext_fun_
=
NULL
;
THREAD_LOCAL
ReduceScatterFunction
Network
::
reduce_scatter_ext_fun_
=
nullptr
;
THREAD_LOCAL
AllgatherFunction
Network
::
allgather_ext_fun_
=
nullptr
;
void
Network
::
Init
(
NetworkConfig
config
)
{
...
...
@@ -38,10 +38,27 @@ void Network::Init(NetworkConfig config) {
}
}
void
Network
::
Init
(
int
num_machines
,
int
rank
,
ReduceScatterFunction
reduce_scatter_ext_fun
,
AllgatherFunction
allgather_ext_fun
)
{
if
(
num_machines
>
1
)
{
rank_
=
rank
;
num_machines_
=
num_machines
;
block_start_
=
std
::
vector
<
comm_size_t
>
(
num_machines_
);
block_len_
=
std
::
vector
<
comm_size_t
>
(
num_machines_
);
buffer_size_
=
1024
*
1024
;
buffer_
.
resize
(
buffer_size_
);
reduce_scatter_ext_fun_
=
reduce_scatter_ext_fun
;
allgather_ext_fun_
=
allgather_ext_fun
;
Log
::
Info
(
"Local rank: %d, total number of machines: %d"
,
rank_
,
num_machines_
);
}
}
void
Network
::
Dispose
()
{
num_machines_
=
1
;
rank_
=
0
;
linkers_
.
reset
(
new
Linkers
());
reduce_scatter_ext_fun_
=
nullptr
;
allgather_ext_fun_
=
nullptr
;
}
void
Network
::
Allreduce
(
char
*
input
,
comm_size_t
input_size
,
int
type_size
,
char
*
output
,
const
ReduceFunction
&
reducer
)
{
...
...
@@ -117,7 +134,7 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
if
(
allgather_ext_fun_
!=
NULL
)
{
if
(
allgather_ext_fun_
!=
nullptr
)
{
return
allgather_ext_fun_
(
input
,
block_len
[
rank_
],
block_start
,
block_len
,
num_machines_
,
output
,
all_size
);
}
comm_size_t
write_pos
=
0
;
...
...
@@ -155,7 +172,7 @@ void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size,
if
(
num_machines_
<=
1
)
{
Log
::
Fatal
(
"Please initilize the network interface first"
);
}
if
(
reduce_scatter_ext_fun_
!=
NULL
)
{
if
(
reduce_scatter_ext_fun_
!=
nullptr
)
{
return
reduce_scatter_ext_fun_
(
input
,
input_size
,
type_size
,
block_start
,
block_len
,
num_machines_
,
output
,
output_size
,
reducer
);
}
if
(
recursive_halving_map_
.
need_pairwise
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment