Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
3a06ce35
Commit
3a06ce35
authored
Oct 29, 2016
by
Guolin Ke
Browse files
clean code
parent
a057afec
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
133 additions
and
67 deletions
+133
-67
include/LightGBM/dataset.h
include/LightGBM/dataset.h
+2
-2
include/LightGBM/utils/pipeline_reader.h
include/LightGBM/utils/pipeline_reader.h
+3
-2
include/LightGBM/utils/text_reader.h
include/LightGBM/utils/text_reader.h
+3
-3
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+8
-8
src/io/dataset.cpp
src/io/dataset.cpp
+61
-13
src/io/parser.cpp
src/io/parser.cpp
+56
-39
No files found.
include/LightGBM/dataset.h
View file @
3a06ce35
...
@@ -192,8 +192,6 @@ private:
...
@@ -192,8 +192,6 @@ private:
int16_t
*
label_int_
;
int16_t
*
label_int_
;
/*! \brief Weights data */
/*! \brief Weights data */
float
*
weights_
;
float
*
weights_
;
/*! \brief Queries data */
data_size_t
*
queries_
;
/*! \brief Query boundaries */
/*! \brief Query boundaries */
data_size_t
*
query_boundaries_
;
data_size_t
*
query_boundaries_
;
/*! \brief Query weights */
/*! \brief Query weights */
...
@@ -204,6 +202,8 @@ private:
...
@@ -204,6 +202,8 @@ private:
data_size_t
num_init_score_
;
data_size_t
num_init_score_
;
/*! \brief Initial score */
/*! \brief Initial score */
score_t
*
init_score_
;
score_t
*
init_score_
;
/*! \brief Queries data */
data_size_t
*
queries_
;
};
};
...
...
include/LightGBM/utils/pipeline_reader.h
View file @
3a06ce35
...
@@ -38,12 +38,13 @@ public:
...
@@ -38,12 +38,13 @@ public:
char
*
buffer_process
=
new
char
[
buffer_size
];
char
*
buffer_process
=
new
char
[
buffer_size
];
// buffer used for the file reading
// buffer used for the file reading
char
*
buffer_read
=
new
char
[
buffer_size
];
char
*
buffer_read
=
new
char
[
buffer_size
];
size_t
read_cnt
=
0
;
if
(
skip_bytes
>
0
)
{
if
(
skip_bytes
>
0
)
{
// skip first k bytes
// skip first k bytes
fread
(
buffer_process
,
1
,
skip_bytes
,
file
);
read_cnt
=
fread
(
buffer_process
,
1
,
skip_bytes
,
file
);
}
}
// read first block
// read first block
size_t
read_cnt
=
fread
(
buffer_process
,
1
,
buffer_size
,
file
);
read_cnt
=
fread
(
buffer_process
,
1
,
buffer_size
,
file
);
size_t
last_read_cnt
=
0
;
size_t
last_read_cnt
=
0
;
while
(
read_cnt
>
0
)
{
while
(
read_cnt
>
0
)
{
// strat read thread
// strat read thread
...
...
include/LightGBM/utils/text_reader.h
View file @
3a06ce35
...
@@ -34,7 +34,7 @@ public:
...
@@ -34,7 +34,7 @@ public:
#else
#else
file
=
fopen
(
filename
,
"r"
);
file
=
fopen
(
filename
,
"r"
);
#endif
#endif
std
::
stringstream
s
s
;
std
::
stringstream
s
tr_buf
;
int
read_c
=
-
1
;
int
read_c
=
-
1
;
read_c
=
fgetc
(
file
);
read_c
=
fgetc
(
file
);
while
(
read_c
!=
EOF
)
{
while
(
read_c
!=
EOF
)
{
...
@@ -42,7 +42,7 @@ public:
...
@@ -42,7 +42,7 @@ public:
if
(
tmp_ch
==
'\n'
||
tmp_ch
==
'\r'
)
{
if
(
tmp_ch
==
'\n'
||
tmp_ch
==
'\r'
)
{
break
;
break
;
}
}
s
s
<<
tmp_ch
;
s
tr_buf
<<
tmp_ch
;
++
skip_bytes_
;
++
skip_bytes_
;
read_c
=
fgetc
(
file
);
read_c
=
fgetc
(
file
);
}
}
...
@@ -55,7 +55,7 @@ public:
...
@@ -55,7 +55,7 @@ public:
++
skip_bytes_
;
++
skip_bytes_
;
}
}
fclose
(
file
);
fclose
(
file
);
first_line_
=
s
s
.
str
();
first_line_
=
s
tr_buf
.
str
();
Log
::
Info
(
"skip header:
\"
%s
\"
in file %s"
,
first_line_
.
c_str
(),
filename_
);
Log
::
Info
(
"skip header:
\"
%s
\"
in file %s"
,
first_line_
.
c_str
(),
filename_
);
}
}
}
}
...
...
src/boosting/gbdt.cpp
View file @
3a06ce35
...
@@ -275,21 +275,21 @@ void GBDT::Boosting() {
...
@@ -275,21 +275,21 @@ void GBDT::Boosting() {
std
::
string
GBDT
::
ModelsToString
()
const
{
std
::
string
GBDT
::
ModelsToString
()
const
{
// serialize this object to string
// serialize this object to string
std
::
stringstream
s
s
;
std
::
stringstream
s
tr_buf
;
// output label index
// output label index
s
s
<<
"label_index="
<<
label_idx_
<<
std
::
endl
;
s
tr_buf
<<
"label_index="
<<
label_idx_
<<
std
::
endl
;
// output max_feature_idx
// output max_feature_idx
s
s
<<
"max_feature_idx="
<<
max_feature_idx_
<<
std
::
endl
;
s
tr_buf
<<
"max_feature_idx="
<<
max_feature_idx_
<<
std
::
endl
;
// output sigmoid parameter
// output sigmoid parameter
s
s
<<
"sigmoid="
<<
object_function_
->
GetSigmoid
()
<<
std
::
endl
;
s
tr_buf
<<
"sigmoid="
<<
object_function_
->
GetSigmoid
()
<<
std
::
endl
;
s
s
<<
std
::
endl
;
s
tr_buf
<<
std
::
endl
;
// output tree models
// output tree models
for
(
size_t
i
=
0
;
i
<
models_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
models_
.
size
();
++
i
)
{
s
s
<<
"Tree="
<<
i
<<
std
::
endl
;
s
tr_buf
<<
"Tree="
<<
i
<<
std
::
endl
;
s
s
<<
models_
[
i
]
->
ToString
()
<<
std
::
endl
;
s
tr_buf
<<
models_
[
i
]
->
ToString
()
<<
std
::
endl
;
}
}
return
s
s
.
str
();
return
s
tr_buf
.
str
();
}
}
void
GBDT
::
ModelsFromString
(
const
std
::
string
&
model_str
,
int
num_used_model
)
{
void
GBDT
::
ModelsFromString
(
const
std
::
string
&
model_str
,
int
num_used_model
)
{
...
...
src/io/dataset.cpp
View file @
3a06ce35
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <vector>
#include <vector>
#include <utility>
#include <utility>
#include <string>
#include <string>
#include <sstream>
namespace
LightGBM
{
namespace
LightGBM
{
...
@@ -36,8 +37,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
...
@@ -36,8 +37,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
if
(
io_config
.
has_header
)
{
if
(
io_config
.
has_header
)
{
std
::
string
first_line
=
text_reader_
->
first_line
();
std
::
string
first_line
=
text_reader_
->
first_line
();
feature_names_
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
feature_names_
=
Common
::
Split
(
first_line
.
c_str
(),
"
\t
,"
);
for
(
in
t
i
=
0
;
i
<
feature_names_
.
size
();
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
feature_names_
.
size
();
++
i
)
{
name2idx
[
feature_names_
[
i
]]
=
i
;
name2idx
[
feature_names_
[
i
]]
=
static_cast
<
int
>
(
i
)
;
}
}
}
}
std
::
string
name_prefix
(
"name:"
);
std
::
string
name_prefix
(
"name:"
);
...
@@ -48,14 +49,25 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
...
@@ -48,14 +49,25 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std
::
string
name
=
io_config
.
label_column
.
substr
(
name_prefix
.
size
());
std
::
string
name
=
io_config
.
label_column
.
substr
(
name_prefix
.
size
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
label_idx_
=
name2idx
[
name
];
label_idx_
=
name2idx
[
name
];
Log
::
Info
(
"use %s column as label"
,
name
.
c_str
());
}
else
{
}
else
{
Log
::
Fatal
(
"cannot find label column: %s in data file"
,
name
.
c_str
());
Log
::
Fatal
(
"cannot find label column: %s in data file"
,
name
.
c_str
());
}
}
}
else
{
}
else
{
Common
::
Atoi
(
io_config
.
label_column
.
c_str
(),
&
label_idx_
);
size_t
pos
=
0
;
label_idx_
=
std
::
stoi
(
io_config
.
label_column
,
&
pos
);
if
(
pos
!=
io_config
.
label_column
.
size
())
{
Log
::
Fatal
(
"label_column is not a number, \
if you want to use column name, \
please add prefix
\"
name:
\"
before column name"
);
}
Log
::
Info
(
"use %d-th column as label"
,
label_idx_
);
}
}
}
}
if
(
feature_names_
.
size
()
>
0
)
{
// erase label column name
feature_names_
.
erase
(
feature_names_
.
begin
()
+
label_idx_
);
}
// load ignore columns
// load ignore columns
if
(
io_config
.
ignore_column
.
size
()
>
0
)
{
if
(
io_config
.
ignore_column
.
size
()
>
0
)
{
if
(
Common
::
StartsWith
(
io_config
.
ignore_column
,
name_prefix
))
{
if
(
Common
::
StartsWith
(
io_config
.
ignore_column
,
name_prefix
))
{
...
@@ -72,8 +84,13 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
...
@@ -72,8 +84,13 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
}
}
}
else
{
}
else
{
for
(
auto
token
:
Common
::
Split
(
io_config
.
ignore_column
.
c_str
(),
','
))
{
for
(
auto
token
:
Common
::
Split
(
io_config
.
ignore_column
.
c_str
(),
','
))
{
int
tmp
=
0
;
size_t
pos
=
0
;
Common
::
Atoi
(
token
.
c_str
(),
&
tmp
);
int
tmp
=
std
::
stoi
(
token
,
&
pos
);
if
(
pos
!=
token
.
size
())
{
Log
::
Fatal
(
"ignore_column is not a number, \
if you want to use column name, \
please add prefix
\"
name:
\"
before column name"
);
}
// skip for label column
// skip for label column
if
(
tmp
>
label_idx_
)
{
tmp
-=
1
;
}
if
(
tmp
>
label_idx_
)
{
tmp
-=
1
;
}
ignore_features_
.
emplace
(
tmp
);
ignore_features_
.
emplace
(
tmp
);
...
@@ -88,11 +105,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
...
@@ -88,11 +105,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std
::
string
name
=
io_config
.
weight_column
.
substr
(
name_prefix
.
size
());
std
::
string
name
=
io_config
.
weight_column
.
substr
(
name_prefix
.
size
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
weight_idx_
=
name2idx
[
name
];
weight_idx_
=
name2idx
[
name
];
Log
::
Info
(
"use %s column as weight"
,
name
.
c_str
());
}
else
{
}
else
{
Log
::
Fatal
(
"cannot find weight column: %s in data file"
,
name
.
c_str
());
Log
::
Fatal
(
"cannot find weight column: %s in data file"
,
name
.
c_str
());
}
}
}
else
{
}
else
{
Common
::
Atoi
(
io_config
.
weight_column
.
c_str
(),
&
weight_idx_
);
size_t
pos
=
0
;
weight_idx_
=
std
::
stoi
(
io_config
.
weight_column
,
&
pos
);
if
(
pos
!=
io_config
.
weight_column
.
size
())
{
Log
::
Fatal
(
"weight_column is not a number, \
if you want to use column name, \
please add prefix
\"
name:
\"
before column name"
);
}
Log
::
Info
(
"use %d-th column as weight"
,
weight_idx_
);
}
}
// skip for label column
// skip for label column
if
(
weight_idx_
>
label_idx_
)
{
if
(
weight_idx_
>
label_idx_
)
{
...
@@ -106,11 +131,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
...
@@ -106,11 +131,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std
::
string
name
=
io_config
.
group_column
.
substr
(
name_prefix
.
size
());
std
::
string
name
=
io_config
.
group_column
.
substr
(
name_prefix
.
size
());
if
(
name2idx
.
count
(
name
)
>
0
)
{
if
(
name2idx
.
count
(
name
)
>
0
)
{
group_idx_
=
name2idx
[
name
];
group_idx_
=
name2idx
[
name
];
Log
::
Info
(
"use %s column as group/query id"
,
name
.
c_str
());
}
else
{
}
else
{
Log
::
Fatal
(
"cannot find group/query column: %s in data file"
,
name
.
c_str
());
Log
::
Fatal
(
"cannot find group/query column: %s in data file"
,
name
.
c_str
());
}
}
}
else
{
}
else
{
Common
::
Atoi
(
io_config
.
group_column
.
c_str
(),
&
group_idx_
);
size_t
pos
=
0
;
group_idx_
=
std
::
stoi
(
io_config
.
group_column
,
&
pos
);
if
(
pos
!=
io_config
.
group_column
.
size
())
{
Log
::
Fatal
(
"group_column is not a number, \
if you want to use column name, \
please add prefix
\"
name:
\"
before column name"
);
}
Log
::
Info
(
"use %d-th column as group/query id"
,
group_idx_
);
}
}
// skip for label column
// skip for label column
if
(
group_idx_
>
label_idx_
)
{
if
(
group_idx_
>
label_idx_
)
{
...
@@ -279,6 +312,21 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -279,6 +312,21 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature
// -1 means doesn't use this feature
used_feature_map_
=
std
::
vector
<
int
>
(
sample_values
.
size
(),
-
1
);
used_feature_map_
=
std
::
vector
<
int
>
(
sample_values
.
size
(),
-
1
);
num_total_features_
=
static_cast
<
int
>
(
sample_values
.
size
());
num_total_features_
=
static_cast
<
int
>
(
sample_values
.
size
());
// check the range of label_idx, weight_idx and group_idx
CHECK
(
label_idx_
>=
0
&&
label_idx_
<=
num_total_features_
);
CHECK
(
weight_idx_
<
0
||
weight_idx_
<
num_total_features_
);
CHECK
(
group_idx_
<
0
||
group_idx_
<
num_total_features_
);
// fill feature_names_ if not header
if
(
feature_names_
.
size
()
<=
0
)
{
for
(
int
i
=
0
;
i
<
num_total_features_
;
++
i
)
{
std
::
stringstream
str_buf
;
str_buf
<<
"Column_"
<<
i
;
feature_names_
.
push_back
(
str_buf
.
str
());
}
}
// start find bins
// start find bins
if
(
num_machines
==
1
)
{
if
(
num_machines
==
1
)
{
std
::
vector
<
BinMapper
*>
bin_mappers
(
sample_values
.
size
());
std
::
vector
<
BinMapper
*>
bin_mappers
(
sample_values
.
size
());
...
@@ -295,7 +343,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -295,7 +343,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
for
(
size_t
i
=
0
;
i
<
sample_values
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sample_values
.
size
();
++
i
)
{
if
(
bin_mappers
[
i
]
==
nullptr
)
{
if
(
bin_mappers
[
i
]
==
nullptr
)
{
Log
::
Error
(
"Ignore Feature %
d
"
,
i
);
Log
::
Error
(
"Ignore Feature %
s
"
,
feature_names_
[
i
].
c_str
()
);
}
}
else
if
(
!
bin_mappers
[
i
]
->
is_trival
())
{
else
if
(
!
bin_mappers
[
i
]
->
is_trival
())
{
// map real feature index to used feature index
// map real feature index to used feature index
...
@@ -305,7 +353,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -305,7 +353,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_
,
is_enable_sparse_
));
num_data_
,
is_enable_sparse_
));
}
else
{
}
else
{
// if feature is trival(only 1 bin), free spaces
// if feature is trival(only 1 bin), free spaces
Log
::
Error
(
"Feature %
d
only contains one value, will be ignored"
,
i
);
Log
::
Error
(
"Feature %
s
only contains one value, will be ignored"
,
feature_names_
[
i
].
c_str
()
);
delete
bin_mappers
[
i
];
delete
bin_mappers
[
i
];
}
}
}
}
...
@@ -353,7 +401,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -353,7 +401,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// restore features bins from buffer
// restore features bins from buffer
for
(
int
i
=
0
;
i
<
total_num_feature
;
++
i
)
{
for
(
int
i
=
0
;
i
<
total_num_feature
;
++
i
)
{
if
(
ignore_features_
.
count
(
i
)
>
0
)
{
if
(
ignore_features_
.
count
(
i
)
>
0
)
{
Log
::
Error
(
"Ignore Feature %
d
"
,
i
);
Log
::
Error
(
"Ignore Feature %
s
"
,
feature_names_
[
i
].
c_str
()
);
continue
;
continue
;
}
}
BinMapper
*
bin_mapper
=
new
BinMapper
();
BinMapper
*
bin_mapper
=
new
BinMapper
();
...
@@ -362,7 +410,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -362,7 +410,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
used_feature_map_
[
i
]
=
static_cast
<
int
>
(
features_
.
size
());
used_feature_map_
[
i
]
=
static_cast
<
int
>
(
features_
.
size
());
features_
.
push_back
(
new
Feature
(
static_cast
<
int
>
(
i
),
bin_mapper
,
num_data_
,
is_enable_sparse_
));
features_
.
push_back
(
new
Feature
(
static_cast
<
int
>
(
i
),
bin_mapper
,
num_data_
,
is_enable_sparse_
));
}
else
{
}
else
{
Log
::
Error
(
"Feature %
d
only contains one value, will be ignored"
,
i
);
Log
::
Error
(
"Feature %
s
only contains one value, will be ignored"
,
feature_names_
[
i
].
c_str
()
);
delete
bin_mapper
;
delete
bin_mapper
;
}
}
}
}
...
@@ -377,7 +425,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
...
@@ -377,7 +425,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
void
Dataset
::
LoadTrainData
(
int
rank
,
int
num_machines
,
bool
is_pre_partition
,
bool
use_two_round_loading
)
{
void
Dataset
::
LoadTrainData
(
int
rank
,
int
num_machines
,
bool
is_pre_partition
,
bool
use_two_round_loading
)
{
// don't support query id in data file when training parallel
// don't support query id in data file when training
in
parallel
if
(
num_machines
>
1
&&
!
is_pre_partition
)
{
if
(
num_machines
>
1
&&
!
is_pre_partition
)
{
if
(
group_idx_
>
0
)
{
if
(
group_idx_
>
0
)
{
Log
::
Fatal
(
"Don't support query id in data file when training parallel without pre-partition. \
Log
::
Fatal
(
"Don't support query id in data file when training parallel without pre-partition. \
...
...
src/io/parser.cpp
View file @
3a06ce35
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <functional>
namespace
LightGBM
{
namespace
LightGBM
{
...
@@ -20,37 +21,53 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
...
@@ -20,37 +21,53 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
}
}
}
}
bool
CheckHasLabelForLibsvm
(
std
::
string
&
str
)
{
int
GetLabelIdxForLibsvm
(
std
::
string
&
str
,
int
num_features
,
int
label_idx
)
{
if
(
num_features
<=
0
)
{
return
label_idx
;
}
str
=
Common
::
Trim
(
str
);
str
=
Common
::
Trim
(
str
);
auto
pos_space
=
str
.
find_first_of
(
"
\f\n\r\t\v
"
);
auto
pos_space
=
str
.
find_first_of
(
"
\f\n\r\t\v
"
);
auto
pos_colon
=
str
.
find_first_of
(
":"
);
auto
pos_colon
=
str
.
find_first_of
(
":"
);
if
(
pos_colon
==
std
::
string
::
npos
||
pos_colon
>
pos_space
)
{
if
(
pos_colon
==
std
::
string
::
npos
||
pos_colon
>
pos_space
)
{
return
true
;
return
-
1
;
}
else
{
}
else
{
return
false
;
return
label_idx
;
}
}
}
}
bool
CheckHasLabelForTSV
(
std
::
string
&
str
,
int
num_features
)
{
int
GetLabelIdxForTSV
(
std
::
string
&
str
,
int
num_features
,
int
label_idx
)
{
if
(
num_features
<=
0
)
{
return
label_idx
;
}
str
=
Common
::
Trim
(
str
);
str
=
Common
::
Trim
(
str
);
auto
tokens
=
Common
::
Split
(
str
.
c_str
(),
'\t'
);
auto
tokens
=
Common
::
Split
(
str
.
c_str
(),
'\t'
);
if
(
static_cast
<
int
>
(
tokens
.
size
())
==
num_features
)
{
if
(
static_cast
<
int
>
(
tokens
.
size
())
==
num_features
)
{
return
false
;
return
-
1
;
}
else
{
}
else
{
return
true
;
return
label_idx
;
}
}
}
}
bool
CheckHasLabelForCSV
(
std
::
string
&
str
,
int
num_features
)
{
int
GetLabelIdxForCSV
(
std
::
string
&
str
,
int
num_features
,
int
label_idx
)
{
if
(
num_features
<=
0
)
{
return
label_idx
;
}
str
=
Common
::
Trim
(
str
);
str
=
Common
::
Trim
(
str
);
auto
tokens
=
Common
::
Split
(
str
.
c_str
(),
','
);
auto
tokens
=
Common
::
Split
(
str
.
c_str
(),
','
);
if
(
static_cast
<
int
>
(
tokens
.
size
())
==
num_features
)
{
if
(
static_cast
<
int
>
(
tokens
.
size
())
==
num_features
)
{
return
false
;
return
-
1
;
}
else
{
}
else
{
return
true
;
return
label_idx
;
}
}
}
}
enum
DataType
{
INVALID
,
CSV
,
TSV
,
LIBSVM
};
Parser
*
Parser
::
CreateParser
(
const
char
*
filename
,
bool
has_header
,
int
num_features
,
int
label_idx
)
{
Parser
*
Parser
::
CreateParser
(
const
char
*
filename
,
bool
has_header
,
int
num_features
,
int
label_idx
)
{
std
::
ifstream
tmp_file
;
std
::
ifstream
tmp_file
;
tmp_file
.
open
(
filename
);
tmp_file
.
open
(
filename
);
...
@@ -80,46 +97,46 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
...
@@ -80,46 +97,46 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
// Get some statistic from 2 line
// Get some statistic from 2 line
GetStatistic
(
line1
.
c_str
(),
&
comma_cnt
,
&
tab_cnt
,
&
colon_cnt
);
GetStatistic
(
line1
.
c_str
(),
&
comma_cnt
,
&
tab_cnt
,
&
colon_cnt
);
GetStatistic
(
line2
.
c_str
(),
&
comma_cnt2
,
&
tab_cnt2
,
&
colon_cnt2
);
GetStatistic
(
line2
.
c_str
(),
&
comma_cnt2
,
&
tab_cnt2
,
&
colon_cnt2
);
Parser
*
ret
=
nullptr
;
bool
has_label
=
true
;
DataType
type
=
DataType
::
INVALID
;
if
(
line2
.
size
()
==
0
)
{
if
(
line2
.
size
()
==
0
)
{
// if only have one line on file
// if only have one line on file
if
(
colon_cnt
>
0
)
{
if
(
colon_cnt
>
0
)
{
if
(
num_features
>
0
)
{
type
=
DataType
::
LIBSVM
;
has_label
=
CheckHasLabelForLibsvm
(
line1
);
}
ret
=
new
LibSVMParser
(
has_label
?
label_idx
:
-
1
);
}
else
if
(
tab_cnt
>
0
)
{
}
else
if
(
tab_cnt
>
0
)
{
if
(
num_features
>
0
)
{
type
=
DataType
::
TSV
;
has_label
=
CheckHasLabelForTSV
(
line1
,
num_features
);
}
ret
=
new
TSVParser
(
has_label
?
label_idx
:
-
1
);
}
else
if
(
comma_cnt
>
0
)
{
}
else
if
(
comma_cnt
>
0
)
{
if
(
num_features
>
0
)
{
type
=
DataType
::
CSV
;
has_label
=
CheckHasLabelForCSV
(
line1
,
num_features
);
}
}
ret
=
new
CSVParser
(
has_label
?
label_idx
:
-
1
);
}
}
else
{
}
else
{
if
(
colon_cnt
>
0
||
colon_cnt2
>
0
)
{
if
(
colon_cnt
>
0
||
colon_cnt2
>
0
)
{
if
(
num_features
>
0
)
{
type
=
DataType
::
LIBSVM
;
has_label
=
CheckHasLabelForLibsvm
(
line1
);
}
else
if
(
tab_cnt
==
tab_cnt2
&&
tab_cnt
>
0
)
{
}
type
=
DataType
::
TSV
;
ret
=
new
LibSVMParser
(
has_label
?
label_idx
:
-
1
);
}
else
if
(
tab_cnt
==
tab_cnt2
&&
tab_cnt
>
0
)
{
if
(
num_features
>
0
)
{
has_label
=
CheckHasLabelForTSV
(
line1
,
num_features
);
}
ret
=
new
TSVParser
(
has_label
?
label_idx
:
-
1
);
}
else
if
(
comma_cnt
==
comma_cnt2
&&
comma_cnt
>
0
)
{
}
else
if
(
comma_cnt
==
comma_cnt2
&&
comma_cnt
>
0
)
{
if
(
num_features
>
0
)
{
type
=
DataType
::
CSV
;
has_label
=
CheckHasLabelForCSV
(
line1
,
num_features
);
}
ret
=
new
CSVParser
(
has_label
?
label_idx
:
-
1
);
}
}
}
}
if
(
!
has_label
)
{
if
(
type
==
DataType
::
INVALID
)
{
Log
::
Fatal
(
"Unkown format of training data"
);
}
Parser
*
ret
=
nullptr
;
if
(
type
==
DataType
::
LIBSVM
)
{
label_idx
=
GetLabelIdxForLibsvm
(
line1
,
num_features
,
label_idx
);
ret
=
new
LibSVMParser
(
label_idx
);
}
else
if
(
type
==
DataType
::
TSV
)
{
label_idx
=
GetLabelIdxForTSV
(
line1
,
num_features
,
label_idx
);
ret
=
new
TSVParser
(
label_idx
);
}
else
if
(
type
==
DataType
::
CSV
)
{
label_idx
=
GetLabelIdxForCSV
(
line1
,
num_features
,
label_idx
);
ret
=
new
CSVParser
(
label_idx
);
}
if
(
label_idx
<
0
)
{
Log
::
Info
(
"Data file: %s doesn't contain label column"
,
filename
);
Log
::
Info
(
"Data file: %s doesn't contain label column"
,
filename
);
}
}
return
ret
;
return
ret
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment