Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ae320e59
Unverified
Commit
ae320e59
authored
Dec 20, 2019
by
Guolin Ke
Committed by
GitHub
Dec 20, 2019
Browse files
fix predict with header (#2643)
* fix predict with header * avoid duplicated feature names
parent
d1002776
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
20 deletions
+33
-20
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+5
-0
src/application/predictor.hpp
src/application/predictor.hpp
+20
-13
src/io/parser.cpp
src/io/parser.cpp
+8
-7
No files found.
include/LightGBM/dataset.h
View file @
ae320e59
...
@@ -556,6 +556,7 @@ class Dataset {
...
@@ -556,6 +556,7 @@ class Dataset {
Log
::
Fatal
(
"Size of feature_names error, should equal with total number of features"
);
Log
::
Fatal
(
"Size of feature_names error, should equal with total number of features"
);
}
}
feature_names_
=
std
::
vector
<
std
::
string
>
(
feature_names
);
feature_names_
=
std
::
vector
<
std
::
string
>
(
feature_names
);
std
::
unordered_set
<
std
::
string
>
feature_name_set
;
// replace ' ' in feature_names with '_'
// replace ' ' in feature_names with '_'
bool
spaceInFeatureName
=
false
;
bool
spaceInFeatureName
=
false
;
for
(
auto
&
feature_name
:
feature_names_
)
{
for
(
auto
&
feature_name
:
feature_names_
)
{
...
@@ -571,6 +572,10 @@ class Dataset {
...
@@ -571,6 +572,10 @@ class Dataset {
spaceInFeatureName
=
true
;
spaceInFeatureName
=
true
;
std
::
replace
(
feature_name
.
begin
(),
feature_name
.
end
(),
' '
,
'_'
);
std
::
replace
(
feature_name
.
begin
(),
feature_name
.
end
(),
' '
,
'_'
);
}
}
if
(
feature_name_set
.
count
(
feature_name
)
>
0
)
{
Log
::
Fatal
(
"Feature (%s) appears more than one time."
,
feature_name
.
c_str
());
}
feature_name_set
.
insert
(
feature_name
);
}
}
if
(
spaceInFeatureName
)
{
if
(
spaceInFeatureName
)
{
Log
::
Warning
(
"Find whitespaces in feature_names, replace with underlines"
);
Log
::
Warning
(
"Find whitespaces in feature_names, replace with underlines"
);
...
...
src/application/predictor.hpp
View file @
ae320e59
...
@@ -135,31 +135,38 @@ class Predictor {
...
@@ -135,31 +135,38 @@ class Predictor {
if
(
!
writer
->
Init
())
{
if
(
!
writer
->
Init
())
{
Log
::
Fatal
(
"Prediction results file %s cannot be found"
,
result_filename
);
Log
::
Fatal
(
"Prediction results file %s cannot be found"
,
result_filename
);
}
}
auto
parser
=
std
::
unique_ptr
<
Parser
>
(
Parser
::
CreateParser
(
data_filename
,
header
,
boosting_
->
MaxFeatureIdx
()
+
1
,
boosting_
->
LabelIdx
()));
auto
label_idx
=
header
?
-
1
:
boosting_
->
LabelIdx
();
auto
parser
=
std
::
unique_ptr
<
Parser
>
(
Parser
::
CreateParser
(
data_filename
,
header
,
boosting_
->
MaxFeatureIdx
()
+
1
,
label_idx
));
if
(
parser
==
nullptr
)
{
if
(
parser
==
nullptr
)
{
Log
::
Fatal
(
"Could not recognize the data format of data file %s"
,
data_filename
);
Log
::
Fatal
(
"Could not recognize the data format of data file %s"
,
data_filename
);
}
}
if
(
parser
->
NumFeatures
()
!=
boosting_
->
MaxFeatureIdx
()
+
1
)
{
if
(
!
header
&&
parser
->
NumFeatures
()
!=
boosting_
->
MaxFeatureIdx
()
+
1
)
{
Log
::
Fatal
(
"The number of features in data (%d) is not the same as it was in training data (%d)."
,
parser
->
NumFeatures
(),
boosting_
->
MaxFeatureIdx
()
+
1
);
Log
::
Fatal
(
"The number of features in data (%d) is not the same as it was in training data (%d)."
,
parser
->
NumFeatures
(),
boosting_
->
MaxFeatureIdx
()
+
1
);
}
}
TextReader
<
data_size_t
>
predict_data_reader
(
data_filename
,
header
);
TextReader
<
data_size_t
>
predict_data_reader
(
data_filename
,
header
);
std
::
unordered_map
<
int
,
int
>
feature_names_map_
;
std
::
vector
<
int
>
feature_remapper
(
parser
->
NumFeatures
(),
-
1
)
;
bool
need_adjust
=
false
;
bool
need_adjust
=
false
;
if
(
header
)
{
if
(
header
)
{
std
::
string
first_line
=
predict_data_reader
.
first_line
();
std
::
string
first_line
=
predict_data_reader
.
first_line
();
std
::
vector
<
std
::
string
>
header_words
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
std
::
vector
<
std
::
string
>
header_words
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
header_words
.
erase
(
header_words
.
begin
()
+
boosting_
->
LabelIdx
())
;
std
::
unordered_map
<
std
::
string
,
int
>
header_mapper
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
header_words
.
size
());
++
i
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
header_words
.
size
());
++
i
)
{
for
(
int
j
=
0
;
j
<
static_cast
<
int
>
(
boosting_
->
FeatureNames
().
size
());
++
j
)
{
if
(
header_mapper
.
count
(
header_words
[
i
])
>
0
)
{
if
(
header_words
[
i
]
==
boosting_
->
FeatureNames
()[
j
])
{
Log
::
Fatal
(
"Feature (%s) appears more than one time."
,
header_words
[
i
].
c_str
());
feature_names_map_
[
i
]
=
j
;
}
break
;
header_mapper
[
header_words
[
i
]]
=
i
;
}
}
const
auto
&
fnames
=
boosting_
->
FeatureNames
();
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
fnames
.
size
());
++
i
)
{
if
(
header_mapper
.
count
(
fnames
[
i
])
<=
0
)
{
Log
::
Warning
(
"Feature (%s) is missed in data file. If it is weight/query/group/ignore_column, you can ignore this warning."
,
fnames
[
i
].
c_str
());
}
else
{
feature_remapper
[
header_mapper
.
at
(
fnames
[
i
])]
=
i
;
}
}
}
}
for
(
auto
s
:
feature_names_map_
)
{
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
feature_remapper
.
size
());
++
i
)
{
if
(
s
.
first
!=
s
.
second
)
{
if
(
feature_remapper
[
i
]
>=
0
&&
i
!=
feature_remapper
[
i
]
)
{
need_adjust
=
true
;
need_adjust
=
true
;
break
;
break
;
}
}
...
@@ -174,8 +181,8 @@ class Predictor {
...
@@ -174,8 +181,8 @@ class Predictor {
if
(
need_adjust
)
{
if
(
need_adjust
)
{
int
i
=
0
,
j
=
static_cast
<
int
>
(
feature
->
size
());
int
i
=
0
,
j
=
static_cast
<
int
>
(
feature
->
size
());
while
(
i
<
j
)
{
while
(
i
<
j
)
{
if
(
feature_
names_map_
.
find
(
(
*
feature
)[
i
].
first
)
!=
feature_names_map_
.
end
()
)
{
if
(
feature_
remapper
[
(
*
feature
)[
i
].
first
]
>=
0
)
{
(
*
feature
)[
i
].
first
=
feature_
names_map_
[(
*
feature
)[
i
].
first
];
(
*
feature
)[
i
].
first
=
feature_
remapper
[(
*
feature
)[
i
].
first
];
++
i
;
++
i
;
}
else
{
}
else
{
// move the non-used features to the end of the feature vector
// move the non-used features to the end of the feature vector
...
...
src/io/parser.cpp
View file @
ae320e59
...
@@ -201,18 +201,19 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
...
@@ -201,18 +201,19 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
Log
::
Fatal
(
"Unknown format of training data."
);
Log
::
Fatal
(
"Unknown format of training data."
);
}
}
std
::
unique_ptr
<
Parser
>
ret
;
std
::
unique_ptr
<
Parser
>
ret
;
int
output_label_index
=
-
1
;
if
(
type
==
DataType
::
LIBSVM
)
{
if
(
type
==
DataType
::
LIBSVM
)
{
label_i
d
x
=
GetLabelIdxForLibsvm
(
lines
[
0
],
num_features
,
label_idx
);
output_
label_i
nde
x
=
GetLabelIdxForLibsvm
(
lines
[
0
],
num_features
,
label_idx
);
ret
.
reset
(
new
LibSVMParser
(
label_i
d
x
,
num_col
));
ret
.
reset
(
new
LibSVMParser
(
output_
label_i
nde
x
,
num_col
));
}
else
if
(
type
==
DataType
::
TSV
)
{
}
else
if
(
type
==
DataType
::
TSV
)
{
label_i
d
x
=
GetLabelIdxForTSV
(
lines
[
0
],
num_features
,
label_idx
);
output_
label_i
nde
x
=
GetLabelIdxForTSV
(
lines
[
0
],
num_features
,
label_idx
);
ret
.
reset
(
new
TSVParser
(
label_i
d
x
,
num_col
));
ret
.
reset
(
new
TSVParser
(
output_
label_i
nde
x
,
num_col
));
}
else
if
(
type
==
DataType
::
CSV
)
{
}
else
if
(
type
==
DataType
::
CSV
)
{
label_i
d
x
=
GetLabelIdxForCSV
(
lines
[
0
],
num_features
,
label_idx
);
output_
label_i
nde
x
=
GetLabelIdxForCSV
(
lines
[
0
],
num_features
,
label_idx
);
ret
.
reset
(
new
CSVParser
(
label_i
d
x
,
num_col
));
ret
.
reset
(
new
CSVParser
(
output_
label_i
nde
x
,
num_col
));
}
}
if
(
label_idx
<
0
)
{
if
(
output_label_index
<
0
&&
label_idx
>=
0
)
{
Log
::
Info
(
"Data file %s doesn't contain a label column."
,
filename
);
Log
::
Info
(
"Data file %s doesn't contain a label column."
,
filename
);
}
}
return
ret
.
release
();
return
ret
.
release
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment