Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a178b75b
"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bf1a604a5b5106d6b7f2aa07ea02be12115dcabc"
Commit
a178b75b
authored
Nov 22, 2016
by
Guolin Ke
Browse files
change some c_api interfaces for better compatibility
parent
6837efe7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
56 additions
and
11 deletions
+56
-11
include/LightGBM/boosting.h
include/LightGBM/boosting.h
+1
-0
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+5
-3
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+2
-0
src/boosting/boosting.cpp
src/boosting/boosting.cpp
+1
-1
src/c_api.cpp
src/c_api.cpp
+10
-4
src/io/dataset.cpp
src/io/dataset.cpp
+2
-0
src/io/metadata.cpp
src/io/metadata.cpp
+33
-0
tests/c_api_test/test.py
tests/c_api_test/test.py
+2
-3
No files found.
include/LightGBM/boosting.h
View file @
a178b75b
...
@@ -151,6 +151,7 @@ public:
...
@@ -151,6 +151,7 @@ public:
/*! \brief Disable copy */
/*! \brief Disable copy */
Boosting
(
const
Boosting
&
)
=
delete
;
Boosting
(
const
Boosting
&
)
=
delete
;
static
void
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
);
/*!
/*!
* \brief Create boosting object
* \brief Create boosting object
* \param type Type of boosting
* \param type Type of boosting
...
...
include/LightGBM/c_api.h
View file @
a178b75b
...
@@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
...
@@ -165,7 +165,7 @@ DllExport int LGBM_DatasetSaveBinary(DatesetHandle handle,
* \param field_name field name, can be label, weight, group
* \param field_name field name, can be label, weight, group
* \param field_data pointer to vector
* \param field_data pointer to vector
* \param num_element number of element in field_data
* \param num_element number of element in field_data
* \param type float
_
32
:0,
int32
_t:1
* \param type float32
or
int32
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
DllExport
int
LGBM_DatasetSetField
(
DatesetHandle
handle
,
DllExport
int
LGBM_DatasetSetField
(
DatesetHandle
handle
,
...
@@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
...
@@ -180,7 +180,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
* \param field_name field name
* \param field_name field name
* \param out_len used to set result length
* \param out_len used to set result length
* \param out_ptr pointer to the result
* \param out_ptr pointer to the result
* \param out_type float
_
32
:0,
int32
_t:1
* \param out_type float32
or
int32
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
DllExport
int
LGBM_DatasetGetField
(
DatesetHandle
handle
,
DllExport
int
LGBM_DatasetGetField
(
DatesetHandle
handle
,
...
@@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
...
@@ -216,6 +216,7 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
* \param valid_names names of validation data sets
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2'
* \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster
* \prama out handle of created Booster
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
...
@@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
...
@@ -224,6 +225,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const
char
*
valid_names
[],
const
char
*
valid_names
[],
int
n_valid_datas
,
int
n_valid_datas
,
const
char
*
parameters
,
const
char
*
parameters
,
const
char
*
init_model_filename
,
BoosterHandle
*
out
);
BoosterHandle
*
out
);
/*!
/*!
...
@@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
...
@@ -232,7 +234,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
* \param out handle of created Booster
* \param out handle of created Booster
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
DllExport
int
LGBM_Booster
Load
FromModelfile
(
DllExport
int
LGBM_Booster
Create
FromModelfile
(
const
char
*
filename
,
const
char
*
filename
,
BoosterHandle
*
out
);
BoosterHandle
*
out
);
...
...
include/LightGBM/dataset.h
View file @
a178b75b
...
@@ -83,6 +83,8 @@ public:
...
@@ -83,6 +83,8 @@ public:
void
SetQueryBoundaries
(
const
data_size_t
*
query_boundaries
,
data_size_t
len
);
void
SetQueryBoundaries
(
const
data_size_t
*
query_boundaries
,
data_size_t
len
);
void
SetQueryId
(
const
data_size_t
*
query_id
,
data_size_t
len
);
/*!
/*!
* \brief Set initial scores
* \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score.
* \param init_score Initial scores, this class will manage memory for init_score.
...
...
src/boosting/boosting.cpp
View file @
a178b75b
...
@@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) {
...
@@ -15,7 +15,7 @@ BoostingType GetBoostingTypeFromModelFile(const char* filename) {
return
BoostingType
::
kUnknow
;
return
BoostingType
::
kUnknow
;
}
}
void
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
void
Boosting
::
LoadFileToBoosting
(
Boosting
*
boosting
,
const
char
*
filename
)
{
if
(
boosting
!=
nullptr
)
{
if
(
boosting
!=
nullptr
)
{
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
TextReader
<
size_t
>
model_reader
(
filename
,
true
);
model_reader
.
ReadAllLines
();
model_reader
.
ReadAllLines
();
...
...
src/c_api.cpp
View file @
a178b75b
...
@@ -82,11 +82,12 @@ public:
...
@@ -82,11 +82,12 @@ public:
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
[
i
]));
Common
::
ConstPtrInVectorWrapper
<
Metric
>
(
valid_metrics_
[
i
]));
}
}
}
}
void
LoadModelFromFile
(
const
char
*
filename
)
{
Boosting
::
LoadFileToBoosting
(
boosting_
.
get
(),
filename
);
}
~
Booster
()
{
~
Booster
()
{
}
}
bool
TrainOneIter
()
{
bool
TrainOneIter
()
{
return
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
return
boosting_
->
TrainOneIter
(
nullptr
,
nullptr
,
false
);
}
}
...
@@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
...
@@ -414,6 +415,7 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const
char
*
valid_names
[],
const
char
*
valid_names
[],
int
n_valid_datas
,
int
n_valid_datas
,
const
char
*
parameters
,
const
char
*
parameters
,
const
char
*
init_model_filename
,
BoosterHandle
*
out
)
{
BoosterHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
const
Dataset
*
p_train_data
=
reinterpret_cast
<
const
Dataset
*>
(
train_data
);
const
Dataset
*
p_train_data
=
reinterpret_cast
<
const
Dataset
*>
(
train_data
);
...
@@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
...
@@ -423,11 +425,15 @@ DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
p_valid_datas
.
emplace_back
(
reinterpret_cast
<
const
Dataset
*>
(
valid_datas
[
i
]));
p_valid_datas
.
emplace_back
(
reinterpret_cast
<
const
Dataset
*>
(
valid_datas
[
i
]));
p_valid_names
.
emplace_back
(
valid_names
[
i
]);
p_valid_names
.
emplace_back
(
valid_names
[
i
]);
}
}
*
out
=
new
Booster
(
p_train_data
,
p_valid_datas
,
p_valid_names
,
parameters
);
auto
ret
=
std
::
unique_ptr
<
Booster
>
(
new
Booster
(
p_train_data
,
p_valid_datas
,
p_valid_names
,
parameters
));
if
(
init_model_filename
!=
nullptr
)
{
ret
->
LoadModelFromFile
(
init_model_filename
);
}
*
out
=
ret
.
release
();
API_END
();
API_END
();
}
}
DllExport
int
LGBM_Booster
Load
FromModelfile
(
DllExport
int
LGBM_Booster
Create
FromModelfile
(
const
char
*
filename
,
const
char
*
filename
,
BoosterHandle
*
out
)
{
BoosterHandle
*
out
)
{
API_BEGIN
();
API_BEGIN
();
...
...
src/io/dataset.cpp
View file @
a178b75b
...
@@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
...
@@ -78,6 +78,8 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name
=
Common
::
Trim
(
name
);
name
=
Common
::
Trim
(
name
);
if
(
name
==
std
::
string
(
"query"
)
||
name
==
std
::
string
(
"group"
))
{
if
(
name
==
std
::
string
(
"query"
)
||
name
==
std
::
string
(
"group"
))
{
metadata_
.
SetQueryBoundaries
(
field_data
,
num_element
);
metadata_
.
SetQueryBoundaries
(
field_data
,
num_element
);
}
else
if
(
name
==
std
::
string
(
"query_id"
)
||
name
==
std
::
string
(
"group_id"
))
{
metadata_
.
SetQueryId
(
field_data
,
num_element
);
}
else
{
}
else
{
return
false
;
return
false
;
}
}
...
...
src/io/metadata.cpp
View file @
a178b75b
...
@@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
...
@@ -248,6 +248,39 @@ void Metadata::SetQueryBoundaries(const data_size_t* query_boundaries, data_size
LoadQueryWeights
();
LoadQueryWeights
();
}
}
void
Metadata
::
SetQueryId
(
const
data_size_t
*
query_id
,
data_size_t
len
)
{
if
(
num_data_
!=
len
)
{
Log
::
Fatal
(
"len of query id is not same with #data"
);
}
if
(
queries_
.
size
()
>
0
)
{
queries_
.
clear
();
}
queries_
=
std
::
vector
<
data_size_t
>
(
num_data_
);
for
(
data_size_t
i
=
0
;
i
<
num_weights_
;
++
i
)
{
queries_
[
i
]
=
query_id
[
i
];
}
// need convert query_id to boundaries
std
::
vector
<
data_size_t
>
tmp_buffer
;
data_size_t
last_qid
=
-
1
;
data_size_t
cur_cnt
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
++
i
)
{
if
(
last_qid
!=
queries_
[
i
])
{
if
(
cur_cnt
>
0
)
{
tmp_buffer
.
push_back
(
cur_cnt
);
}
cur_cnt
=
0
;
last_qid
=
queries_
[
i
];
}
++
cur_cnt
;
}
tmp_buffer
.
push_back
(
cur_cnt
);
query_boundaries_
=
std
::
vector
<
data_size_t
>
(
tmp_buffer
.
size
()
+
1
);
num_queries_
=
static_cast
<
data_size_t
>
(
tmp_buffer
.
size
());
query_boundaries_
[
0
]
=
0
;
for
(
size_t
i
=
0
;
i
<
tmp_buffer
.
size
();
++
i
)
{
query_boundaries_
[
i
+
1
]
=
query_boundaries_
[
i
]
+
tmp_buffer
[
i
];
}
queries_
.
clear
();
LoadQueryWeights
();
}
void
Metadata
::
LoadWeights
()
{
void
Metadata
::
LoadWeights
()
{
num_weights_
=
0
;
num_weights_
=
0
;
...
...
tests/c_api_test/test.py
View file @
a178b75b
...
@@ -178,7 +178,7 @@ def test_booster():
...
@@ -178,7 +178,7 @@ def test_booster():
name
=
[
c_str
(
'test'
)]
name
=
[
c_str
(
'test'
)]
booster
=
ctypes
.
c_void_p
()
booster
=
ctypes
.
c_void_p
()
LIB
.
LGBM_BoosterCreate
(
train
,
c_array
(
ctypes
.
c_void_p
,
test
),
c_array
(
ctypes
.
c_char_p
,
name
),
LIB
.
LGBM_BoosterCreate
(
train
,
c_array
(
ctypes
.
c_void_p
,
test
),
c_array
(
ctypes
.
c_char_p
,
name
),
len
(
test
),
c_str
(
"app=binary metric=auc num_leaves=31 verbose=0"
),
ctypes
.
byref
(
booster
))
len
(
test
),
c_str
(
"app=binary metric=auc num_leaves=31 verbose=0"
),
None
,
ctypes
.
byref
(
booster
))
is_finished
=
ctypes
.
c_int
(
0
)
is_finished
=
ctypes
.
c_int
(
0
)
for
i
in
range
(
100
):
for
i
in
range
(
100
):
LIB
.
LGBM_BoosterUpdateOneIter
(
booster
,
ctypes
.
byref
(
is_finished
))
LIB
.
LGBM_BoosterUpdateOneIter
(
booster
,
ctypes
.
byref
(
is_finished
))
...
@@ -191,7 +191,7 @@ def test_booster():
...
@@ -191,7 +191,7 @@ def test_booster():
test_free_dataset
(
train
)
test_free_dataset
(
train
)
test_free_dataset
(
test
[
0
])
test_free_dataset
(
test
[
0
])
booster2
=
ctypes
.
c_void_p
()
booster2
=
ctypes
.
c_void_p
()
LIB
.
LGBM_Booster
Load
FromModelfile
(
c_str
(
'model.txt'
),
ctypes
.
byref
(
booster2
))
LIB
.
LGBM_Booster
Create
FromModelfile
(
c_str
(
'model.txt'
),
ctypes
.
byref
(
booster2
))
data
=
[]
data
=
[]
inp
=
open
(
'../../examples/binary_classification/binary.test'
,
'r'
)
inp
=
open
(
'../../examples/binary_classification/binary.test'
,
'r'
)
for
line
in
inp
.
readlines
():
for
line
in
inp
.
readlines
():
...
@@ -214,4 +214,3 @@ def test_booster():
...
@@ -214,4 +214,3 @@ def test_booster():
test_dataset
()
test_dataset
()
test_booster
()
test_booster
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment