Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
1c1749db
Commit
1c1749db
authored
Mar 22, 2017
by
Guolin Ke
Browse files
fix bug in filter bin
parent
8a0f07ae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
19 deletions
+18
-19
include/LightGBM/utils/common.h
include/LightGBM/utils/common.h
+2
-2
src/c_api.cpp
src/c_api.cpp
+6
-6
src/io/bin.cpp
src/io/bin.cpp
+2
-6
src/io/config.cpp
src/io/config.cpp
+2
-2
src/io/dataset_loader.cpp
src/io/dataset_loader.cpp
+6
-3
No files found.
include/LightGBM/utils/common.h
View file @
1c1749db
...
@@ -425,8 +425,8 @@ inline static double ApproximateHessianWithGaussian(const double y, const double
...
@@ -425,8 +425,8 @@ inline static double ApproximateHessianWithGaussian(const double y, const double
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
static
T
*
*
Vector2Ptr
(
std
::
vector
<
std
::
vector
<
T
>>&
data
)
{
inline
static
std
::
vector
<
T
*
>
Vector2Ptr
(
std
::
vector
<
std
::
vector
<
T
>>&
data
)
{
T
*
*
ptr
=
new
T
*
[
data
.
size
()
]
;
std
::
vector
<
T
*
>
ptr
(
data
.
size
()
)
;
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
data
.
size
();
++
i
)
{
ptr
[
i
]
=
data
[
i
].
data
();
ptr
[
i
]
=
data
[
i
].
data
();
}
}
...
...
src/c_api.cpp
View file @
1c1749db
...
@@ -423,8 +423,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
...
@@ -423,8 +423,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
}
}
}
}
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
),
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
)
.
data
()
,
Common
::
Vector2Ptr
<
int
>
(
sample_idx
),
Common
::
Vector2Ptr
<
int
>
(
sample_idx
)
.
data
()
,
static_cast
<
int
>
(
sample_values
.
size
()),
static_cast
<
int
>
(
sample_values
.
size
()),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
sample_cnt
,
nrow
));
sample_cnt
,
nrow
));
...
@@ -487,8 +487,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
...
@@ -487,8 +487,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
}
}
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
CHECK
(
num_col
>=
static_cast
<
int
>
(
sample_values
.
size
()));
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
),
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
)
.
data
()
,
Common
::
Vector2Ptr
<
int
>
(
sample_idx
),
Common
::
Vector2Ptr
<
int
>
(
sample_idx
)
.
data
()
,
static_cast
<
int
>
(
sample_values
.
size
()),
static_cast
<
int
>
(
sample_values
.
size
()),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
sample_cnt
,
nrow
));
sample_cnt
,
nrow
));
...
@@ -546,8 +546,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
...
@@ -546,8 +546,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
}
}
}
}
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
DatasetLoader
loader
(
io_config
,
nullptr
,
1
,
nullptr
);
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
),
ret
.
reset
(
loader
.
CostructFromSampleData
(
Common
::
Vector2Ptr
<
double
>
(
sample_values
)
.
data
()
,
Common
::
Vector2Ptr
<
int
>
(
sample_idx
),
Common
::
Vector2Ptr
<
int
>
(
sample_idx
)
.
data
()
,
static_cast
<
int
>
(
sample_values
.
size
()),
static_cast
<
int
>
(
sample_values
.
size
()),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
Common
::
VectorSize
<
double
>
(
sample_values
).
data
(),
sample_cnt
,
nrow
));
sample_cnt
,
nrow
));
...
...
src/io/bin.cpp
View file @
1c1749db
...
@@ -49,18 +49,14 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin
...
@@ -49,18 +49,14 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin
int
sum_left
=
0
;
int
sum_left
=
0
;
for
(
size_t
i
=
0
;
i
<
cnt_in_bin
.
size
()
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
cnt_in_bin
.
size
()
-
1
;
++
i
)
{
sum_left
+=
cnt_in_bin
[
i
];
sum_left
+=
cnt_in_bin
[
i
];
if
(
sum_left
>=
filter_cnt
)
{
if
(
sum_left
>=
filter_cnt
&&
total_cnt
-
sum_left
>=
filter_cnt
)
{
return
false
;
}
else
if
(
total_cnt
-
sum_left
>=
filter_cnt
)
{
return
false
;
return
false
;
}
}
}
}
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
cnt_in_bin
.
size
()
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
cnt_in_bin
.
size
()
-
1
;
++
i
)
{
int
sum_left
=
cnt_in_bin
[
i
];
int
sum_left
=
cnt_in_bin
[
i
];
if
(
sum_left
>=
filter_cnt
)
{
if
(
sum_left
>=
filter_cnt
&&
total_cnt
-
sum_left
>=
filter_cnt
)
{
return
false
;
}
else
if
(
total_cnt
-
sum_left
>=
filter_cnt
)
{
return
false
;
return
false
;
}
}
}
}
...
...
src/io/config.cpp
View file @
1c1749db
...
@@ -141,8 +141,8 @@ void OverallConfig::CheckParamConflict() {
...
@@ -141,8 +141,8 @@ void OverallConfig::CheckParamConflict() {
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
));
bool
objective_type_multiclass
=
(
objective_type
==
std
::
string
(
"multiclass"
));
int
num_class_check
=
boosting_config
.
num_class
;
int
num_class_check
=
boosting_config
.
num_class
;
if
(
objective_type_multiclass
)
{
if
(
objective_type_multiclass
)
{
if
(
num_class_check
<=
2
)
{
if
(
num_class_check
<=
1
)
{
Log
::
Fatal
(
"Number of classes should be specified and greater than
2
for multiclass training"
);
Log
::
Fatal
(
"Number of classes should be specified and greater than
1
for multiclass training"
);
}
}
}
else
{
}
else
{
if
(
task_type
==
TaskType
::
kTrain
&&
num_class_check
!=
1
)
{
if
(
task_type
==
TaskType
::
kTrain
&&
num_class_check
!=
1
)
{
...
...
src/io/dataset_loader.cpp
View file @
1c1749db
...
@@ -487,7 +487,9 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
...
@@ -487,7 +487,9 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
feature_names_
.
push_back
(
str_buf
.
str
());
feature_names_
.
push_back
(
str_buf
.
str
());
}
}
}
}
const
data_size_t
filter_cnt
=
static_cast
<
data_size_t
>
(
static_cast
<
double
>
(
0.95
*
io_config_
.
min_data_in_leaf
)
/
num_data
*
num_col
);
const
data_size_t
filter_cnt
=
static_cast
<
data_size_t
>
(
static_cast
<
double
>
(
io_config_
.
min_data_in_leaf
*
total_sample_size
)
/
num_data
);
#pragma omp parallel for schedule(guided)
#pragma omp parallel for schedule(guided)
for
(
int
i
=
0
;
i
<
num_col
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_col
;
++
i
)
{
...
@@ -701,7 +703,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
...
@@ -701,7 +703,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
}
}
dataset
->
feature_names_
=
feature_names_
;
dataset
->
feature_names_
=
feature_names_
;
std
::
vector
<
std
::
unique_ptr
<
BinMapper
>>
bin_mappers
(
sample_values
.
size
());
std
::
vector
<
std
::
unique_ptr
<
BinMapper
>>
bin_mappers
(
sample_values
.
size
());
const
data_size_t
filter_cnt
=
static_cast
<
data_size_t
>
(
static_cast
<
double
>
(
0.95
*
io_config_
.
min_data_in_leaf
)
/
dataset
->
num_data_
*
sample_values
.
size
());
const
data_size_t
filter_cnt
=
static_cast
<
data_size_t
>
(
static_cast
<
double
>
(
io_config_
.
min_data_in_leaf
*
sample_values
.
size
())
/
dataset
->
num_data_
);
// start find bins
// start find bins
if
(
num_machines
==
1
)
{
if
(
num_machines
==
1
)
{
...
@@ -815,7 +818,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
...
@@ -815,7 +818,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
}
}
}
}
sample_values
.
clear
();
sample_values
.
clear
();
dataset
->
Construct
(
bin_mappers
,
Common
::
Vector2Ptr
<
int
>
(
sample_indices
),
dataset
->
Construct
(
bin_mappers
,
Common
::
Vector2Ptr
<
int
>
(
sample_indices
)
.
data
()
,
Common
::
VectorSize
<
int
>
(
sample_indices
).
data
(),
sample_data
.
size
(),
io_config_
);
Common
::
VectorSize
<
int
>
(
sample_indices
).
data
(),
sample_data
.
size
(),
io_config_
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment