Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
1466f907
Commit
1466f907
authored
Dec 05, 2016
by
Guolin Ke
Committed by
GitHub
Dec 05, 2016
Browse files
Categorical feature support (#108)
Categorical feature support (#108)
parent
531352f6
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
697 additions
and
291 deletions
+697
-291
include/LightGBM/bin.h
include/LightGBM/bin.h
+60
-17
include/LightGBM/c_api.h
include/LightGBM/c_api.h
+12
-0
include/LightGBM/config.h
include/LightGBM/config.h
+16
-4
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+9
-1
include/LightGBM/dataset_loader.h
include/LightGBM/dataset_loader.h
+9
-6
include/LightGBM/feature.h
include/LightGBM/feature.h
+5
-3
include/LightGBM/tree.h
include/LightGBM/tree.h
+48
-6
include/LightGBM/utils/common.h
include/LightGBM/utils/common.h
+54
-48
python-package/lightgbm/basic.py
python-package/lightgbm/basic.py
+37
-1
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+34
-8
python-package/lightgbm/sklearn.py
python-package/lightgbm/sklearn.py
+14
-3
src/application/application.cpp
src/application/application.cpp
+4
-4
src/application/predictor.hpp
src/application/predictor.hpp
+1
-1
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+36
-6
src/boosting/gbdt.h
src/boosting/gbdt.h
+2
-0
src/c_api.cpp
src/c_api.cpp
+19
-6
src/io/bin.cpp
src/io/bin.cpp
+178
-79
src/io/config.cpp
src/io/config.cpp
+4
-3
src/io/dataset_loader.cpp
src/io/dataset_loader.cpp
+130
-93
src/io/dense_bin.hpp
src/io/dense_bin.hpp
+25
-2
No files found.
include/LightGBM/bin.h
View file @
1466f907
...
@@ -5,9 +5,14 @@
...
@@ -5,9 +5,14 @@
#include <vector>
#include <vector>
#include <functional>
#include <functional>
#include <unordered_map>
namespace
LightGBM
{
namespace
LightGBM
{
enum
BinType
{
NumericalBin
,
CategoricalBin
};
/*! \brief Store data for one histogram bin */
/*! \brief Store data for one histogram bin */
struct
HistogramBinEntry
{
struct
HistogramBinEntry
{
...
@@ -55,9 +60,20 @@ public:
...
@@ -55,9 +60,20 @@ public:
if
(
num_bin_
!=
other
.
num_bin_
)
{
if
(
num_bin_
!=
other
.
num_bin_
)
{
return
false
;
return
false
;
}
}
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
if
(
bin_type_
!=
other
.
bin_type_
)
{
if
(
bin_upper_bound_
[
i
]
!=
other
.
bin_upper_bound_
[
i
])
{
return
false
;
return
false
;
}
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
if
(
bin_upper_bound_
[
i
]
!=
other
.
bin_upper_bound_
[
i
])
{
return
false
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
num_bin_
;
i
++
)
{
if
(
bin_2_categorical_
[
i
]
!=
other
.
bin_2_categorical_
[
i
])
{
return
false
;
}
}
}
}
}
return
true
;
return
true
;
...
@@ -80,7 +96,11 @@ public:
...
@@ -80,7 +96,11 @@ public:
* \return Feature value of this bin
* \return Feature value of this bin
*/
*/
inline
double
BinToValue
(
unsigned
int
bin
)
const
{
inline
double
BinToValue
(
unsigned
int
bin
)
const
{
return
bin_upper_bound_
[
bin
];
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
return
bin_upper_bound_
[
bin
];
}
else
{
return
bin_2_categorical_
[
bin
];
}
}
}
/*!
/*!
* \brief Get sizes in byte of this object
* \brief Get sizes in byte of this object
...
@@ -97,8 +117,9 @@ public:
...
@@ -97,8 +117,9 @@ public:
* \brief Construct feature value to bin mapper according feature values
* \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature
* \param values (Sampled) values of this feature
* \param max_bin The maximal number of bin
* \param max_bin The maximal number of bin
* \param bin_type Type of this bin
*/
*/
void
FindBin
(
std
::
vector
<
double
>*
values
,
size_t
total_sample_cnt
,
int
max_bin
);
void
FindBin
(
std
::
vector
<
double
>*
values
,
size_t
total_sample_cnt
,
int
max_bin
,
BinType
bin_type
);
/*!
/*!
* \brief Use specific number of bin to calculate the size of this class
* \brief Use specific number of bin to calculate the size of this class
...
@@ -119,6 +140,7 @@ public:
...
@@ -119,6 +140,7 @@ public:
*/
*/
void
CopyFrom
(
const
char
*
buffer
);
void
CopyFrom
(
const
char
*
buffer
);
inline
BinType
bin_type
()
const
{
return
bin_type_
;
}
private:
private:
/*! \brief Number of bins */
/*! \brief Number of bins */
int
num_bin_
;
int
num_bin_
;
...
@@ -128,6 +150,12 @@ private:
...
@@ -128,6 +150,12 @@ private:
bool
is_trival_
;
bool
is_trival_
;
/*! \brief Sparse rate of this bins( num_bin0/num_data ) */
/*! \brief Sparse rate of this bins( num_bin0/num_data ) */
double
sparse_rate_
;
double
sparse_rate_
;
/*! \brief Type of this bin */
BinType
bin_type_
;
/*! \brief Mapper from categorical to bin */
std
::
unordered_map
<
int
,
unsigned
int
>
categorical_2_bin_
;
/*! \brief Mapper from bin to categorical */
std
::
vector
<
int
>
bin_2_categorical_
;
};
};
/*!
/*!
...
@@ -257,7 +285,8 @@ public:
...
@@ -257,7 +285,8 @@ public:
* \return The number of less than or equal data.
* \return The number of less than or equal data.
*/
*/
virtual
data_size_t
Split
(
virtual
data_size_t
Split
(
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
data_size_t
*
lte_indices
,
data_size_t
*
gt_indices
)
const
=
0
;
data_size_t
*
lte_indices
,
data_size_t
*
gt_indices
)
const
=
0
;
/*!
/*!
...
@@ -280,44 +309,58 @@ public:
...
@@ -280,44 +309,58 @@ public:
* \param is_enable_sparse True if enable sparse feature
* \param is_enable_sparse True if enable sparse feature
* \param is_sparse Will set to true if this bin is sparse
* \param is_sparse Will set to true if this bin is sparse
* \param default_bin Default bin for zeros value
* \param default_bin Default bin for zeros value
* \param bin_type type of bin
* \return The bin data object
* \return The bin data object
*/
*/
static
Bin
*
CreateBin
(
data_size_t
num_data
,
int
num_bin
,
static
Bin
*
CreateBin
(
data_size_t
num_data
,
int
num_bin
,
double
sparse_rate
,
bool
is_enable_sparse
,
bool
*
is_sparse
,
int
default_bin
);
double
sparse_rate
,
bool
is_enable_sparse
,
bool
*
is_sparse
,
int
default_bin
,
BinType
bin_type
);
/*!
/*!
* \brief Create object for bin data of one feature, used for dense feature
* \brief Create object for bin data of one feature, used for dense feature
* \param num_data Total number of data
* \param num_data Total number of data
* \param num_bin Number of bin
* \param num_bin Number of bin
* \param default_bin Default bin for zeros value
* \param default_bin Default bin for zeros value
* \param bin_type type of bin
* \return The bin data object
* \return The bin data object
*/
*/
static
Bin
*
CreateDenseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
);
static
Bin
*
CreateDenseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
,
BinType
bin_type
);
/*!
/*!
* \brief Create object for bin data of one feature, used for sparse feature
* \brief Create object for bin data of one feature, used for sparse feature
* \param num_data Total number of data
* \param num_data Total number of data
* \param num_bin Number of bin
* \param num_bin Number of bin
* \param default_bin Default bin for zeros value
* \param default_bin Default bin for zeros value
* \param bin_type type of bin
* \return The bin data object
* \return The bin data object
*/
*/
static
Bin
*
CreateSparseBin
(
data_size_t
num_data
,
static
Bin
*
CreateSparseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
);
int
num_bin
,
int
default_bin
,
BinType
bin_type
);
};
};
inline
unsigned
int
BinMapper
::
ValueToBin
(
double
value
)
const
{
inline
unsigned
int
BinMapper
::
ValueToBin
(
double
value
)
const
{
// binary search to find bin
// binary search to find bin
int
l
=
0
;
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
int
r
=
num_bin_
-
1
;
int
l
=
0
;
while
(
l
<
r
)
{
int
r
=
num_bin_
-
1
;
int
m
=
(
r
+
l
-
1
)
/
2
;
while
(
l
<
r
)
{
if
(
value
<=
bin_upper_bound_
[
m
])
{
int
m
=
(
r
+
l
-
1
)
/
2
;
r
=
m
;
if
(
value
<=
bin_upper_bound_
[
m
])
{
r
=
m
;
}
else
{
l
=
m
+
1
;
}
}
return
l
;
}
else
{
int
int_value
=
static_cast
<
int
>
(
value
);
if
(
categorical_2_bin_
.
count
(
int_value
))
{
return
categorical_2_bin_
.
at
(
int_value
);
}
else
{
}
else
{
l
=
m
+
1
;
return
num_bin_
-
1
;
}
}
}
}
return
l
;
}
}
}
// namespace LightGBM
}
// namespace LightGBM
...
...
include/LightGBM/c_api.h
View file @
1466f907
...
@@ -153,6 +153,18 @@ DllExport int LGBM_DatasetGetSubset(
...
@@ -153,6 +153,18 @@ DllExport int LGBM_DatasetGetSubset(
const
char
*
parameters
,
const
char
*
parameters
,
DatesetHandle
*
out
);
DatesetHandle
*
out
);
/*!
* \brief save feature names to Dataset
* \param handle handle
* \param feature_names feature names
* \param num_feature_names number of feature names
* \return 0 when succeed, -1 when failure happens
*/
DllExport
int
LGBM_DatasetSetFeatureNames
(
DatesetHandle
handle
,
const
char
**
feature_names
,
int64_t
num_feature_names
);
/*!
/*!
* \brief free space for dataset
* \brief free space for dataset
* \return 0 when succeed, -1 when failure happens
* \return 0 when succeed, -1 when failure happens
...
...
include/LightGBM/config.h
View file @
1466f907
...
@@ -114,14 +114,21 @@ public:
...
@@ -114,14 +114,21 @@ public:
* And add an prefix "name:" while using column name */
* And add an prefix "name:" while using column name */
std
::
string
label_column
=
""
;
std
::
string
label_column
=
""
;
/*! \brief Index or column name of weight, < 0 means not used
/*! \brief Index or column name of weight, < 0 means not used
* And add an prefix "name:" while using column name */
* And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */
std
::
string
weight_column
=
""
;
std
::
string
weight_column
=
""
;
/*! \brief Index or column name of group, < 0 means not used */
/*! \brief Index or column name of group/query id, < 0 means not used
* And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */
std
::
string
group_column
=
""
;
std
::
string
group_column
=
""
;
/*! \brief ignored features, separate by ','
/*! \brief ignored features, separate by ','
* e.g. name:column_name1,column_name2 */
* And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */
std
::
string
ignore_column
=
""
;
std
::
string
ignore_column
=
""
;
/*! \brief specific categorical columns, Note:only support for integer type categorical
* And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */
std
::
string
categorical_column
=
""
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
void
Set
(
const
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
params
)
override
;
};
};
...
@@ -368,8 +375,13 @@ struct ParameterAlias {
...
@@ -368,8 +375,13 @@ struct ParameterAlias {
{
"query_column"
,
"group_column"
},
{
"query_column"
,
"group_column"
},
{
"ignore_feature"
,
"ignore_column"
},
{
"ignore_feature"
,
"ignore_column"
},
{
"blacklist"
,
"ignore_column"
},
{
"blacklist"
,
"ignore_column"
},
{
"categorical_feature"
,
"categorical_column"
},
{
"cat_column"
,
"categorical_column"
},
{
"cat_feature"
,
"categorical_column"
},
{
"predict_raw_score"
,
"is_predict_raw_score"
},
{
"predict_raw_score"
,
"is_predict_raw_score"
},
{
"predict_leaf_index"
,
"is_predict_leaf_index"
},
{
"predict_leaf_index"
,
"is_predict_leaf_index"
},
{
"raw_score"
,
"is_predict_raw_score"
},
{
"leaf_index"
,
"is_predict_leaf_index"
},
{
"min_split_gain"
,
"min_gain_to_split"
},
{
"min_split_gain"
,
"min_gain_to_split"
},
{
"reg_alpha"
,
"lambda_l1"
},
{
"reg_alpha"
,
"lambda_l1"
},
{
"reg_lambda"
,
"lambda_l2"
},
{
"reg_lambda"
,
"lambda_l2"
},
...
...
include/LightGBM/dataset.h
View file @
1466f907
...
@@ -378,7 +378,15 @@ public:
...
@@ -378,7 +378,15 @@ public:
inline
int
label_idx
()
const
{
return
label_idx_
;
}
inline
int
label_idx
()
const
{
return
label_idx_
;
}
/*! \brief Get names of current data set */
/*! \brief Get names of current data set */
inline
std
::
vector
<
std
::
string
>
feature_names
()
const
{
return
feature_names_
;
}
inline
const
std
::
vector
<
std
::
string
>&
feature_names
()
const
{
return
feature_names_
;
}
inline
void
set_feature_names
(
const
std
::
vector
<
std
::
string
>&
feature_names
)
{
if
(
feature_names
.
size
()
!=
static_cast
<
size_t
>
(
num_total_features_
))
{
Log
::
Warning
(
"size of feature_names error, should equal with total number of features"
);
return
;
}
feature_names_
=
std
::
vector
<
std
::
string
>
(
feature_names
);
}
/*! \brief Get Number of data */
/*! \brief Get Number of data */
inline
data_size_t
num_data
()
const
{
return
num_data_
;
}
inline
data_size_t
num_data
()
const
{
return
num_data_
;
}
...
...
include/LightGBM/dataset_loader.h
View file @
1466f907
...
@@ -8,12 +8,10 @@ namespace LightGBM {
...
@@ -8,12 +8,10 @@ namespace LightGBM {
class
DatasetLoader
{
class
DatasetLoader
{
public:
public:
DatasetLoader
(
const
IOConfig
&
io_config
,
const
PredictFunction
&
predict_fun
);
DatasetLoader
(
const
IOConfig
&
io_config
,
const
PredictFunction
&
predict_fun
,
const
char
*
filename
);
~
DatasetLoader
();
~
DatasetLoader
();
void
SetHeader
(
const
char
*
filename
);
Dataset
*
LoadFromFile
(
const
char
*
filename
,
int
rank
,
int
num_machines
);
Dataset
*
LoadFromFile
(
const
char
*
filename
,
int
rank
,
int
num_machines
);
Dataset
*
LoadFromFile
(
const
char
*
filename
)
{
Dataset
*
LoadFromFile
(
const
char
*
filename
)
{
...
@@ -32,6 +30,9 @@ public:
...
@@ -32,6 +30,9 @@ public:
DatasetLoader
(
const
DatasetLoader
&
)
=
delete
;
DatasetLoader
(
const
DatasetLoader
&
)
=
delete
;
private:
private:
void
SetHeader
(
const
char
*
filename
);
void
CheckDataset
(
const
Dataset
*
dataset
);
void
CheckDataset
(
const
Dataset
*
dataset
);
std
::
vector
<
std
::
string
>
LoadTextDataToMemory
(
const
char
*
filename
,
const
Metadata
&
metadata
,
int
rank
,
int
num_machines
,
int
*
num_global_data
,
std
::
vector
<
data_size_t
>*
used_data_indices
);
std
::
vector
<
std
::
string
>
LoadTextDataToMemory
(
const
char
*
filename
,
const
Metadata
&
metadata
,
int
rank
,
int
num_machines
,
int
*
num_global_data
,
std
::
vector
<
data_size_t
>*
used_data_indices
);
...
@@ -57,15 +58,17 @@ private:
...
@@ -57,15 +58,17 @@ private:
/*! \brief prediction function for initial model */
/*! \brief prediction function for initial model */
const
PredictFunction
&
predict_fun_
;
const
PredictFunction
&
predict_fun_
;
/*! \brief index of label column */
/*! \brief index of label column */
int
label_idx_
=
0
;
int
label_idx_
;
/*! \brief index of weight column */
/*! \brief index of weight column */
int
weight_idx_
=
NO_SPECIFIC
;
int
weight_idx_
;
/*! \brief index of group column */
/*! \brief index of group column */
int
group_idx_
=
NO_SPECIFIC
;
int
group_idx_
;
/*! \brief Mapper from real feature index to used index*/
/*! \brief Mapper from real feature index to used index*/
std
::
unordered_set
<
int
>
ignore_features_
;
std
::
unordered_set
<
int
>
ignore_features_
;
/*! \brief store feature names */
/*! \brief store feature names */
std
::
vector
<
std
::
string
>
feature_names_
;
std
::
vector
<
std
::
string
>
feature_names_
;
/*! \brief Mapper from real feature index to used index*/
std
::
unordered_set
<
int
>
categorical_features_
;
};
};
...
...
include/LightGBM/feature.h
View file @
1466f907
...
@@ -27,7 +27,7 @@ public:
...
@@ -27,7 +27,7 @@ public:
:
bin_mapper_
(
bin_mapper
)
{
:
bin_mapper_
(
bin_mapper
)
{
feature_index_
=
feature_idx
;
feature_index_
=
feature_idx
;
bin_data_
.
reset
(
Bin
::
CreateBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_data_
.
reset
(
Bin
::
CreateBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_mapper_
->
sparse_rate
(),
is_enable_sparse
,
&
is_sparse_
,
bin_mapper_
->
ValueToBin
(
0
)));
bin_mapper_
->
sparse_rate
(),
is_enable_sparse
,
&
is_sparse_
,
bin_mapper_
->
ValueToBin
(
0
)
,
bin_mapper_
->
bin_type
()
));
}
}
/*!
/*!
* \brief Constructor from memory
* \brief Constructor from memory
...
@@ -52,9 +52,9 @@ public:
...
@@ -52,9 +52,9 @@ public:
num_data
=
static_cast
<
data_size_t
>
(
local_used_indices
.
size
());
num_data
=
static_cast
<
data_size_t
>
(
local_used_indices
.
size
());
}
}
if
(
is_sparse_
)
{
if
(
is_sparse_
)
{
bin_data_
.
reset
(
Bin
::
CreateSparseBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_mapper_
->
ValueToBin
(
0
)));
bin_data_
.
reset
(
Bin
::
CreateSparseBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_mapper_
->
ValueToBin
(
0
)
,
bin_mapper_
->
bin_type
()
));
}
else
{
}
else
{
bin_data_
.
reset
(
Bin
::
CreateDenseBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_mapper_
->
ValueToBin
(
0
)));
bin_data_
.
reset
(
Bin
::
CreateDenseBin
(
num_data
,
bin_mapper_
->
num_bin
(),
bin_mapper_
->
ValueToBin
(
0
)
,
bin_mapper_
->
bin_type
()
));
}
}
// get bin data
// get bin data
bin_data_
->
LoadFromMemory
(
memory_ptr
,
local_used_indices
);
bin_data_
->
LoadFromMemory
(
memory_ptr
,
local_used_indices
);
...
@@ -90,6 +90,8 @@ public:
...
@@ -90,6 +90,8 @@ public:
inline
const
BinMapper
*
bin_mapper
()
const
{
return
bin_mapper_
.
get
();
}
inline
const
BinMapper
*
bin_mapper
()
const
{
return
bin_mapper_
.
get
();
}
/*! \brief Number of bin of this feature */
/*! \brief Number of bin of this feature */
inline
int
num_bin
()
const
{
return
bin_mapper_
->
num_bin
();
}
inline
int
num_bin
()
const
{
return
bin_mapper_
->
num_bin
();
}
inline
BinType
bin_type
()
const
{
return
bin_mapper_
->
bin_type
();
}
/*! \brief Get bin data of this feature */
/*! \brief Get bin data of this feature */
inline
const
Bin
*
bin_data
()
const
{
return
bin_data_
.
get
();
}
inline
const
Bin
*
bin_data
()
const
{
return
bin_data_
.
get
();
}
/*!
/*!
...
...
include/LightGBM/tree.h
View file @
1466f907
...
@@ -35,17 +35,20 @@ public:
...
@@ -35,17 +35,20 @@ public:
* \brief Performing a split on tree leaves.
* \brief Performing a split on tree leaves.
* \param leaf Index of leaf to be split
* \param leaf Index of leaf to be split
* \param feature Index of feature; the converted index after removing useless features
* \param feature Index of feature; the converted index after removing useless features
* \param bin_type type of this feature, numerical or categorical
* \param threshold Threshold(bin) of split
* \param threshold Threshold(bin) of split
* \param real_feature Index of feature, the original index on data
* \param real_feature Index of feature, the original index on data
* \param threshold_double Threshold on feature value
* \param threshold_double Threshold on feature value
* \param left_value Model Left child output
* \param left_value Model Left child output
* \param right_value Model Right child output
* \param right_value Model Right child output
* \param left_cnt Count of left child
* \param right_cnt Count of right child
* \param gain Split gain
* \param gain Split gain
* \return The index of new leaf.
* \return The index of new leaf.
*/
*/
int
Split
(
int
leaf
,
int
feature
,
unsigned
int
threshold
,
int
real_feature
,
int
Split
(
int
leaf
,
int
feature
,
BinType
bin_type
,
unsigned
int
threshold
,
int
real_feature
,
double
threshold_double
,
double
left_value
,
double
threshold_double
,
double
left_value
,
double
right_value
,
double
gain
);
double
right_value
,
data_size_t
left_cnt
,
data_size_t
right_cnt
,
double
gain
);
/*! \brief Get the output of one leave */
/*! \brief Get the output of one leave */
inline
double
LeafOutput
(
int
leaf
)
const
{
return
leaf_value_
[
leaf
];
}
inline
double
LeafOutput
(
int
leaf
)
const
{
return
leaf_value_
[
leaf
];
}
...
@@ -104,6 +107,35 @@ public:
...
@@ -104,6 +107,35 @@ public:
/*! \brief Serialize this object to json*/
/*! \brief Serialize this object to json*/
std
::
string
ToJSON
();
std
::
string
ToJSON
();
template
<
typename
T
>
static
bool
CategoricalDecision
(
T
fval
,
T
threshold
)
{
if
(
static_cast
<
int
>
(
fval
)
==
static_cast
<
int
>
(
threshold
))
{
return
true
;
}
else
{
return
false
;
}
}
template
<
typename
T
>
static
bool
NumericalDecision
(
T
fval
,
T
threshold
)
{
if
(
fval
<=
threshold
)
{
return
true
;
}
else
{
return
false
;
}
}
static
const
char
*
GetDecisionTypeName
(
int8_t
type
)
{
if
(
type
==
0
)
{
return
"no_greater"
;
}
else
{
return
"is"
;
}
}
static
std
::
vector
<
std
::
function
<
bool
(
unsigned
int
,
unsigned
int
)
>>
inner_decision_funs
;
static
std
::
vector
<
std
::
function
<
bool
(
double
,
double
)
>>
decision_funs
;
private:
private:
/*!
/*!
* \brief Find leaf index of which record belongs by data
* \brief Find leaf index of which record belongs by data
...
@@ -141,15 +173,21 @@ private:
...
@@ -141,15 +173,21 @@ private:
std
::
vector
<
unsigned
int
>
threshold_in_bin_
;
std
::
vector
<
unsigned
int
>
threshold_in_bin_
;
/*! \brief A non-leaf node's split threshold in feature value */
/*! \brief A non-leaf node's split threshold in feature value */
std
::
vector
<
double
>
threshold_
;
std
::
vector
<
double
>
threshold_
;
/*! \brief Decision type, 0 for '<='(numerical feature), 1 for 'is'(categorical feature) */
std
::
vector
<
int8_t
>
decision_type_
;
/*! \brief A non-leaf node's split gain */
/*! \brief A non-leaf node's split gain */
std
::
vector
<
double
>
split_gain_
;
std
::
vector
<
double
>
split_gain_
;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
std
::
vector
<
double
>
internal_value_
;
// used for leaf node
// used for leaf node
/*! \brief The parent of leaf */
/*! \brief The parent of leaf */
std
::
vector
<
int
>
leaf_parent_
;
std
::
vector
<
int
>
leaf_parent_
;
/*! \brief Output of leaves */
/*! \brief Output of leaves */
std
::
vector
<
double
>
leaf_value_
;
std
::
vector
<
double
>
leaf_value_
;
/*! \brief DataCount of leaves */
std
::
vector
<
data_size_t
>
leaf_count_
;
/*! \brief Output of non-leaf nodes */
std
::
vector
<
double
>
internal_value_
;
/*! \brief DataCount of non-leaf nodes */
std
::
vector
<
data_size_t
>
internal_count_
;
/*! \brief Depth for leaves */
/*! \brief Depth for leaves */
std
::
vector
<
int
>
leaf_depth_
;
std
::
vector
<
int
>
leaf_depth_
;
};
};
...
@@ -169,7 +207,9 @@ inline int Tree::GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterat
...
@@ -169,7 +207,9 @@ inline int Tree::GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterat
data_size_t
data_idx
)
const
{
data_size_t
data_idx
)
const
{
int
node
=
0
;
int
node
=
0
;
while
(
node
>=
0
)
{
while
(
node
>=
0
)
{
if
(
iterators
[
split_feature_
[
node
]]
->
Get
(
data_idx
)
<=
threshold_in_bin_
[
node
])
{
if
(
inner_decision_funs
[
decision_type_
[
node
]](
iterators
[
split_feature_
[
node
]]
->
Get
(
data_idx
),
threshold_in_bin_
[
node
]))
{
node
=
left_child_
[
node
];
node
=
left_child_
[
node
];
}
else
{
}
else
{
node
=
right_child_
[
node
];
node
=
right_child_
[
node
];
...
@@ -181,7 +221,9 @@ inline int Tree::GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterat
...
@@ -181,7 +221,9 @@ inline int Tree::GetLeaf(const std::vector<std::unique_ptr<BinIterator>>& iterat
inline
int
Tree
::
GetLeaf
(
const
double
*
feature_values
)
const
{
inline
int
Tree
::
GetLeaf
(
const
double
*
feature_values
)
const
{
int
node
=
0
;
int
node
=
0
;
while
(
node
>=
0
)
{
while
(
node
>=
0
)
{
if
(
feature_values
[
split_feature_real_
[
node
]]
<=
threshold_
[
node
])
{
if
(
decision_funs
[
decision_type_
[
node
]](
feature_values
[
split_feature_real_
[
node
]],
threshold_
[
node
]))
{
node
=
left_child_
[
node
];
node
=
left_child_
[
node
];
}
else
{
}
else
{
node
=
right_child_
[
node
];
node
=
right_child_
[
node
];
...
...
include/LightGBM/utils/common.h
View file @
1466f907
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <cmath>
#include <cmath>
#include <functional>
#include <functional>
#include <memory>
#include <memory>
#include <type_traits>
namespace
LightGBM
{
namespace
LightGBM
{
...
@@ -230,22 +231,17 @@ inline static const char* SkipReturn(const char* p) {
...
@@ -230,22 +231,17 @@ inline static const char* SkipReturn(const char* p) {
return
p
;
return
p
;
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
T2
>
inline
static
std
::
string
ArrayToString
(
const
T
*
arr
,
int
n
,
char
delimiter
)
{
inline
static
std
::
vector
<
T2
>
ArrayCast
(
const
std
::
vector
<
T
>&
arr
)
{
if
(
n
<=
0
)
{
std
::
vector
<
T2
>
ret
;
return
std
::
string
(
""
);
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
++
i
)
{
ret
.
push_back
(
static_cast
<
T2
>
(
arr
[
i
]));
}
}
std
::
stringstream
str_buf
;
return
ret
;
str_buf
<<
arr
[
0
];
for
(
int
i
=
1
;
i
<
n
;
++
i
)
{
str_buf
<<
delimiter
;
str_buf
<<
arr
[
i
];
}
return
str_buf
.
str
();
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
static
std
::
string
ArrayToString
(
std
::
vector
<
T
>
arr
,
char
delimiter
)
{
inline
static
std
::
string
ArrayToString
(
const
std
::
vector
<
T
>
&
arr
,
char
delimiter
)
{
if
(
arr
.
size
()
<=
0
)
{
if
(
arr
.
size
()
<=
0
)
{
return
std
::
string
(
""
);
return
std
::
string
(
""
);
}
}
...
@@ -258,55 +254,43 @@ inline static std::string ArrayToString(std::vector<T> arr, char delimiter) {
...
@@ -258,55 +254,43 @@ inline static std::string ArrayToString(std::vector<T> arr, char delimiter) {
return
str_buf
.
str
();
return
str_buf
.
str
();
}
}
inline
static
void
StringToIntArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
,
int
*
out
)
{
template
<
typename
T
>
inline
static
std
::
vector
<
T
>
StringToArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
if
(
strs
.
size
()
!=
n
)
{
if
(
strs
.
size
()
!=
n
)
{
Log
::
Fatal
(
"StringToIntArray error, size doesn't match."
);
Log
::
Fatal
(
"StringToIntArray error, size doesn't match."
);
}
}
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
std
::
vector
<
T
>
ret
(
n
);
strs
[
i
]
=
Trim
(
strs
[
i
]);
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
{
Atoi
(
strs
[
i
].
c_str
(),
&
out
[
i
]);
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
}
ret
[
i
]
=
static_cast
<
T
>
(
std
::
stod
(
strs
[
i
]));
}
}
}
else
{
for
(
size_t
i
=
0
;
i
<
n
;
++
i
)
{
inline
static
void
StringToDoubleArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
,
double
*
out
)
{
ret
[
i
]
=
static_cast
<
T
>
(
std
::
stol
(
strs
[
i
]));
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
}
if
(
strs
.
size
()
!=
n
)
{
Log
::
Fatal
(
"StringToDoubleArray error, size doesn't match."
);
}
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
strs
[
i
]
=
Trim
(
strs
[
i
]);
Atof
(
strs
[
i
].
c_str
(),
&
out
[
i
]);
}
}
inline
static
std
::
vector
<
double
>
StringToDoubleArray
(
const
std
::
string
&
str
,
char
delimiter
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
std
::
vector
<
double
>
ret
;
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
strs
[
i
]
=
Trim
(
strs
[
i
]);
double
val
=
0.0
f
;
Atof
(
strs
[
i
].
c_str
(),
&
val
);
ret
.
push_back
(
val
);
}
}
return
ret
;
return
ret
;
}
}
inline
static
std
::
vector
<
int
>
StringToIntArray
(
const
std
::
string
&
str
,
char
delimiter
)
{
template
<
typename
T
>
inline
static
std
::
vector
<
T
>
StringToArray
(
const
std
::
string
&
str
,
char
delimiter
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
std
::
vector
<
int
>
ret
;
std
::
vector
<
T
>
ret
;
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
if
(
std
::
is_same
<
T
,
float
>::
value
||
std
::
is_same
<
T
,
double
>::
value
)
{
strs
[
i
]
=
Trim
(
strs
[
i
]);
for
(
const
auto
&
s
:
strs
)
{
int
val
=
0
;
ret
.
push_back
(
static_cast
<
T
>
(
std
::
stod
(
s
)));
Atoi
(
strs
[
i
].
c_str
(),
&
val
);
}
ret
.
push_back
(
val
);
}
else
{
for
(
const
auto
&
s
:
strs
)
{
ret
.
push_back
(
static_cast
<
T
>
(
std
::
stol
(
s
)));
}
}
}
return
ret
;
return
ret
;
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
static
std
::
string
Join
(
const
std
::
vector
<
T
>&
strs
,
char
delimiter
)
{
inline
static
std
::
string
Join
(
const
std
::
vector
<
T
>&
strs
,
const
char
*
delimiter
)
{
if
(
strs
.
size
()
<=
0
)
{
if
(
strs
.
size
()
<=
0
)
{
return
std
::
string
(
""
);
return
std
::
string
(
""
);
}
}
...
@@ -320,7 +304,7 @@ inline static std::string Join(const std::vector<T>& strs, char delimiter) {
...
@@ -320,7 +304,7 @@ inline static std::string Join(const std::vector<T>& strs, char delimiter) {
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
static
std
::
string
Join
(
const
std
::
vector
<
T
>&
strs
,
size_t
start
,
size_t
end
,
char
delimiter
)
{
inline
static
std
::
string
Join
(
const
std
::
vector
<
T
>&
strs
,
size_t
start
,
size_t
end
,
const
char
*
delimiter
)
{
if
(
end
-
start
<=
0
)
{
if
(
end
-
start
<=
0
)
{
return
std
::
string
(
""
);
return
std
::
string
(
""
);
}
}
...
@@ -375,6 +359,28 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
...
@@ -375,6 +359,28 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
return
ret
;
return
ret
;
}
}
template
<
typename
T1
,
typename
T2
>
inline
void
SortForPair
(
std
::
vector
<
T1
>&
keys
,
std
::
vector
<
T2
>&
values
,
size_t
start
,
bool
is_reverse
=
false
)
{
std
::
vector
<
std
::
pair
<
T1
,
T2
>>
arr
;
for
(
size_t
i
=
start
;
i
<
keys
.
size
();
++
i
)
{
arr
.
emplace_back
(
keys
[
i
],
values
[
i
]);
}
if
(
!
is_reverse
)
{
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
[](
const
std
::
pair
<
T1
,
T2
>&
a
,
const
std
::
pair
<
T1
,
T2
>&
b
)
{
return
a
.
first
<
b
.
first
;
});
}
else
{
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
[](
const
std
::
pair
<
T1
,
T2
>&
a
,
const
std
::
pair
<
T1
,
T2
>&
b
)
{
return
a
.
first
>
b
.
first
;
});
}
for
(
size_t
i
=
start
;
i
<
arr
.
size
();
++
i
)
{
keys
[
i
]
=
arr
[
i
].
first
;
values
[
i
]
=
arr
[
i
].
second
;
}
}
}
// namespace Common
}
// namespace Common
}
// namespace LightGBM
}
// namespace LightGBM
...
...
python-package/lightgbm/basic.py
View file @
1466f907
...
@@ -418,7 +418,8 @@ class Dataset(object):
...
@@ -418,7 +418,8 @@ class Dataset(object):
def
__init__
(
self
,
data
,
label
=
None
,
max_bin
=
255
,
reference
=
None
,
def
__init__
(
self
,
data
,
label
=
None
,
max_bin
=
255
,
reference
=
None
,
weight
=
None
,
group
=
None
,
predictor
=
None
,
weight
=
None
,
group
=
None
,
predictor
=
None
,
silent
=
False
,
params
=
None
):
silent
=
False
,
feature_name
=
None
,
categorical_feature
=
None
,
params
=
None
):
"""
"""
Dataset used in LightGBM.
Dataset used in LightGBM.
...
@@ -439,6 +440,11 @@ class Dataset(object):
...
@@ -439,6 +440,11 @@ class Dataset(object):
group/query size for dataset
group/query size for dataset
silent : boolean, optional
silent : boolean, optional
Whether print messages during construction
Whether print messages during construction
feature_name : list of str
feature names
categorical_feature : list of str/int
categorical features , int type to use index,
str type to use feature names (feature_name cannot be None)
params: dict, optional
params: dict, optional
other parameters
other parameters
"""
"""
...
@@ -461,6 +467,23 @@ class Dataset(object):
...
@@ -461,6 +467,23 @@ class Dataset(object):
params
[
"verbose"
]
=
0
params
[
"verbose"
]
=
0
elif
"verbose"
not
in
params
:
elif
"verbose"
not
in
params
:
params
[
"verbose"
]
=
1
params
[
"verbose"
]
=
1
"""get categorical features"""
if
categorical_feature
is
not
None
:
categorical_indices
=
[]
feature_dict
=
{}
if
feature_name
is
not
None
:
feature_dict
=
dict
((
name
,
i
)
for
i
,
name
in
enumerate
(
feature_name
))
for
name
in
categorical_feature
:
if
is_str
(
name
)
and
name
in
feature_dict
:
categorical_indices
.
append
(
feature_dict
[
name
])
elif
isinstance
(
name
,
int
):
categorical_indices
.
append
(
name
)
else
:
raise
TypeError
(
"unknown type({}) or unknown name({}) in categorical_feature"
.
format
(
type
(
name
).
__name__
,
name
))
params
[
'categorical_column'
]
=
categorical_indices
params_str
=
param_dict_to_str
(
params
)
params_str
=
param_dict_to_str
(
params
)
"""process for reference dataset"""
"""process for reference dataset"""
ref_dataset
=
None
ref_dataset
=
None
...
@@ -513,6 +536,8 @@ class Dataset(object):
...
@@ -513,6 +536,8 @@ class Dataset(object):
new_init_score
[
j
*
num_data
+
i
]
=
init_score
[
i
*
self
.
predictor
.
num_class
+
j
]
new_init_score
[
j
*
num_data
+
i
]
=
init_score
[
i
*
self
.
predictor
.
num_class
+
j
]
init_score
=
new_init_score
init_score
=
new_init_score
self
.
set_init_score
(
init_score
)
self
.
set_init_score
(
init_score
)
# set feature names
self
.
set_feature_name
(
feature_name
)
def
create_valid
(
self
,
data
,
label
=
None
,
weight
=
None
,
group
=
None
,
def
create_valid
(
self
,
data
,
label
=
None
,
weight
=
None
,
group
=
None
,
silent
=
False
,
params
=
None
):
silent
=
False
,
params
=
None
):
...
@@ -559,6 +584,17 @@ class Dataset(object):
...
@@ -559,6 +584,17 @@ class Dataset(object):
raise
ValueError
(
"label should not be None"
)
raise
ValueError
(
"label should not be None"
)
return
ret
return
ret
def
set_feature_name
(
self
,
feature_name
):
if
feature_name
is
None
:
return
if
len
(
feature_name
)
!=
self
.
num_feature
():
raise
ValueError
(
"size of feature_name error"
)
c_feature_name
=
[
c_str
(
name
)
for
name
in
feature_name
]
_safe_call
(
_LIB
.
LGBM_DatasetSetFeatureNames
(
self
.
handle
,
c_array
(
ctypes
.
c_char_p
,
c_feature_name
),
len
(
feature_name
)))
def
__init_from_np2d
(
self
,
mat
,
params_str
,
ref_dataset
):
def
__init_from_np2d
(
self
,
mat
,
params_str
,
ref_dataset
):
"""
"""
Initialize data from a 2-D numpy matrix.
Initialize data from a 2-D numpy matrix.
...
...
python-package/lightgbm/engine.py
View file @
1466f907
...
@@ -8,7 +8,8 @@ from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
...
@@ -8,7 +8,8 @@ from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from
.
import
callback
from
.
import
callback
def
_construct_dataset
(
X_y
,
reference
=
None
,
def
_construct_dataset
(
X_y
,
reference
=
None
,
params
=
None
,
other_fields
=
None
,
params
=
None
,
other_fields
=
None
,
feature_name
=
None
,
categorical_feature
=
None
,
predictor
=
None
):
predictor
=
None
):
if
'max_bin'
in
params
:
if
'max_bin'
in
params
:
max_bin
=
int
(
params
[
'max_bin'
])
max_bin
=
int
(
params
[
'max_bin'
])
...
@@ -34,7 +35,10 @@ def _construct_dataset(X_y, reference=None,
...
@@ -34,7 +35,10 @@ def _construct_dataset(X_y, reference=None,
if
reference
is
None
:
if
reference
is
None
:
ret
=
Dataset
(
data
,
label
=
label
,
max_bin
=
max_bin
,
ret
=
Dataset
(
data
,
label
=
label
,
max_bin
=
max_bin
,
weight
=
weight
,
group
=
group
,
weight
=
weight
,
group
=
group
,
predictor
=
predictor
,
params
=
params
)
predictor
=
predictor
,
feature_name
=
feature_name
,
categorical_feature
=
categorical_feature
,
params
=
params
)
else
:
else
:
ret
=
reference
.
create_valid
(
data
,
label
=
label
,
weight
=
weight
,
ret
=
reference
.
create_valid
(
data
,
label
=
label
,
weight
=
weight
,
group
=
group
,
params
=
params
)
group
=
group
,
params
=
params
)
...
@@ -46,6 +50,7 @@ def train(params, train_data, num_boost_round=100,
...
@@ -46,6 +50,7 @@ def train(params, train_data, num_boost_round=100,
valid_datas
=
None
,
valid_names
=
None
,
valid_datas
=
None
,
valid_names
=
None
,
fobj
=
None
,
feval
=
None
,
init_model
=
None
,
fobj
=
None
,
feval
=
None
,
init_model
=
None
,
train_fields
=
None
,
valid_fields
=
None
,
train_fields
=
None
,
valid_fields
=
None
,
feature_name
=
None
,
categorical_feature
=
None
,
early_stopping_rounds
=
None
,
evals_result
=
None
,
early_stopping_rounds
=
None
,
evals_result
=
None
,
verbose_eval
=
True
,
learning_rates
=
None
,
callbacks
=
None
):
verbose_eval
=
True
,
learning_rates
=
None
,
callbacks
=
None
):
"""Train with given parameters.
"""Train with given parameters.
...
@@ -76,6 +81,11 @@ def train(params, train_data, num_boost_round=100,
...
@@ -76,6 +81,11 @@ def train(params, train_data, num_boost_round=100,
other data file in training data.
\
other data file in training data.
\
e.g. valid_fields[0]['weight'] is weight data for first valid data
e.g. valid_fields[0]['weight'] is weight data for first valid data
support fields: weight, group, init_score
support fields: weight, group, init_score
feature_name : list of str
feature names
categorical_feature : list of str/int
categorical features , int type to use index,
str type to use feature names (feature_name cannot be None)
early_stopping_rounds: int
early_stopping_rounds: int
Activates early stopping.
Activates early stopping.
Requires at least one validation data and one metric
Requires at least one validation data and one metric
...
@@ -125,7 +135,11 @@ def train(params, train_data, num_boost_round=100,
...
@@ -125,7 +135,11 @@ def train(params, train_data, num_boost_round=100,
if
isinstance
(
train_data
,
Dataset
):
if
isinstance
(
train_data
,
Dataset
):
train_set
=
train_data
train_set
=
train_data
else
:
else
:
train_set
=
_construct_dataset
(
train_data
,
None
,
params
,
train_fields
,
predictor
)
train_set
=
_construct_dataset
(
train_data
,
None
,
params
,
other_fields
=
train_fields
,
feature_name
=
feature_name
,
categorical_feature
=
categorical_feature
,
predictor
=
predictor
)
is_valid_contain_train
=
False
is_valid_contain_train
=
False
train_data_name
=
"training"
train_data_name
=
"training"
valid_sets
=
[]
valid_sets
=
[]
...
@@ -150,8 +164,10 @@ def train(params, train_data, num_boost_round=100,
...
@@ -150,8 +164,10 @@ def train(params, train_data, num_boost_round=100,
valid_data
,
valid_data
,
train_set
,
train_set
,
params
,
params
,
other_fields
,
other_fields
=
other_fields
,
predictor
)
feature_name
=
feature_name
,
categorical_feature
=
categorical_feature
,
predictor
=
predictor
)
valid_sets
.
append
(
valid_set
)
valid_sets
.
append
(
valid_set
)
if
valid_names
is
not
None
:
if
valid_names
is
not
None
:
name_valid_sets
.
append
(
valid_names
[
i
])
name_valid_sets
.
append
(
valid_names
[
i
])
...
@@ -303,8 +319,10 @@ def _agg_cv_result(raw_results):
...
@@ -303,8 +319,10 @@ def _agg_cv_result(raw_results):
return
results
return
results
def
cv
(
params
,
train_data
,
num_boost_round
=
10
,
nfold
=
5
,
stratified
=
False
,
def
cv
(
params
,
train_data
,
num_boost_round
=
10
,
nfold
=
5
,
stratified
=
False
,
metrics
=
(),
fobj
=
None
,
feval
=
None
,
train_fields
=
None
,
early_stopping_rounds
=
None
,
metrics
=
(),
fobj
=
None
,
feval
=
None
,
train_fields
=
None
,
fpreproc
=
None
,
verbose_eval
=
None
,
show_stdv
=
True
,
seed
=
0
,
feature_name
=
None
,
categorical_feature
=
None
,
early_stopping_rounds
=
None
,
fpreproc
=
None
,
verbose_eval
=
None
,
show_stdv
=
True
,
seed
=
0
,
callbacks
=
None
):
callbacks
=
None
):
"""Cross-validation with given paramaters.
"""Cross-validation with given paramaters.
...
@@ -331,6 +349,11 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
...
@@ -331,6 +349,11 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
train_fields : dict
train_fields : dict
other data file in training data. e.g. train_fields['weight'] is weight data
other data file in training data. e.g. train_fields['weight'] is weight data
support fields: weight, group, init_score
support fields: weight, group, init_score
feature_name : list of str
feature names
categorical_feature : list of str/int
categorical features , int type to use index,
str type to use feature names (feature_name cannot be None)
early_stopping_rounds: int
early_stopping_rounds: int
Activates early stopping. CV error needs to decrease at least
Activates early stopping. CV error needs to decrease at least
every <early_stopping_rounds> round(s) to continue.
every <early_stopping_rounds> round(s) to continue.
...
@@ -373,7 +396,10 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
...
@@ -373,7 +396,10 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
if
metrics
is
not
None
and
len
(
metrics
)
>
0
:
if
metrics
is
not
None
and
len
(
metrics
)
>
0
:
params
[
'metric'
].
extend
(
metrics
)
params
[
'metric'
].
extend
(
metrics
)
train_set
=
_construct_dataset
(
train_data
,
None
,
params
,
train_fields
)
train_set
=
_construct_dataset
(
train_data
,
None
,
params
,
other_fields
=
train_fields
,
feature_name
=
feature_name
,
categorical_feature
=
categorical_feature
)
results
=
{}
results
=
{}
cvfolds
=
_make_n_folds
(
train_set
,
nfold
,
params
,
seed
,
fpreproc
,
stratified
)
cvfolds
=
_make_n_folds
(
train_set
,
nfold
,
params
,
seed
,
fpreproc
,
stratified
)
...
...
python-package/lightgbm/sklearn.py
View file @
1466f907
...
@@ -197,7 +197,9 @@ class LGBMModel(LGBMModelBase):
...
@@ -197,7 +197,9 @@ class LGBMModel(LGBMModelBase):
def
fit
(
self
,
X
,
y
,
eval_set
=
None
,
eval_metric
=
None
,
def
fit
(
self
,
X
,
y
,
eval_set
=
None
,
eval_metric
=
None
,
early_stopping_rounds
=
None
,
verbose
=
True
,
early_stopping_rounds
=
None
,
verbose
=
True
,
train_fields
=
None
,
valid_fields
=
None
,
other_params
=
None
):
train_fields
=
None
,
valid_fields
=
None
,
feature_name
=
None
,
categorical_feature
=
None
,
other_params
=
None
):
"""
"""
Fit the gradient boosting model
Fit the gradient boosting model
...
@@ -225,6 +227,11 @@ class LGBMModel(LGBMModelBase):
...
@@ -225,6 +227,11 @@ class LGBMModel(LGBMModelBase):
other data file in training data.
\
other data file in training data.
\
e.g. valid_fields[0]['weight'] is weight data for first valid data
e.g. valid_fields[0]['weight'] is weight data for first valid data
support fields: weight, group, init_score
support fields: weight, group, init_score
feature_name : list of str
feature names
categorical_feature : list of str/int
categorical features , int type to use index,
str type to use feature names (feature_name cannot be None)
other_params: dict
other_params: dict
other parameters
other parameters
"""
"""
...
@@ -260,7 +267,8 @@ class LGBMModel(LGBMModelBase):
...
@@ -260,7 +267,8 @@ class LGBMModel(LGBMModelBase):
early_stopping_rounds
=
early_stopping_rounds
,
early_stopping_rounds
=
early_stopping_rounds
,
evals_result
=
evals_result
,
fobj
=
self
.
fobj
,
feval
=
feval
,
evals_result
=
evals_result
,
fobj
=
self
.
fobj
,
feval
=
feval
,
verbose_eval
=
verbose
,
train_fields
=
train_fields
,
verbose_eval
=
verbose
,
train_fields
=
train_fields
,
valid_fields
=
valid_fields
)
valid_fields
=
valid_fields
,
feature_name
=
feature_name
,
categorical_feature
=
categorical_feature
)
if
evals_result
:
if
evals_result
:
for
val
in
evals_result
.
items
():
for
val
in
evals_result
.
items
():
...
@@ -321,7 +329,9 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
...
@@ -321,7 +329,9 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
def
fit
(
self
,
X
,
y
,
eval_set
=
None
,
eval_metric
=
None
,
def
fit
(
self
,
X
,
y
,
eval_set
=
None
,
eval_metric
=
None
,
early_stopping_rounds
=
None
,
verbose
=
True
,
early_stopping_rounds
=
None
,
verbose
=
True
,
train_fields
=
None
,
valid_fields
=
None
,
other_params
=
None
):
train_fields
=
None
,
valid_fields
=
None
,
feature_name
=
None
,
categorical_feature
=
None
,
other_params
=
None
):
self
.
classes_
=
np
.
unique
(
y
)
self
.
classes_
=
np
.
unique
(
y
)
self
.
n_classes_
=
len
(
self
.
classes_
)
self
.
n_classes_
=
len
(
self
.
classes_
)
...
@@ -347,6 +357,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
...
@@ -347,6 +357,7 @@ class LGBMClassifier(LGBMModel, LGBMClassifierBase):
super
(
LGBMClassifier
,
self
).
fit
(
X
,
training_labels
,
eval_set
,
super
(
LGBMClassifier
,
self
).
fit
(
X
,
training_labels
,
eval_set
,
eval_metric
,
early_stopping_rounds
,
eval_metric
,
early_stopping_rounds
,
verbose
,
train_fields
,
valid_fields
,
verbose
,
train_fields
,
valid_fields
,
feature_name
,
categorical_feature
,
other_params
)
other_params
)
return
self
return
self
...
...
src/application/application.cpp
View file @
1466f907
...
@@ -105,12 +105,13 @@ void Application::LoadParameters(int argc, char** argv) {
...
@@ -105,12 +105,13 @@ void Application::LoadParameters(int argc, char** argv) {
void
Application
::
LoadData
()
{
void
Application
::
LoadData
()
{
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
auto
start_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
std
::
unique_ptr
<
Predictor
>
predictor
;
// prediction is needed if using input initial model(continued train)
// prediction is needed if using input initial model(continued train)
PredictFunction
predict_fun
=
nullptr
;
PredictFunction
predict_fun
=
nullptr
;
// need to continue training
// need to continue training
if
(
boosting_
->
NumberOfTotalModel
()
>
0
)
{
if
(
boosting_
->
NumberOfTotalModel
()
>
0
)
{
P
redictor
p
redictor
(
boosting_
.
get
(),
true
,
false
);
p
redictor
.
reset
(
new
P
redictor
(
boosting_
.
get
(),
true
,
false
)
)
;
predict_fun
=
predictor
.
GetPredictFunction
();
predict_fun
=
predictor
->
GetPredictFunction
();
}
}
// sync up random seed for data partition
// sync up random seed for data partition
...
@@ -119,8 +120,7 @@ void Application::LoadData() {
...
@@ -119,8 +120,7 @@ void Application::LoadData() {
GlobalSyncUpByMin
<
int
>
(
config_
.
io_config
.
data_random_seed
);
GlobalSyncUpByMin
<
int
>
(
config_
.
io_config
.
data_random_seed
);
}
}
DatasetLoader
dataset_loader
(
config_
.
io_config
,
predict_fun
);
DatasetLoader
dataset_loader
(
config_
.
io_config
,
predict_fun
,
config_
.
io_config
.
data_filename
.
c_str
());
dataset_loader
.
SetHeader
(
config_
.
io_config
.
data_filename
.
c_str
());
// load Training data
// load Training data
if
(
config_
.
is_parallel_find_bin
)
{
if
(
config_
.
is_parallel_find_bin
)
{
// load data for parallel training
// load data for parallel training
...
...
src/application/predictor.hpp
View file @
1466f907
...
@@ -116,7 +116,7 @@ public:
...
@@ -116,7 +116,7 @@ public:
// parser
// parser
parser_fun
(
lines
[
i
].
c_str
(),
&
oneline_features
);
parser_fun
(
lines
[
i
].
c_str
(),
&
oneline_features
);
// predict
// predict
pred_result
[
i
]
=
Common
::
Join
<
double
>
(
predict_fun_
(
oneline_features
),
'
\t
'
);
pred_result
[
i
]
=
Common
::
Join
<
double
>
(
predict_fun_
(
oneline_features
),
"
\t
"
);
}
}
for
(
size_t
i
=
0
;
i
<
pred_result
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
pred_result
.
size
();
++
i
)
{
...
...
src/boosting/gbdt.cpp
View file @
1466f907
...
@@ -401,11 +401,18 @@ std::string GBDT::DumpModel() const {
...
@@ -401,11 +401,18 @@ std::string GBDT::DumpModel() const {
ss
<<
"
\"
num_class
\"
:"
<<
num_class_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
num_class
\"
:"
<<
num_class_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
label_index
\"
:"
<<
label_idx_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
label_index
\"
:"
<<
label_idx_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
max_feature_idx
\"
:"
<<
max_feature_idx_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
max_feature_idx
\"
:"
<<
max_feature_idx_
<<
","
<<
std
::
endl
;
if
(
object_function_
!=
nullptr
)
{
ss
<<
"
\"
objective
\"
:
\"
"
<<
object_function_
->
GetName
()
<<
"
\"
,"
<<
std
::
endl
;
}
ss
<<
"
\"
sigmoid
\"
:"
<<
sigmoid_
<<
","
<<
std
::
endl
;
ss
<<
"
\"
sigmoid
\"
:"
<<
sigmoid_
<<
","
<<
std
::
endl
;
// output feature names
auto
feature_names
=
std
::
ref
(
feature_names_
);
if
(
train_data_
!=
nullptr
)
{
feature_names
=
std
::
ref
(
train_data_
->
feature_names
());
}
ss
<<
"
\"
feature_names
\"
:[
\"
"
<<
Common
::
Join
(
feature_names
.
get
(),
"
\"
,
\"
"
)
<<
"
\"
],"
<<
std
::
endl
;
ss
<<
"
\"
tree_info
\"
:["
;
ss
<<
"
\"
tree_info
\"
:["
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
models_
.
size
());
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
models_
.
size
());
++
i
)
{
if
(
i
>
0
)
{
if
(
i
>
0
)
{
...
@@ -441,8 +448,14 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
...
@@ -441,8 +448,14 @@ void GBDT::SaveModelToFile(int num_iteration, const char* filename) const {
}
}
// output sigmoid parameter
// output sigmoid parameter
output_file
<<
"sigmoid="
<<
sigmoid_
<<
std
::
endl
;
output_file
<<
"sigmoid="
<<
sigmoid_
<<
std
::
endl
;
output_file
<<
std
::
endl
;
// output feature names
auto
feature_names
=
std
::
ref
(
feature_names_
);
if
(
train_data_
!=
nullptr
)
{
feature_names
=
std
::
ref
(
train_data_
->
feature_names
());
}
output_file
<<
"feature_names="
<<
Common
::
Join
(
feature_names
.
get
(),
" "
)
<<
std
::
endl
;
output_file
<<
std
::
endl
;
int
num_used_model
=
0
;
int
num_used_model
=
0
;
if
(
num_iteration
<=
0
)
{
if
(
num_iteration
<=
0
)
{
num_used_model
=
static_cast
<
int
>
(
models_
.
size
());
num_used_model
=
static_cast
<
int
>
(
models_
.
size
());
...
@@ -500,6 +513,19 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -500,6 +513,19 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
}
else
{
}
else
{
sigmoid_
=
-
1.0
f
;
sigmoid_
=
-
1.0
f
;
}
}
// get feature names
line
=
Common
::
FindFromLines
(
lines
,
"feature_names="
);
if
(
line
.
size
()
>
0
)
{
feature_names_
=
Common
::
Split
(
Common
::
Split
(
line
.
c_str
(),
'='
)[
1
].
c_str
(),
' '
);
if
(
feature_names_
.
size
()
!=
static_cast
<
size_t
>
(
max_feature_idx_
+
1
))
{
Log
::
Fatal
(
"Wrong size of feature_names"
);
return
;
}
}
else
{
Log
::
Fatal
(
"Model file doesn't contain feature names"
);
return
;
}
// get tree models
// get tree models
size_t
i
=
0
;
size_t
i
=
0
;
while
(
i
<
lines
.
size
())
{
while
(
i
<
lines
.
size
())
{
...
@@ -509,7 +535,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -509,7 +535,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
int
start
=
static_cast
<
int
>
(
i
);
int
start
=
static_cast
<
int
>
(
i
);
while
(
i
<
lines
.
size
()
&&
lines
[
i
].
find
(
"Tree="
)
==
std
::
string
::
npos
)
{
++
i
;
}
while
(
i
<
lines
.
size
()
&&
lines
[
i
].
find
(
"Tree="
)
==
std
::
string
::
npos
)
{
++
i
;
}
int
end
=
static_cast
<
int
>
(
i
);
int
end
=
static_cast
<
int
>
(
i
);
std
::
string
tree_str
=
Common
::
Join
<
std
::
string
>
(
lines
,
start
,
end
,
'
\n
'
);
std
::
string
tree_str
=
Common
::
Join
<
std
::
string
>
(
lines
,
start
,
end
,
"
\n
"
);
auto
new_tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
tree_str
));
auto
new_tree
=
std
::
unique_ptr
<
Tree
>
(
new
Tree
(
tree_str
));
models_
.
push_back
(
std
::
move
(
new_tree
));
models_
.
push_back
(
std
::
move
(
new_tree
));
}
else
{
}
else
{
...
@@ -522,6 +548,10 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
...
@@ -522,6 +548,10 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
}
}
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
GBDT
::
FeatureImportance
()
const
{
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
GBDT
::
FeatureImportance
()
const
{
auto
feature_names
=
std
::
ref
(
feature_names_
);
if
(
train_data_
!=
nullptr
)
{
feature_names
=
std
::
ref
(
train_data_
->
feature_names
());
}
std
::
vector
<
size_t
>
feature_importances
(
max_feature_idx_
+
1
,
0
);
std
::
vector
<
size_t
>
feature_importances
(
max_feature_idx_
+
1
,
0
);
for
(
size_t
iter
=
0
;
iter
<
models_
.
size
();
++
iter
)
{
for
(
size_t
iter
=
0
;
iter
<
models_
.
size
();
++
iter
)
{
for
(
int
split_idx
=
0
;
split_idx
<
models_
[
iter
]
->
num_leaves
()
-
1
;
++
split_idx
)
{
for
(
int
split_idx
=
0
;
split_idx
<
models_
[
iter
]
->
num_leaves
()
-
1
;
++
split_idx
)
{
...
@@ -532,7 +562,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
...
@@ -532,7 +562,7 @@ std::vector<std::pair<size_t, std::string>> GBDT::FeatureImportance() const {
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
pairs
;
std
::
vector
<
std
::
pair
<
size_t
,
std
::
string
>>
pairs
;
for
(
size_t
i
=
0
;
i
<
feature_importances
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
feature_importances
.
size
();
++
i
)
{
if
(
feature_importances
[
i
]
>
0
)
{
if
(
feature_importances
[
i
]
>
0
)
{
pairs
.
emplace_back
(
feature_importances
[
i
],
train_data_
->
feature_names
()[
i
]
);
pairs
.
emplace_back
(
feature_importances
[
i
],
feature_names
.
get
().
at
(
i
)
);
}
}
}
}
// sort the importance
// sort the importance
...
...
src/boosting/gbdt.h
View file @
1466f907
...
@@ -298,6 +298,8 @@ protected:
...
@@ -298,6 +298,8 @@ protected:
double
shrinkage_rate_
;
double
shrinkage_rate_
;
/*! \brief Number of loaded initial models */
/*! \brief Number of loaded initial models */
int
num_init_iteration_
;
int
num_init_iteration_
;
/*! \brief Feature names */
std
::
vector
<
std
::
string
>
feature_names_
;
};
};
}
// namespace LightGBM
}
// namespace LightGBM
...
...
src/c_api.cpp
View file @
1466f907
...
@@ -221,8 +221,7 @@ DllExport int LGBM_DatasetCreateFromFile(const char* filename,
...
@@ -221,8 +221,7 @@ DllExport int LGBM_DatasetCreateFromFile(const char* filename,
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
IOConfig
io_config
;
IOConfig
io_config
;
io_config
.
Set
(
param
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
DatasetLoader
loader
(
io_config
,
nullptr
,
filename
);
loader
.
SetHeader
(
filename
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
*
out
=
loader
.
LoadFromFile
(
filename
);
*
out
=
loader
.
LoadFromFile
(
filename
);
}
else
{
}
else
{
...
@@ -244,7 +243,6 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data,
...
@@ -244,7 +243,6 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data,
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
IOConfig
io_config
;
IOConfig
io_config
;
io_config
.
Set
(
param
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_row_fun
=
RowFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
auto
get_row_fun
=
RowFunctionFromDenseMatric
(
data
,
nrow
,
ncol
,
data_type
,
is_row_major
);
if
(
reference
==
nullptr
)
{
if
(
reference
==
nullptr
)
{
...
@@ -262,6 +260,7 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data,
...
@@ -262,6 +260,7 @@ DllExport int LGBM_DatasetCreateFromMat(const void* data,
}
}
}
}
}
}
DatasetLoader
loader
(
io_config
,
nullptr
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
...
@@ -296,7 +295,6 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
...
@@ -296,7 +295,6 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
IOConfig
io_config
;
IOConfig
io_config
;
io_config
.
Set
(
param
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
auto
get_row_fun
=
RowFunctionFromCSR
(
indptr
,
indptr_type
,
indices
,
data
,
data_type
,
nindptr
,
nelem
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
nindptr
-
1
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
nindptr
-
1
);
...
@@ -324,6 +322,7 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
...
@@ -324,6 +322,7 @@ DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
}
}
}
}
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
DatasetLoader
loader
(
io_config
,
nullptr
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
...
@@ -358,7 +357,6 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
...
@@ -358,7 +357,6 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
auto
param
=
ConfigBase
::
Str2Map
(
parameters
);
IOConfig
io_config
;
IOConfig
io_config
;
io_config
.
Set
(
param
);
io_config
.
Set
(
param
);
DatasetLoader
loader
(
io_config
,
nullptr
);
std
::
unique_ptr
<
Dataset
>
ret
;
std
::
unique_ptr
<
Dataset
>
ret
;
auto
get_col_fun
=
ColumnFunctionFromCSC
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
);
auto
get_col_fun
=
ColumnFunctionFromCSC
(
col_ptr
,
col_ptr_type
,
indices
,
data
,
data_type
,
ncol_ptr
,
nelem
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
num_row
);
int32_t
nrow
=
static_cast
<
int32_t
>
(
num_row
);
...
@@ -374,6 +372,7 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
...
@@ -374,6 +372,7 @@ DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
auto
cur_col
=
get_col_fun
(
i
);
auto
cur_col
=
get_col_fun
(
i
);
sample_values
[
i
]
=
SampleFromOneColumn
(
cur_col
,
sample_indices
);
sample_values
[
i
]
=
SampleFromOneColumn
(
cur_col
,
sample_indices
);
}
}
DatasetLoader
loader
(
io_config
,
nullptr
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
ret
.
reset
(
loader
.
CostructFromSampleData
(
sample_values
,
sample_cnt
,
nrow
));
}
else
{
}
else
{
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
ret
.
reset
(
new
Dataset
(
nrow
,
io_config
.
num_class
));
...
@@ -413,6 +412,20 @@ DllExport int LGBM_DatasetGetSubset(
...
@@ -413,6 +412,20 @@ DllExport int LGBM_DatasetGetSubset(
API_END
();
API_END
();
}
}
DllExport
int
LGBM_DatasetSetFeatureNames
(
DatesetHandle
handle
,
const
char
**
feature_names
,
int64_t
num_feature_names
)
{
API_BEGIN
();
auto
dataset
=
reinterpret_cast
<
Dataset
*>
(
handle
);
std
::
vector
<
std
::
string
>
feature_names_str
;
for
(
int64_t
i
=
0
;
i
<
num_feature_names
;
++
i
)
{
feature_names_str
.
emplace_back
(
feature_names
[
i
]);
}
dataset
->
set_feature_names
(
feature_names_str
);
API_END
();
}
DllExport
int
LGBM_DatasetFree
(
DatesetHandle
handle
)
{
DllExport
int
LGBM_DatasetFree
(
DatesetHandle
handle
)
{
API_BEGIN
();
API_BEGIN
();
delete
reinterpret_cast
<
Dataset
*>
(
handle
);
delete
reinterpret_cast
<
Dataset
*>
(
handle
);
...
@@ -744,7 +757,7 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
...
@@ -744,7 +757,7 @@ DllExport int LGBM_BoosterDumpModel(BoosterHandle handle,
API_BEGIN
();
API_BEGIN
();
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
Booster
*
ref_booster
=
reinterpret_cast
<
Booster
*>
(
handle
);
std
::
string
model
=
ref_booster
->
DumpModel
();
std
::
string
model
=
ref_booster
->
DumpModel
();
*
out_len
=
static_cast
<
int64_t
>
(
model
.
size
());
*
out_len
=
static_cast
<
int64_t
>
(
model
.
size
())
+
1
;
if
(
*
out_len
<=
buffer_len
)
{
if
(
*
out_len
<=
buffer_len
)
{
std
::
strcpy
(
*
out_str
,
model
.
c_str
());
std
::
strcpy
(
*
out_str
,
model
.
c_str
());
}
}
...
...
src/io/bin.cpp
View file @
1466f907
...
@@ -23,10 +23,14 @@ BinMapper::BinMapper(const BinMapper& other) {
...
@@ -23,10 +23,14 @@ BinMapper::BinMapper(const BinMapper& other) {
num_bin_
=
other
.
num_bin_
;
num_bin_
=
other
.
num_bin_
;
is_trival_
=
other
.
is_trival_
;
is_trival_
=
other
.
is_trival_
;
sparse_rate_
=
other
.
sparse_rate_
;
sparse_rate_
=
other
.
sparse_rate_
;
bin_upper_bound_
=
std
::
vector
<
double
>
(
num_bin_
);
bin_type_
=
other
.
bin_type_
;
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
bin_upper_bound_
[
i
]
=
other
.
bin_upper_bound_
[
i
];
bin_upper_bound_
=
other
.
bin_upper_bound_
;
}
else
{
bin_2_categorical_
=
other
.
bin_2_categorical_
;
categorical_2_bin_
=
other
.
categorical_2_bin_
;
}
}
}
}
BinMapper
::
BinMapper
(
const
void
*
memory
)
{
BinMapper
::
BinMapper
(
const
void
*
memory
)
{
...
@@ -37,7 +41,8 @@ BinMapper::~BinMapper() {
...
@@ -37,7 +41,8 @@ BinMapper::~BinMapper() {
}
}
void
BinMapper
::
FindBin
(
std
::
vector
<
double
>*
values
,
size_t
total_sample_cnt
,
int
max_bin
)
{
void
BinMapper
::
FindBin
(
std
::
vector
<
double
>*
values
,
size_t
total_sample_cnt
,
int
max_bin
,
BinType
bin_type
)
{
bin_type_
=
bin_type
;
std
::
vector
<
double
>&
ref_values
=
(
*
values
);
std
::
vector
<
double
>&
ref_values
=
(
*
values
);
size_t
sample_size
=
total_sample_cnt
;
size_t
sample_size
=
total_sample_cnt
;
int
zero_cnt
=
static_cast
<
int
>
(
total_sample_cnt
-
ref_values
.
size
());
int
zero_cnt
=
static_cast
<
int
>
(
total_sample_cnt
-
ref_values
.
size
());
...
@@ -81,70 +86,105 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
...
@@ -81,70 +86,105 @@ void BinMapper::FindBin(std::vector<double>* values, size_t total_sample_cnt, in
int
num_values
=
static_cast
<
int
>
(
distinct_values
.
size
());
int
num_values
=
static_cast
<
int
>
(
distinct_values
.
size
());
int
cnt_in_bin0
=
0
;
int
cnt_in_bin0
=
0
;
if
(
num_values
<=
max_bin
)
{
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
std
::
sort
(
distinct_values
.
begin
(),
distinct_values
.
end
());
if
(
num_values
<=
max_bin
)
{
// use distinct value is enough
std
::
sort
(
distinct_values
.
begin
(),
distinct_values
.
end
());
num_bin_
=
num_values
;
// use distinct value is enough
bin_upper_bound_
=
std
::
vector
<
double
>
(
num_values
);
num_bin_
=
num_values
;
for
(
int
i
=
0
;
i
<
num_values
-
1
;
++
i
)
{
bin_upper_bound_
=
std
::
vector
<
double
>
(
num_values
);
bin_upper_bound_
[
i
]
=
(
distinct_values
[
i
]
+
distinct_values
[
i
+
1
])
/
2
;
for
(
int
i
=
0
;
i
<
num_values
-
1
;
++
i
)
{
}
bin_upper_bound_
[
i
]
=
(
distinct_values
[
i
]
+
distinct_values
[
i
+
1
])
/
2
;
cnt_in_bin0
=
counts
[
0
];
bin_upper_bound_
[
num_values
-
1
]
=
std
::
numeric_limits
<
double
>::
infinity
();
}
else
{
// mean size for one bin
double
mean_bin_size
=
sample_size
/
static_cast
<
double
>
(
max_bin
);
int
rest_bin_cnt
=
max_bin
;
int
rest_sample_cnt
=
static_cast
<
int
>
(
sample_size
);
std
::
vector
<
bool
>
is_big_count_value
(
num_values
,
false
);
for
(
int
i
=
0
;
i
<
num_values
;
++
i
)
{
if
(
counts
[
i
]
>=
mean_bin_size
)
{
is_big_count_value
[
i
]
=
true
;
--
rest_bin_cnt
;
rest_sample_cnt
-=
counts
[
i
];
}
}
}
cnt_in_bin0
=
counts
[
0
];
mean_bin_size
=
rest_sample_cnt
/
static_cast
<
double
>
(
rest_bin_cnt
);
bin_upper_bound_
[
num_values
-
1
]
=
std
::
numeric_limits
<
double
>::
infinity
();
}
else
{
// mean size for one bin
double
mean_bin_size
=
sample_size
/
static_cast
<
double
>
(
max_bin
);
int
rest_bin_cnt
=
max_bin
;
int
rest_sample_cnt
=
static_cast
<
int
>
(
sample_size
);
std
::
vector
<
bool
>
is_big_count_value
(
num_values
,
false
);
for
(
int
i
=
0
;
i
<
num_values
;
++
i
)
{
if
(
counts
[
i
]
>=
mean_bin_size
)
{
is_big_count_value
[
i
]
=
true
;
--
rest_bin_cnt
;
rest_sample_cnt
-=
counts
[
i
];
}
}
mean_bin_size
=
rest_sample_cnt
/
static_cast
<
double
>
(
rest_bin_cnt
);
std
::
vector
<
double
>
upper_bounds
(
max_bin
,
std
::
numeric_limits
<
double
>::
infinity
());
std
::
vector
<
double
>
upper_bounds
(
max_bin
,
std
::
numeric_limits
<
double
>::
infinity
());
std
::
vector
<
double
>
lower_bounds
(
max_bin
,
std
::
numeric_limits
<
double
>::
infinity
());
std
::
vector
<
double
>
lower_bounds
(
max_bin
,
std
::
numeric_limits
<
double
>::
infinity
());
int
bin_cnt
=
0
;
int
bin_cnt
=
0
;
lower_bounds
[
bin_cnt
]
=
distinct_values
[
0
];
lower_bounds
[
bin_cnt
]
=
distinct_values
[
0
];
int
cur_cnt_inbin
=
0
;
int
cur_cnt_inbin
=
0
;
for
(
int
i
=
0
;
i
<
num_values
-
1
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_values
-
1
;
++
i
)
{
if
(
!
is_big_count_value
[
i
])
{
rest_sample_cnt
-=
counts
[
i
];
}
cur_cnt_inbin
+=
counts
[
i
];
// need a new bin
if
(
is_big_count_value
[
i
]
||
cur_cnt_inbin
>=
mean_bin_size
||
(
is_big_count_value
[
i
+
1
]
&&
cur_cnt_inbin
>=
std
::
max
(
1.0
,
mean_bin_size
*
0.5
f
)))
{
upper_bounds
[
bin_cnt
]
=
distinct_values
[
i
];
if
(
bin_cnt
==
0
)
{
cnt_in_bin0
=
cur_cnt_inbin
;
}
++
bin_cnt
;
lower_bounds
[
bin_cnt
]
=
distinct_values
[
i
+
1
];
if
(
bin_cnt
>=
max_bin
-
1
)
{
break
;
}
cur_cnt_inbin
=
0
;
if
(
!
is_big_count_value
[
i
])
{
if
(
!
is_big_count_value
[
i
])
{
--
rest_bin_cnt
;
rest_sample_cnt
-=
counts
[
i
];
mean_bin_size
=
rest_sample_cnt
/
static_cast
<
double
>
(
rest_bin_cnt
);
}
cur_cnt_inbin
+=
counts
[
i
];
// need a new bin
if
(
is_big_count_value
[
i
]
||
cur_cnt_inbin
>=
mean_bin_size
||
(
is_big_count_value
[
i
+
1
]
&&
cur_cnt_inbin
>=
std
::
max
(
1.0
,
mean_bin_size
*
0.5
f
)))
{
upper_bounds
[
bin_cnt
]
=
distinct_values
[
i
];
if
(
bin_cnt
==
0
)
{
cnt_in_bin0
=
cur_cnt_inbin
;
}
++
bin_cnt
;
lower_bounds
[
bin_cnt
]
=
distinct_values
[
i
+
1
];
if
(
bin_cnt
>=
max_bin
-
1
)
{
break
;
}
cur_cnt_inbin
=
0
;
if
(
!
is_big_count_value
[
i
])
{
--
rest_bin_cnt
;
mean_bin_size
=
rest_sample_cnt
/
static_cast
<
double
>
(
rest_bin_cnt
);
}
}
}
}
}
//
++
bin_cnt
;
// update bin upper bound
bin_upper_bound_
=
std
::
vector
<
double
>
(
bin_cnt
);
num_bin_
=
bin_cnt
;
for
(
int
i
=
0
;
i
<
bin_cnt
-
1
;
++
i
)
{
bin_upper_bound_
[
i
]
=
(
upper_bounds
[
i
]
+
lower_bounds
[
i
+
1
])
/
2.0
f
;
}
// last bin upper bound
bin_upper_bound_
[
bin_cnt
-
1
]
=
std
::
numeric_limits
<
double
>::
infinity
();
}
}
else
{
// convert to int type first
std
::
vector
<
int
>
distinct_values_int
;
std
::
vector
<
int
>
counts_int
;
distinct_values_int
.
push_back
(
static_cast
<
int
>
(
distinct_values
[
0
]));
counts_int
.
push_back
(
counts
[
0
]);
for
(
size_t
i
=
1
;
i
<
distinct_values
.
size
();
++
i
)
{
if
(
static_cast
<
int
>
(
distinct_values
[
i
])
!=
distinct_values_int
.
back
())
{
distinct_values_int
.
push_back
(
static_cast
<
int
>
(
distinct_values
[
i
]));
counts_int
.
push_back
(
counts
[
i
]);
}
else
{
counts_int
.
back
()
+=
counts
[
i
];
}
}
}
//
// sort by counts
++
bin_cnt
;
Common
::
SortForPair
<
int
,
int
>
(
counts_int
,
distinct_values_int
,
0
,
true
);
// update bin upper bound
// will ingore the categorical of small counts
bin_upper_bound_
=
std
::
vector
<
double
>
(
bin_cnt
);
num_bin_
=
std
::
min
(
max_bin
,
static_cast
<
int
>
(
counts_int
.
size
()));
num_bin_
=
bin_cnt
;
categorical_2_bin_
.
clear
();
for
(
int
i
=
0
;
i
<
bin_cnt
-
1
;
++
i
)
{
bin_2_categorical_
=
std
::
vector
<
int
>
(
num_bin_
);
bin_upper_bound_
[
i
]
=
(
upper_bounds
[
i
]
+
lower_bounds
[
i
+
1
])
/
2.0
f
;
int
used_cnt
=
0
;
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
bin_2_categorical_
[
i
]
=
distinct_values_int
[
i
];
categorical_2_bin_
[
distinct_values_int
[
i
]]
=
static_cast
<
unsigned
int
>
(
i
);
used_cnt
+=
counts_int
[
i
];
}
}
// last bin upper bound
if
(
used_cnt
/
static_cast
<
double
>
(
sample_size
)
<
0.95
f
)
{
bin_upper_bound_
[
bin_cnt
-
1
]
=
std
::
numeric_limits
<
double
>::
infinity
();
Log
::
Warning
(
"Too many categoricals are ignored, \
please use bigger max_bin or partition this column "
);
}
cnt_in_bin0
=
static_cast
<
int
>
(
sample_size
)
-
used_cnt
+
counts_int
[
0
];
}
}
// check trival(num_bin_ == 1) feature
// check trival(num_bin_ == 1) feature
if
(
num_bin_
<=
1
)
{
if
(
num_bin_
<=
1
)
{
is_trival_
=
true
;
is_trival_
=
true
;
...
@@ -161,6 +201,7 @@ int BinMapper::SizeForSpecificBin(int bin) {
...
@@ -161,6 +201,7 @@ int BinMapper::SizeForSpecificBin(int bin) {
size
+=
sizeof
(
int
);
size
+=
sizeof
(
int
);
size
+=
sizeof
(
bool
);
size
+=
sizeof
(
bool
);
size
+=
sizeof
(
double
);
size
+=
sizeof
(
double
);
size
+=
sizeof
(
BinType
);
size
+=
bin
*
sizeof
(
double
);
size
+=
bin
*
sizeof
(
double
);
return
size
;
return
size
;
}
}
...
@@ -172,7 +213,13 @@ void BinMapper::CopyTo(char * buffer) {
...
@@ -172,7 +213,13 @@ void BinMapper::CopyTo(char * buffer) {
buffer
+=
sizeof
(
is_trival_
);
buffer
+=
sizeof
(
is_trival_
);
std
::
memcpy
(
buffer
,
&
sparse_rate_
,
sizeof
(
sparse_rate_
));
std
::
memcpy
(
buffer
,
&
sparse_rate_
,
sizeof
(
sparse_rate_
));
buffer
+=
sizeof
(
sparse_rate_
);
buffer
+=
sizeof
(
sparse_rate_
);
std
::
memcpy
(
buffer
,
bin_upper_bound_
.
data
(),
num_bin_
*
sizeof
(
double
));
std
::
memcpy
(
buffer
,
&
bin_type_
,
sizeof
(
bin_type_
));
buffer
+=
sizeof
(
bin_type_
);
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
std
::
memcpy
(
buffer
,
bin_upper_bound_
.
data
(),
num_bin_
*
sizeof
(
double
));
}
else
{
std
::
memcpy
(
buffer
,
bin_2_categorical_
.
data
(),
num_bin_
*
sizeof
(
int
));
}
}
}
void
BinMapper
::
CopyFrom
(
const
char
*
buffer
)
{
void
BinMapper
::
CopyFrom
(
const
char
*
buffer
)
{
...
@@ -182,63 +229,115 @@ void BinMapper::CopyFrom(const char * buffer) {
...
@@ -182,63 +229,115 @@ void BinMapper::CopyFrom(const char * buffer) {
buffer
+=
sizeof
(
is_trival_
);
buffer
+=
sizeof
(
is_trival_
);
std
::
memcpy
(
&
sparse_rate_
,
buffer
,
sizeof
(
sparse_rate_
));
std
::
memcpy
(
&
sparse_rate_
,
buffer
,
sizeof
(
sparse_rate_
));
buffer
+=
sizeof
(
sparse_rate_
);
buffer
+=
sizeof
(
sparse_rate_
);
bin_upper_bound_
=
std
::
vector
<
double
>
(
num_bin_
);
std
::
memcpy
(
&
bin_type_
,
buffer
,
sizeof
(
bin_type_
));
std
::
memcpy
(
bin_upper_bound_
.
data
(),
buffer
,
num_bin_
*
sizeof
(
double
));
buffer
+=
sizeof
(
bin_type_
);
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
bin_upper_bound_
=
std
::
vector
<
double
>
(
num_bin_
);
std
::
memcpy
(
bin_upper_bound_
.
data
(),
buffer
,
num_bin_
*
sizeof
(
double
));
}
else
{
bin_2_categorical_
=
std
::
vector
<
int
>
(
num_bin_
);
std
::
memcpy
(
bin_2_categorical_
.
data
(),
buffer
,
num_bin_
*
sizeof
(
int
));
categorical_2_bin_
.
clear
();
for
(
int
i
=
0
;
i
<
num_bin_
;
++
i
)
{
categorical_2_bin_
[
bin_2_categorical_
[
i
]]
=
static_cast
<
unsigned
int
>
(
i
);
}
}
}
}
void
BinMapper
::
SaveBinaryToFile
(
FILE
*
file
)
const
{
void
BinMapper
::
SaveBinaryToFile
(
FILE
*
file
)
const
{
fwrite
(
&
num_bin_
,
sizeof
(
num_bin_
),
1
,
file
);
fwrite
(
&
num_bin_
,
sizeof
(
num_bin_
),
1
,
file
);
fwrite
(
&
is_trival_
,
sizeof
(
is_trival_
),
1
,
file
);
fwrite
(
&
is_trival_
,
sizeof
(
is_trival_
),
1
,
file
);
fwrite
(
&
sparse_rate_
,
sizeof
(
sparse_rate_
),
1
,
file
);
fwrite
(
&
sparse_rate_
,
sizeof
(
sparse_rate_
),
1
,
file
);
fwrite
(
bin_upper_bound_
.
data
(),
sizeof
(
double
),
num_bin_
,
file
);
fwrite
(
&
bin_type_
,
sizeof
(
bin_type_
),
1
,
file
);
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
fwrite
(
bin_upper_bound_
.
data
(),
sizeof
(
double
),
num_bin_
,
file
);
}
else
{
fwrite
(
bin_2_categorical_
.
data
(),
sizeof
(
int
),
num_bin_
,
file
);
}
}
}
size_t
BinMapper
::
SizesInByte
()
const
{
size_t
BinMapper
::
SizesInByte
()
const
{
return
sizeof
(
num_bin_
)
+
sizeof
(
is_trival_
)
+
sizeof
(
sparse_rate_
)
+
sizeof
(
double
)
*
num_bin_
;
size_t
ret
=
sizeof
(
num_bin_
)
+
sizeof
(
is_trival_
)
+
sizeof
(
sparse_rate_
)
+
sizeof
(
bin_type_
);
if
(
bin_type_
==
BinType
::
NumericalBin
)
{
ret
+=
sizeof
(
double
)
*
num_bin_
;
}
else
{
ret
+=
sizeof
(
int
)
*
num_bin_
;
}
return
ret
;
}
}
template
class
DenseBin
<
uint8_t
>;
template
class
DenseBin
<
uint8_t
>;
template
class
DenseBin
<
uint16_t
>;
template
class
DenseBin
<
uint16_t
>;
template
class
DenseBin
<
uint32_t
>;
template
class
DenseBin
<
uint32_t
>;
template
class
DenseCategoricalBin
<
uint8_t
>;
template
class
DenseCategoricalBin
<
uint16_t
>;
template
class
DenseCategoricalBin
<
uint32_t
>;
template
class
SparseBin
<
uint8_t
>;
template
class
SparseBin
<
uint8_t
>;
template
class
SparseBin
<
uint16_t
>;
template
class
SparseBin
<
uint16_t
>;
template
class
SparseBin
<
uint32_t
>;
template
class
SparseBin
<
uint32_t
>;
template
class
SparseCategoricalBin
<
uint8_t
>;
template
class
SparseCategoricalBin
<
uint16_t
>;
template
class
SparseCategoricalBin
<
uint32_t
>;
template
class
OrderedSparseBin
<
uint8_t
>;
template
class
OrderedSparseBin
<
uint8_t
>;
template
class
OrderedSparseBin
<
uint16_t
>;
template
class
OrderedSparseBin
<
uint16_t
>;
template
class
OrderedSparseBin
<
uint32_t
>;
template
class
OrderedSparseBin
<
uint32_t
>;
Bin
*
Bin
::
CreateBin
(
data_size_t
num_data
,
int
num_bin
,
double
sparse_rate
,
bool
is_enable_sparse
,
bool
*
is_sparse
,
int
default_bin
)
{
Bin
*
Bin
::
CreateBin
(
data_size_t
num_data
,
int
num_bin
,
double
sparse_rate
,
bool
is_enable_sparse
,
bool
*
is_sparse
,
int
default_bin
,
BinType
bin_type
)
{
// sparse threshold
// sparse threshold
const
double
kSparseThreshold
=
0.8
f
;
const
double
kSparseThreshold
=
0.8
f
;
if
(
sparse_rate
>=
kSparseThreshold
&&
is_enable_sparse
)
{
if
(
sparse_rate
>=
kSparseThreshold
&&
is_enable_sparse
)
{
*
is_sparse
=
true
;
*
is_sparse
=
true
;
return
CreateSparseBin
(
num_data
,
num_bin
,
default_bin
);
return
CreateSparseBin
(
num_data
,
num_bin
,
default_bin
,
bin_type
);
}
else
{
}
else
{
*
is_sparse
=
false
;
*
is_sparse
=
false
;
return
CreateDenseBin
(
num_data
,
num_bin
,
default_bin
);
return
CreateDenseBin
(
num_data
,
num_bin
,
default_bin
,
bin_type
);
}
}
}
}
Bin
*
Bin
::
CreateDenseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
)
{
Bin
*
Bin
::
CreateDenseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
,
BinType
bin_type
)
{
if
(
num_bin
<=
256
)
{
if
(
bin_type
==
BinType
::
NumericalBin
)
{
return
new
DenseBin
<
uint8_t
>
(
num_data
,
default_bin
);
if
(
num_bin
<=
256
)
{
}
else
if
(
num_bin
<=
65536
)
{
return
new
DenseBin
<
uint8_t
>
(
num_data
,
default_bin
);
return
new
DenseBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
if
(
num_bin
<=
65536
)
{
return
new
DenseBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
{
return
new
DenseBin
<
uint32_t
>
(
num_data
,
default_bin
);
}
}
else
{
}
else
{
return
new
DenseBin
<
uint32_t
>
(
num_data
,
default_bin
);
if
(
num_bin
<=
256
)
{
return
new
DenseCategoricalBin
<
uint8_t
>
(
num_data
,
default_bin
);
}
else
if
(
num_bin
<=
65536
)
{
return
new
DenseCategoricalBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
{
return
new
DenseCategoricalBin
<
uint32_t
>
(
num_data
,
default_bin
);
}
}
}
}
}
Bin
*
Bin
::
CreateSparseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
)
{
Bin
*
Bin
::
CreateSparseBin
(
data_size_t
num_data
,
int
num_bin
,
int
default_bin
,
BinType
bin_type
)
{
if
(
num_bin
<=
256
)
{
if
(
bin_type
==
BinType
::
NumericalBin
)
{
return
new
SparseBin
<
uint8_t
>
(
num_data
,
default_bin
);
if
(
num_bin
<=
256
)
{
}
else
if
(
num_bin
<=
65536
)
{
return
new
SparseBin
<
uint8_t
>
(
num_data
,
default_bin
);
return
new
SparseBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
if
(
num_bin
<=
65536
)
{
return
new
SparseBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
{
return
new
SparseBin
<
uint32_t
>
(
num_data
,
default_bin
);
}
}
else
{
}
else
{
return
new
SparseBin
<
uint32_t
>
(
num_data
,
default_bin
);
if
(
num_bin
<=
256
)
{
return
new
SparseCategoricalBin
<
uint8_t
>
(
num_data
,
default_bin
);
}
else
if
(
num_bin
<=
65536
)
{
return
new
SparseCategoricalBin
<
uint16_t
>
(
num_data
,
default_bin
);
}
else
{
return
new
SparseCategoricalBin
<
uint32_t
>
(
num_data
,
default_bin
);
}
}
}
}
}
...
...
src/io/config.cpp
View file @
1466f907
...
@@ -203,6 +203,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...
@@ -203,6 +203,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetString
(
params
,
"weight_column"
,
&
weight_column
);
GetString
(
params
,
"weight_column"
,
&
weight_column
);
GetString
(
params
,
"group_column"
,
&
group_column
);
GetString
(
params
,
"group_column"
,
&
group_column
);
GetString
(
params
,
"ignore_column"
,
&
ignore_column
);
GetString
(
params
,
"ignore_column"
,
&
ignore_column
);
GetString
(
params
,
"categorical_column"
,
&
categorical_column
);
}
}
...
@@ -216,7 +217,7 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
...
@@ -216,7 +217,7 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
GetDouble
(
params
,
"scale_pos_weight"
,
&
scale_pos_weight
);
GetDouble
(
params
,
"scale_pos_weight"
,
&
scale_pos_weight
);
std
::
string
tmp_str
=
""
;
std
::
string
tmp_str
=
""
;
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
label_gain
=
Common
::
StringTo
DoubleArray
(
tmp_str
,
','
);
label_gain
=
Common
::
StringTo
Array
<
double
>
(
tmp_str
,
','
);
}
else
{
}
else
{
// label_gain = 2^i - 1, may overflow, so we use 31 here
// label_gain = 2^i - 1, may overflow, so we use 31 here
const
int
max_label
=
31
;
const
int
max_label
=
31
;
...
@@ -234,7 +235,7 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
...
@@ -234,7 +235,7 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
GetInt
(
params
,
"num_class"
,
&
num_class
);
GetInt
(
params
,
"num_class"
,
&
num_class
);
std
::
string
tmp_str
=
""
;
std
::
string
tmp_str
=
""
;
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
if
(
GetString
(
params
,
"label_gain"
,
&
tmp_str
))
{
label_gain
=
Common
::
StringTo
DoubleArray
(
tmp_str
,
','
);
label_gain
=
Common
::
StringTo
Array
<
double
>
(
tmp_str
,
','
);
}
else
{
}
else
{
// label_gain = 2^i - 1, may overflow, so we use 31 here
// label_gain = 2^i - 1, may overflow, so we use 31 here
const
int
max_label
=
31
;
const
int
max_label
=
31
;
...
@@ -245,7 +246,7 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
...
@@ -245,7 +246,7 @@ void MetricConfig::Set(const std::unordered_map<std::string, std::string>& param
}
}
label_gain
.
shrink_to_fit
();
label_gain
.
shrink_to_fit
();
if
(
GetString
(
params
,
"ndcg_eval_at"
,
&
tmp_str
))
{
if
(
GetString
(
params
,
"ndcg_eval_at"
,
&
tmp_str
))
{
eval_at
=
Common
::
StringTo
Int
Array
(
tmp_str
,
','
);
eval_at
=
Common
::
StringToArray
<
int
>
(
tmp_str
,
','
);
std
::
sort
(
eval_at
.
begin
(),
eval_at
.
end
());
std
::
sort
(
eval_at
.
begin
(),
eval_at
.
end
());
for
(
size_t
i
=
0
;
i
<
eval_at
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
eval_at
.
size
();
++
i
)
{
CHECK
(
eval_at
[
i
]
>
0
);
CHECK
(
eval_at
[
i
]
>
0
);
...
...
src/io/dataset_loader.cpp
View file @
1466f907
...
@@ -8,9 +8,12 @@
...
@@ -8,9 +8,12 @@
namespace
LightGBM
{
namespace
LightGBM
{
DatasetLoader
::
DatasetLoader
(
const
IOConfig
&
io_config
,
const
PredictFunction
&
predict_fun
)
DatasetLoader
::
DatasetLoader
(
const
IOConfig
&
io_config
,
const
PredictFunction
&
predict_fun
,
const
char
*
filename
)
:
io_config_
(
io_config
),
random_
(
io_config_
.
data_random_seed
),
predict_fun_
(
predict_fun
)
{
:
io_config_
(
io_config
),
random_
(
io_config_
.
data_random_seed
),
predict_fun_
(
predict_fun
)
{
label_idx_
=
0
;
weight_idx_
=
NO_SPECIFIC
;
group_idx_
=
NO_SPECIFIC
;
SetHeader
(
filename
);
}
}
DatasetLoader
::~
DatasetLoader
()
{
DatasetLoader
::~
DatasetLoader
()
{
...
@@ -18,119 +21,141 @@ DatasetLoader::~DatasetLoader() {
...
@@ -18,119 +21,141 @@ DatasetLoader::~DatasetLoader() {
}
}
void
DatasetLoader
::
SetHeader
(
const
char
*
filename
)
{
void
DatasetLoader
::
SetHeader
(
const
char
*
filename
)
{
TextReader
<
data_size_t
>
text_reader
(
filename
,
io_config_
.
has_header
);
std
::
unordered_map
<
std
::
string
,
int
>
name2idx
;
std
::
unordered_map
<
std
::
string
,
int
>
name2idx
;
// get column names
if
(
io_config_
.
has_header
)
{
std
::
string
first_line
=
text_reader
.
first_line
();
feature_names_
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
for
(
size_t
i
=
0
;
i
<
feature_names_
.
size
();
++
i
)
{
name2idx
[
feature_names_
[
i
]]
=
static_cast
<
int
>
(
i
);
}
}
std
::
string
name_prefix
(
"name:"
);
std
::
string
name_prefix
(
"name:"
);
if
(
filename
!=
nullptr
)
{
// load label idx
TextReader
<
data_size_t
>
text_reader
(
filename
,
io_config_
.
has_header
);
if
(
io_config_
.
label_column
.
size
()
>
0
)
{
if
(
Common
::
StartsWith
(
io_config_
.
label_column
,
name_prefix
))
{
// get column names
std
::
string
name
=
io_config_
.
label_column
.
substr
(
name_prefix
.
size
());
if
(
io_config_
.
has_header
)
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
std
::
string
first_line
=
text_reader
.
first_line
();
label_idx_
=
name2idx
[
name
];
feature_names_
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
Log
::
Info
(
"Using column %s as label"
,
name
.
c_str
());
}
// load label idx first
if
(
io_config_
.
label_column
.
size
()
>
0
)
{
if
(
Common
::
StartsWith
(
io_config_
.
label_column
,
name_prefix
))
{
std
::
string
name
=
io_config_
.
label_column
.
substr
(
name_prefix
.
size
());
label_idx_
=
-
1
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
feature_names_
.
size
());
++
i
)
{
if
(
name
==
feature_names_
[
i
])
{
label_idx_
=
i
;
break
;
}
}
if
(
label_idx_
>=
0
)
{
Log
::
Info
(
"Using column %s as label"
,
name
.
c_str
());
}
else
{
Log
::
Fatal
(
"Could not find label column %s in data file \
or data file doesn't contain header"
,
name
.
c_str
());
}
}
else
{
}
else
{
Log
::
Fatal
(
"Could not find label column %s in data file"
,
name
.
c_str
());
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
label_column
.
c_str
(),
&
label_idx_
))
{
}
Log
::
Fatal
(
"label_column is not a number, \
}
else
{
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
label_column
.
c_str
(),
&
label_idx_
))
{
Log
::
Fatal
(
"label_column is not a number, \
if you want to use a column name, \
if you want to use a column name, \
please add the prefix
\"
name:
\"
to the column name"
);
please add the prefix
\"
name:
\"
to the column name"
);
}
Log
::
Info
(
"Using column number %d as label"
,
label_idx_
);
}
}
Log
::
Info
(
"Using column number %d as label"
,
label_idx_
);
}
}
}
if
(
feature_names_
.
size
()
>
0
)
{
if
(
feature_names_
.
size
()
>
0
)
{
// erase label column name
// erase label column name
feature_names_
.
erase
(
feature_names_
.
begin
()
+
label_idx_
);
feature_names_
.
erase
(
feature_names_
.
begin
()
+
label_idx_
);
}
for
(
size_t
i
=
0
;
i
<
feature_names_
.
size
();
++
i
)
{
// load ignore columns
name2idx
[
feature_names_
[
i
]]
=
static_cast
<
int
>
(
i
);
if
(
io_config_
.
ignore_column
.
size
()
>
0
)
{
if
(
Common
::
StartsWith
(
io_config_
.
ignore_column
,
name_prefix
))
{
std
::
string
names
=
io_config_
.
ignore_column
.
substr
(
name_prefix
.
size
());
for
(
auto
name
:
Common
::
Split
(
names
.
c_str
(),
','
))
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
int
tmp
=
name2idx
[
name
];
// skip for label column
if
(
tmp
>
label_idx_
)
{
tmp
-=
1
;
}
ignore_features_
.
emplace
(
tmp
);
}
else
{
Log
::
Fatal
(
"Could not find ignore column %s in data file"
,
name
.
c_str
());
}
}
}
}
else
{
}
for
(
auto
token
:
Common
::
Split
(
io_config_
.
ignore_column
.
c_str
(),
','
))
{
int
tmp
=
0
;
// load ignore columns
if
(
!
Common
::
AtoiAndCheck
(
token
.
c_str
(),
&
tmp
))
{
if
(
io_config_
.
ignore_column
.
size
()
>
0
)
{
Log
::
Fatal
(
"ignore_column is not a number, \
if
(
Common
::
StartsWith
(
io_config_
.
ignore_column
,
name_prefix
))
{
std
::
string
names
=
io_config_
.
ignore_column
.
substr
(
name_prefix
.
size
());
for
(
auto
name
:
Common
::
Split
(
names
.
c_str
(),
','
))
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
int
tmp
=
name2idx
[
name
];
ignore_features_
.
emplace
(
tmp
);
}
else
{
Log
::
Fatal
(
"Could not find ignore column %s in data file"
,
name
.
c_str
());
}
}
}
else
{
for
(
auto
token
:
Common
::
Split
(
io_config_
.
ignore_column
.
c_str
(),
','
))
{
int
tmp
=
0
;
if
(
!
Common
::
AtoiAndCheck
(
token
.
c_str
(),
&
tmp
))
{
Log
::
Fatal
(
"ignore_column is not a number, \
if you want to use a column name, \
if you want to use a column name, \
please add the prefix
\"
name:
\"
to the column name"
);
please add the prefix
\"
name:
\"
to the column name"
);
}
ignore_features_
.
emplace
(
tmp
);
}
}
// skip for label column
if
(
tmp
>
label_idx_
)
{
tmp
-=
1
;
}
ignore_features_
.
emplace
(
tmp
);
}
}
}
}
// load weight idx
}
if
(
io_config_
.
weight_column
.
size
()
>
0
)
{
if
(
Common
::
StartsWith
(
io_config_
.
weight_column
,
name_prefix
))
{
// load weight idx
std
::
string
name
=
io_config_
.
weight_column
.
substr
(
name_prefix
.
size
());
if
(
io_config_
.
weight_column
.
size
(
)
>
0
)
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
if
(
Common
::
StartsWith
(
io_config_
.
weight_column
,
name_prefix
))
{
weight_idx_
=
name2idx
[
name
];
std
::
string
name
=
io_config_
.
weight_column
.
substr
(
name_prefix
.
size
());
Log
::
Info
(
"Using column %s as weight"
,
name
.
c_str
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
}
else
{
weight_idx_
=
name2idx
[
name
]
;
Log
::
Fatal
(
"Could not find weight column %s in data file"
,
name
.
c_str
())
;
Log
::
Info
(
"Using column %s as weight"
,
name
.
c_str
());
}
}
else
{
}
else
{
Log
::
Fatal
(
"Could not find weight column %s in data file"
,
name
.
c_str
());
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
weight_column
.
c_str
(),
&
weight_idx_
))
{
}
Log
::
Fatal
(
"weight_column is not a number, \
}
else
{
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
weight_column
.
c_str
(),
&
weight_idx_
))
{
Log
::
Fatal
(
"weight_column is not a number, \
if you want to use a column name, \
if you want to use a column name, \
please add the prefix
\"
name:
\"
to the column name"
);
please add the prefix
\"
name:
\"
to the column name"
);
}
Log
::
Info
(
"Using column number %d as weight"
,
weight_idx_
);
}
}
Log
::
Info
(
"Using column number %d as weight"
,
weight_idx_
);
ignore_features_
.
emplace
(
weight_idx_
);
}
}
// skip for label column
// load group idx
if
(
weight_idx_
>
label_idx_
)
{
if
(
io_config_
.
group_column
.
size
()
>
0
)
{
weight_idx_
-=
1
;
if
(
Common
::
StartsWith
(
io_config_
.
group_column
,
name_prefix
))
{
std
::
string
name
=
io_config_
.
group_column
.
substr
(
name_prefix
.
size
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
group_idx_
=
name2idx
[
name
];
Log
::
Info
(
"Using column %s as group/query id"
,
name
.
c_str
());
}
else
{
Log
::
Fatal
(
"Could not find group/query column %s in data file"
,
name
.
c_str
());
}
}
else
{
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
group_column
.
c_str
(),
&
group_idx_
))
{
Log
::
Fatal
(
"group_column is not a number, \
if you want to use a column name, \
please add the prefix
\"
name:
\"
to the column name"
);
}
Log
::
Info
(
"Using column number %d as group/query id"
,
group_idx_
);
}
ignore_features_
.
emplace
(
group_idx_
);
}
}
ignore_features_
.
emplace
(
weight_idx_
);
}
}
if
(
io_config_
.
group_column
.
size
()
>
0
)
{
// load categorical features
if
(
Common
::
StartsWith
(
io_config_
.
group_column
,
name_prefix
))
{
if
(
io_config_
.
categorical_column
.
size
()
>
0
)
{
std
::
string
name
=
io_config_
.
group_column
.
substr
(
name_prefix
.
size
());
if
(
Common
::
StartsWith
(
io_config_
.
categorical_column
,
name_prefix
))
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
std
::
string
names
=
io_config_
.
categorical_column
.
substr
(
name_prefix
.
size
());
group_idx_
=
name2idx
[
name
];
for
(
auto
name
:
Common
::
Split
(
names
.
c_str
(),
','
))
{
Log
::
Info
(
"Using column %s as group/query id"
,
name
.
c_str
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
}
else
{
int
tmp
=
name2idx
[
name
];
Log
::
Fatal
(
"Could not find group/query column %s in data file"
,
name
.
c_str
());
categorical_features_
.
emplace
(
tmp
);
}
else
{
Log
::
Fatal
(
"Could not find categorical_column %s in data file"
,
name
.
c_str
());
}
}
}
}
else
{
}
else
{
if
(
!
Common
::
AtoiAndCheck
(
io_config_
.
group_column
.
c_str
(),
&
group_idx_
))
{
for
(
auto
token
:
Common
::
Split
(
io_config_
.
categorical_column
.
c_str
(),
','
))
{
Log
::
Fatal
(
"group_column is not a number, \
int
tmp
=
0
;
if you want to use a column name, \
if
(
!
Common
::
AtoiAndCheck
(
token
.
c_str
(),
&
tmp
))
{
please add the prefix
\"
name:
\"
to the column name"
);
Log
::
Fatal
(
"categorical_column is not a number, \
if you want to use a column name, \
please add the prefix
\"
name:
\"
to the column name"
);
}
categorical_features_
.
emplace
(
tmp
);
}
}
Log
::
Info
(
"Using column number %d as group/query id"
,
group_idx_
);
}
}
// skip for label column
if
(
group_idx_
>
label_idx_
)
{
group_idx_
-=
1
;
}
ignore_features_
.
emplace
(
group_idx_
);
}
}
}
}
...
@@ -415,7 +440,11 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
...
@@ -415,7 +440,11 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
sample_values
.
size
());
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
sample_values
.
size
());
++
i
)
{
bin_mappers
[
i
].
reset
(
new
BinMapper
());
bin_mappers
[
i
].
reset
(
new
BinMapper
());
bin_mappers
[
i
]
->
FindBin
(
&
sample_values
[
i
],
total_sample_size
,
io_config_
.
max_bin
);
BinType
bin_type
=
BinType
::
NumericalBin
;
if
(
categorical_features_
.
count
(
i
))
{
bin_type
=
BinType
::
CategoricalBin
;
}
bin_mappers
[
i
]
->
FindBin
(
&
sample_values
[
i
],
total_sample_size
,
io_config_
.
max_bin
,
bin_type
);
}
}
auto
dataset
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
());
auto
dataset
=
std
::
unique_ptr
<
Dataset
>
(
new
Dataset
());
...
@@ -631,7 +660,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
...
@@ -631,7 +660,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
continue
;
continue
;
}
}
bin_mappers
[
i
].
reset
(
new
BinMapper
());
bin_mappers
[
i
].
reset
(
new
BinMapper
());
bin_mappers
[
i
]
->
FindBin
(
&
sample_values
[
i
],
sample_data
.
size
(),
io_config_
.
max_bin
);
BinType
bin_type
=
BinType
::
NumericalBin
;
if
(
categorical_features_
.
count
(
i
))
{
bin_type
=
BinType
::
CategoricalBin
;
}
bin_mappers
[
i
]
->
FindBin
(
&
sample_values
[
i
],
sample_data
.
size
(),
io_config_
.
max_bin
,
bin_type
);
}
}
for
(
size_t
i
=
0
;
i
<
sample_values
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sample_values
.
size
();
++
i
)
{
...
@@ -681,7 +714,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
...
@@ -681,7 +714,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
len
[
rank
];
++
i
)
{
for
(
int
i
=
0
;
i
<
len
[
rank
];
++
i
)
{
BinMapper
bin_mapper
;
BinMapper
bin_mapper
;
bin_mapper
.
FindBin
(
&
sample_values
[
start
[
rank
]
+
i
],
sample_data
.
size
(),
io_config_
.
max_bin
);
BinType
bin_type
=
BinType
::
NumericalBin
;
if
(
categorical_features_
.
count
(
start
[
rank
]
+
i
))
{
bin_type
=
BinType
::
CategoricalBin
;
}
bin_mapper
.
FindBin
(
&
sample_values
[
start
[
rank
]
+
i
],
sample_data
.
size
(),
io_config_
.
max_bin
,
bin_type
);
bin_mapper
.
CopyTo
(
input_buffer
.
data
()
+
i
*
type_size
);
bin_mapper
.
CopyTo
(
input_buffer
.
data
()
+
i
*
type_size
);
}
}
// convert to binary size
// convert to binary size
...
...
src/io/dense_bin.hpp
View file @
1466f907
...
@@ -103,7 +103,7 @@ public:
...
@@ -103,7 +103,7 @@ public:
}
}
}
}
data_size_t
Split
(
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
virtual
data_size_t
Split
(
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
data_size_t
*
lte_indices
,
data_size_t
*
gt_indices
)
const
override
{
data_size_t
*
lte_indices
,
data_size_t
*
gt_indices
)
const
override
{
data_size_t
lte_count
=
0
;
data_size_t
lte_count
=
0
;
data_size_t
gt_count
=
0
;
data_size_t
gt_count
=
0
;
...
@@ -145,7 +145,7 @@ public:
...
@@ -145,7 +145,7 @@ public:
return
sizeof
(
VAL_T
)
*
num_data_
;
return
sizeof
(
VAL_T
)
*
num_data_
;
}
}
pr
ivate
:
pr
otected
:
data_size_t
num_data_
;
data_size_t
num_data_
;
std
::
vector
<
VAL_T
>
data_
;
std
::
vector
<
VAL_T
>
data_
;
};
};
...
@@ -168,5 +168,28 @@ BinIterator* DenseBin<VAL_T>::GetIterator(data_size_t) const {
...
@@ -168,5 +168,28 @@ BinIterator* DenseBin<VAL_T>::GetIterator(data_size_t) const {
return
new
DenseBinIterator
<
VAL_T
>
(
this
);
return
new
DenseBinIterator
<
VAL_T
>
(
this
);
}
}
template
<
typename
VAL_T
>
class
DenseCategoricalBin
:
public
DenseBin
<
VAL_T
>
{
public:
DenseCategoricalBin
(
data_size_t
num_data
,
int
default_bin
)
:
DenseBin
<
VAL_T
>
(
num_data
,
default_bin
)
{
}
virtual
data_size_t
Split
(
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
data_size_t
*
lte_indices
,
data_size_t
*
gt_indices
)
const
override
{
data_size_t
lte_count
=
0
;
data_size_t
gt_count
=
0
;
for
(
data_size_t
i
=
0
;
i
<
num_data
;
++
i
)
{
data_size_t
idx
=
data_indices
[
i
];
if
(
DenseBin
<
VAL_T
>::
data_
[
idx
]
!=
threshold
)
{
gt_indices
[
gt_count
++
]
=
idx
;
}
else
{
lte_indices
[
lte_count
++
]
=
idx
;
}
}
return
lte_count
;
}
};
}
// namespace LightGBM
}
// namespace LightGBM
#endif // LightGBM_IO_DENSE_BIN_HPP_
#endif // LightGBM_IO_DENSE_BIN_HPP_
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment