Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a36eb7e7
Commit
a36eb7e7
authored
Nov 01, 2019
by
Guolin Ke
Browse files
remove many vector.at()
parent
8f7199a4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
41 deletions
+52
-41
include/LightGBM/utils/common.h
include/LightGBM/utils/common.h
+16
-12
src/io/dataset.cpp
src/io/dataset.cpp
+24
-21
src/io/dataset_loader.cpp
src/io/dataset_loader.cpp
+4
-3
src/metric/multiclass_metric.hpp
src/metric/multiclass_metric.hpp
+5
-3
src/treelearner/cost_effective_gradient_boosting.hpp
src/treelearner/cost_effective_gradient_boosting.hpp
+3
-2
No files found.
include/LightGBM/utils/common.h
View file @
a36eb7e7
...
@@ -632,8 +632,8 @@ inline static void Softmax(const double* input, double* output, int len) {
...
@@ -632,8 +632,8 @@ inline static void Softmax(const double* input, double* output, int len) {
template
<
typename
T
>
template
<
typename
T
>
std
::
vector
<
const
T
*>
ConstPtrInVectorWrapper
(
const
std
::
vector
<
std
::
unique_ptr
<
T
>>&
input
)
{
std
::
vector
<
const
T
*>
ConstPtrInVectorWrapper
(
const
std
::
vector
<
std
::
unique_ptr
<
T
>>&
input
)
{
std
::
vector
<
const
T
*>
ret
;
std
::
vector
<
const
T
*>
ret
;
for
(
size_t
i
=
0
;
i
<
input
.
size
();
++
i
)
{
for
(
auto
t
=
input
.
begin
();
t
!=
input
.
end
();
++
t
)
{
ret
.
push_back
(
input
.
at
(
i
).
get
());
ret
.
push_back
(
t
->
get
());
}
}
return
ret
;
return
ret
;
}
}
...
@@ -641,8 +641,10 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
...
@@ -641,8 +641,10 @@ std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
>
inline
static
void
SortForPair
(
std
::
vector
<
T1
>*
keys
,
std
::
vector
<
T2
>*
values
,
size_t
start
,
bool
is_reverse
=
false
)
{
inline
static
void
SortForPair
(
std
::
vector
<
T1
>*
keys
,
std
::
vector
<
T2
>*
values
,
size_t
start
,
bool
is_reverse
=
false
)
{
std
::
vector
<
std
::
pair
<
T1
,
T2
>>
arr
;
std
::
vector
<
std
::
pair
<
T1
,
T2
>>
arr
;
auto
&
ref_key
=
*
keys
;
auto
&
ref_value
=
*
values
;
for
(
size_t
i
=
start
;
i
<
keys
->
size
();
++
i
)
{
for
(
size_t
i
=
start
;
i
<
keys
->
size
();
++
i
)
{
arr
.
emplace_back
(
keys
->
at
(
i
),
values
->
at
(
i
)
);
arr
.
emplace_back
(
ref_key
[
i
],
ref_value
[
i
]
);
}
}
if
(
!
is_reverse
)
{
if
(
!
is_reverse
)
{
std
::
stable_sort
(
arr
.
begin
(),
arr
.
end
(),
[](
const
std
::
pair
<
T1
,
T2
>&
a
,
const
std
::
pair
<
T1
,
T2
>&
b
)
{
std
::
stable_sort
(
arr
.
begin
(),
arr
.
end
(),
[](
const
std
::
pair
<
T1
,
T2
>&
a
,
const
std
::
pair
<
T1
,
T2
>&
b
)
{
...
@@ -654,16 +656,17 @@ inline static void SortForPair(std::vector<T1>* keys, std::vector<T2>* values, s
...
@@ -654,16 +656,17 @@ inline static void SortForPair(std::vector<T1>* keys, std::vector<T2>* values, s
});
});
}
}
for
(
size_t
i
=
start
;
i
<
arr
.
size
();
++
i
)
{
for
(
size_t
i
=
start
;
i
<
arr
.
size
();
++
i
)
{
keys
->
at
(
i
)
=
arr
[
i
].
first
;
ref_key
[
i
]
=
arr
[
i
].
first
;
value
s
->
at
(
i
)
=
arr
[
i
].
second
;
ref_
value
[
i
]
=
arr
[
i
].
second
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
static
std
::
vector
<
T
*>
Vector2Ptr
(
std
::
vector
<
std
::
vector
<
T
>>*
data
)
{
inline
static
std
::
vector
<
T
*>
Vector2Ptr
(
std
::
vector
<
std
::
vector
<
T
>>*
data
)
{
std
::
vector
<
T
*>
ptr
(
data
->
size
());
std
::
vector
<
T
*>
ptr
(
data
->
size
());
auto
&
ref_data
=
*
data
;
for
(
size_t
i
=
0
;
i
<
data
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
data
->
size
();
++
i
)
{
ptr
[
i
]
=
data
->
at
(
i
)
.
data
();
ptr
[
i
]
=
ref_data
[
i
]
.
data
();
}
}
return
ptr
;
return
ptr
;
}
}
...
@@ -841,12 +844,13 @@ inline static std::vector<uint32_t> EmptyBitset(int n) {
...
@@ -841,12 +844,13 @@ inline static std::vector<uint32_t> EmptyBitset(int n) {
template
<
typename
T
>
template
<
typename
T
>
inline
static
void
InsertBitset
(
std
::
vector
<
uint32_t
>*
vec
,
const
T
val
)
{
inline
static
void
InsertBitset
(
std
::
vector
<
uint32_t
>*
vec
,
const
T
val
)
{
int
i1
=
val
/
32
;
auto
&
ref_v
=
*
vec
;
int
i2
=
val
%
32
;
int
i1
=
val
/
32
;
if
(
static_cast
<
int
>
(
vec
->
size
())
<
i1
+
1
)
{
int
i2
=
val
%
32
;
vec
->
resize
(
i1
+
1
,
0
);
if
(
static_cast
<
int
>
(
vec
->
size
())
<
i1
+
1
)
{
}
vec
->
resize
(
i1
+
1
,
0
);
vec
->
at
(
i1
)
|=
(
1
<<
i2
);
}
ref_v
[
i1
]
|=
(
1
<<
i2
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/io/dataset.cpp
View file @
a36eb7e7
...
@@ -61,8 +61,9 @@ int GetConfilctCount(const std::vector<bool>& mark, const int* indices, int num_
...
@@ -61,8 +61,9 @@ int GetConfilctCount(const std::vector<bool>& mark, const int* indices, int num_
return
ret
;
return
ret
;
}
}
void
MarkUsed
(
std
::
vector
<
bool
>*
mark
,
const
int
*
indices
,
int
num_indices
)
{
void
MarkUsed
(
std
::
vector
<
bool
>*
mark
,
const
int
*
indices
,
int
num_indices
)
{
auto
&
ref_mark
=
*
mark
;
for
(
int
i
=
0
;
i
<
num_indices
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_indices
;
++
i
)
{
mark
->
at
(
indices
[
i
]
)
=
true
;
ref_mark
[
indices
[
i
]
]
=
true
;
}
}
}
}
...
@@ -238,8 +239,9 @@ void Dataset::Construct(
...
@@ -238,8 +239,9 @@ void Dataset::Construct(
sparse_threshold_
=
io_config
.
sparse_threshold
;
sparse_threshold_
=
io_config
.
sparse_threshold
;
// get num_features
// get num_features
std
::
vector
<
int
>
used_features
;
std
::
vector
<
int
>
used_features
;
auto
&
ref_bin_mappers
=
*
bin_mappers
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
bin_mappers
->
size
());
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
bin_mappers
->
size
());
++
i
)
{
if
(
bin_mappers
->
at
(
i
)
!=
nullptr
&&
!
bin_mappers
->
at
(
i
)
->
is_trivial
())
{
if
(
ref_
bin_mappers
[
i
]
!=
nullptr
&&
!
ref_
bin_mappers
[
i
]
->
is_trivial
())
{
used_features
.
emplace_back
(
i
);
used_features
.
emplace_back
(
i
);
}
}
}
}
...
@@ -277,7 +279,7 @@ void Dataset::Construct(
...
@@ -277,7 +279,7 @@ void Dataset::Construct(
real_feature_idx_
[
cur_fidx
]
=
real_fidx
;
real_feature_idx_
[
cur_fidx
]
=
real_fidx
;
feature2group_
[
cur_fidx
]
=
i
;
feature2group_
[
cur_fidx
]
=
i
;
feature2subfeature_
[
cur_fidx
]
=
j
;
feature2subfeature_
[
cur_fidx
]
=
j
;
cur_bin_mappers
.
emplace_back
(
bin_mappers
->
at
(
real_fidx
)
.
release
());
cur_bin_mappers
.
emplace_back
(
ref_
bin_mappers
[
real_fidx
]
.
release
());
++
cur_fidx
;
++
cur_fidx
;
}
}
feature_groups_
.
emplace_back
(
std
::
unique_ptr
<
FeatureGroup
>
(
feature_groups_
.
emplace_back
(
std
::
unique_ptr
<
FeatureGroup
>
(
...
@@ -848,6 +850,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -848,6 +850,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
int
num_used_group
=
static_cast
<
int
>
(
used_group
.
size
());
int
num_used_group
=
static_cast
<
int
>
(
used_group
.
size
());
auto
ptr_ordered_grad
=
gradients
;
auto
ptr_ordered_grad
=
gradients
;
auto
ptr_ordered_hess
=
hessians
;
auto
ptr_ordered_hess
=
hessians
;
auto
&
ref_ordered_bins
=
*
ordered_bins
;
if
(
data_indices
!=
nullptr
&&
num_data
<
num_data_
)
{
if
(
data_indices
!=
nullptr
&&
num_data
<
num_data_
)
{
if
(
!
is_constant_hessian
)
{
if
(
!
is_constant_hessian
)
{
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
...
@@ -874,7 +877,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -874,7 +877,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins
->
at
(
group
)
==
nullptr
)
{
if
(
ref_
ordered_bins
[
group
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
data_indices
,
data_indices
,
...
@@ -884,10 +887,10 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -884,10 +887,10 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
data_ptr
);
data_ptr
);
}
else
{
}
else
{
// used ordered bin
// used ordered bin
ordered_bins
->
at
(
group
)
->
ConstructHistogram
(
leaf_idx
,
ref_
ordered_bins
[
group
]
->
ConstructHistogram
(
leaf_idx
,
gradients
,
gradients
,
hessians
,
hessians
,
data_ptr
);
data_ptr
);
}
}
OMP_LOOP_EX_END
();
OMP_LOOP_EX_END
();
}
}
...
@@ -903,7 +906,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -903,7 +906,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins
->
at
(
group
)
==
nullptr
)
{
if
(
ref_
ordered_bins
[
group
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
data_indices
,
data_indices
,
...
@@ -912,9 +915,9 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -912,9 +915,9 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
data_ptr
);
data_ptr
);
}
else
{
}
else
{
// used ordered bin
// used ordered bin
ordered_bins
->
at
(
group
)
->
ConstructHistogram
(
leaf_idx
,
ref_
ordered_bins
[
group
]
->
ConstructHistogram
(
leaf_idx
,
gradients
,
gradients
,
data_ptr
);
data_ptr
);
}
}
// fixed hessian.
// fixed hessian.
for
(
int
i
=
0
;
i
<
num_bin
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_bin
;
++
i
)
{
...
@@ -936,7 +939,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -936,7 +939,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins
->
at
(
group
)
==
nullptr
)
{
if
(
ref_
ordered_bins
[
group
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
num_data
,
num_data
,
...
@@ -945,10 +948,10 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -945,10 +948,10 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
data_ptr
);
data_ptr
);
}
else
{
}
else
{
// used ordered bin
// used ordered bin
ordered_bins
->
at
(
group
)
->
ConstructHistogram
(
leaf_idx
,
ref_
ordered_bins
[
group
]
->
ConstructHistogram
(
leaf_idx
,
gradients
,
gradients
,
hessians
,
hessians
,
data_ptr
);
data_ptr
);
}
}
OMP_LOOP_EX_END
();
OMP_LOOP_EX_END
();
}
}
...
@@ -964,7 +967,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -964,7 +967,7 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
const
int
num_bin
=
feature_groups_
[
group
]
->
num_total_bin_
;
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
std
::
memset
(
reinterpret_cast
<
void
*>
(
data_ptr
+
1
),
0
,
(
num_bin
-
1
)
*
sizeof
(
HistogramBinEntry
));
// construct histograms for smaller leaf
// construct histograms for smaller leaf
if
(
ordered_bins
->
at
(
group
)
==
nullptr
)
{
if
(
ref_
ordered_bins
[
group
]
==
nullptr
)
{
// if not use ordered bin
// if not use ordered bin
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
feature_groups_
[
group
]
->
bin_data_
->
ConstructHistogram
(
num_data
,
num_data
,
...
@@ -972,9 +975,9 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
...
@@ -972,9 +975,9 @@ void Dataset::ConstructHistograms(const std::vector<int8_t>& is_feature_used,
data_ptr
);
data_ptr
);
}
else
{
}
else
{
// used ordered bin
// used ordered bin
ordered_bins
->
at
(
group
)
->
ConstructHistogram
(
leaf_idx
,
ref_
ordered_bins
[
group
]
->
ConstructHistogram
(
leaf_idx
,
gradients
,
gradients
,
data_ptr
);
data_ptr
);
}
}
// fixed hessian.
// fixed hessian.
for
(
int
i
=
0
;
i
<
num_bin
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_bin
;
++
i
)
{
...
...
src/io/dataset_loader.cpp
View file @
a36eb7e7
...
@@ -1048,6 +1048,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
...
@@ -1048,6 +1048,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
void
DatasetLoader
::
ExtractFeaturesFromMemory
(
std
::
vector
<
std
::
string
>*
text_data
,
const
Parser
*
parser
,
Dataset
*
dataset
)
{
void
DatasetLoader
::
ExtractFeaturesFromMemory
(
std
::
vector
<
std
::
string
>*
text_data
,
const
Parser
*
parser
,
Dataset
*
dataset
)
{
std
::
vector
<
std
::
pair
<
int
,
double
>>
oneline_features
;
std
::
vector
<
std
::
pair
<
int
,
double
>>
oneline_features
;
double
tmp_label
=
0.0
f
;
double
tmp_label
=
0.0
f
;
auto
&
ref_text_data
=
*
text_data
;
if
(
predict_fun_
==
nullptr
)
{
if
(
predict_fun_
==
nullptr
)
{
OMP_INIT_EX
();
OMP_INIT_EX
();
// if doesn't need to prediction with initial model
// if doesn't need to prediction with initial model
...
@@ -1057,11 +1058,11 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
...
@@ -1057,11 +1058,11 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
const
int
tid
=
omp_get_thread_num
();
const
int
tid
=
omp_get_thread_num
();
oneline_features
.
clear
();
oneline_features
.
clear
();
// parser
// parser
parser
->
ParseOneLine
(
text_data
->
at
(
i
)
.
c_str
(),
&
oneline_features
,
&
tmp_label
);
parser
->
ParseOneLine
(
ref_
text_data
[
i
]
.
c_str
(),
&
oneline_features
,
&
tmp_label
);
// set label
// set label
dataset
->
metadata_
.
SetLabelAt
(
i
,
static_cast
<
label_t
>
(
tmp_label
));
dataset
->
metadata_
.
SetLabelAt
(
i
,
static_cast
<
label_t
>
(
tmp_label
));
// free processed line:
// free processed line:
text_data
->
at
(
i
)
.
clear
();
ref_
text_data
[
i
]
.
clear
();
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
// text_reader_->Lines()[i].shrink_to_fit();
// text_reader_->Lines()[i].shrink_to_fit();
// push data
// push data
...
@@ -1094,7 +1095,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
...
@@ -1094,7 +1095,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_dat
const
int
tid
=
omp_get_thread_num
();
const
int
tid
=
omp_get_thread_num
();
oneline_features
.
clear
();
oneline_features
.
clear
();
// parser
// parser
parser
->
ParseOneLine
(
text_data
->
at
(
i
)
.
c_str
(),
&
oneline_features
,
&
tmp_label
);
parser
->
ParseOneLine
(
ref_
text_data
[
i
]
.
c_str
(),
&
oneline_features
,
&
tmp_label
);
// set initial score
// set initial score
std
::
vector
<
double
>
oneline_init_score
(
num_class_
);
std
::
vector
<
double
>
oneline_init_score
(
num_class_
);
predict_fun_
(
oneline_features
,
oneline_init_score
.
data
());
predict_fun_
(
oneline_features
,
oneline_init_score
.
data
());
...
...
src/metric/multiclass_metric.hpp
View file @
a36eb7e7
...
@@ -140,9 +140,10 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
...
@@ -140,9 +140,10 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
inline
static
double
LossOnPoint
(
label_t
label
,
std
::
vector
<
double
>*
score
,
const
Config
&
config
)
{
inline
static
double
LossOnPoint
(
label_t
label
,
std
::
vector
<
double
>*
score
,
const
Config
&
config
)
{
size_t
k
=
static_cast
<
size_t
>
(
label
);
size_t
k
=
static_cast
<
size_t
>
(
label
);
auto
&
ref_score
=
*
score
;
int
num_larger
=
0
;
int
num_larger
=
0
;
for
(
size_t
i
=
0
;
i
<
score
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
score
->
size
();
++
i
)
{
if
(
score
->
at
(
i
)
>=
score
->
at
(
k
)
)
++
num_larger
;
if
(
ref_
score
[
i
]
>=
ref_score
[
k
]
)
++
num_larger
;
if
(
num_larger
>
config
.
multi_error_top_k
)
return
1.0
f
;
if
(
num_larger
>
config
.
multi_error_top_k
)
return
1.0
f
;
}
}
return
0.0
f
;
return
0.0
f
;
...
@@ -164,8 +165,9 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
...
@@ -164,8 +165,9 @@ class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetr
inline
static
double
LossOnPoint
(
label_t
label
,
std
::
vector
<
double
>*
score
,
const
Config
&
)
{
inline
static
double
LossOnPoint
(
label_t
label
,
std
::
vector
<
double
>*
score
,
const
Config
&
)
{
size_t
k
=
static_cast
<
size_t
>
(
label
);
size_t
k
=
static_cast
<
size_t
>
(
label
);
if
(
score
->
at
(
k
)
>
kEpsilon
)
{
auto
&
ref_score
=
*
score
;
return
static_cast
<
double
>
(
-
std
::
log
(
score
->
at
(
k
)));
if
(
ref_score
[
k
]
>
kEpsilon
)
{
return
static_cast
<
double
>
(
-
std
::
log
(
ref_score
[
k
]));
}
else
{
}
else
{
return
-
std
::
log
(
kEpsilon
);
return
-
std
::
log
(
kEpsilon
);
}
}
...
...
src/treelearner/cost_effective_gradient_boosting.hpp
View file @
a36eb7e7
...
@@ -63,14 +63,15 @@ class CostEfficientGradientBoosting {
...
@@ -63,14 +63,15 @@ class CostEfficientGradientBoosting {
auto
config
=
tree_learner_
->
config_
;
auto
config
=
tree_learner_
->
config_
;
auto
train_data
=
tree_learner_
->
train_data_
;
auto
train_data
=
tree_learner_
->
train_data_
;
const
int
inner_feature_index
=
train_data
->
InnerFeatureIndex
(
best_split_info
->
feature
);
const
int
inner_feature_index
=
train_data
->
InnerFeatureIndex
(
best_split_info
->
feature
);
auto
&
ref_best_split_per_leaf
=
*
best_split_per_leaf
;
if
(
!
config
->
cegb_penalty_feature_coupled
.
empty
()
&&
!
is_feature_used_in_split_
[
inner_feature_index
])
{
if
(
!
config
->
cegb_penalty_feature_coupled
.
empty
()
&&
!
is_feature_used_in_split_
[
inner_feature_index
])
{
is_feature_used_in_split_
[
inner_feature_index
]
=
true
;
is_feature_used_in_split_
[
inner_feature_index
]
=
true
;
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
for
(
int
i
=
0
;
i
<
tree
->
num_leaves
();
++
i
)
{
if
(
i
==
best_leaf
)
continue
;
if
(
i
==
best_leaf
)
continue
;
auto
split
=
&
splits_per_leaf_
[
static_cast
<
size_t
>
(
i
)
*
train_data
->
num_features
()
+
inner_feature_index
];
auto
split
=
&
splits_per_leaf_
[
static_cast
<
size_t
>
(
i
)
*
train_data
->
num_features
()
+
inner_feature_index
];
split
->
gain
+=
config
->
cegb_tradeoff
*
config
->
cegb_penalty_feature_coupled
[
best_split_info
->
feature
];
split
->
gain
+=
config
->
cegb_tradeoff
*
config
->
cegb_penalty_feature_coupled
[
best_split_info
->
feature
];
if
(
*
split
>
best_split_per_leaf
->
at
(
i
)
)
if
(
*
split
>
ref_
best_split_per_leaf
[
i
]
)
best_split_per_leaf
->
at
(
i
)
=
*
split
;
ref_
best_split_per_leaf
[
i
]
=
*
split
;
}
}
}
}
if
(
!
config
->
cegb_penalty_feature_lazy
.
empty
())
{
if
(
!
config
->
cegb_penalty_feature_lazy
.
empty
())
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment