Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
ee97ed3d
Commit
ee97ed3d
authored
Oct 21, 2016
by
Qiwei Ye
Browse files
using advanced logger for lightGBM
parent
2d0e8fc9
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
404 additions
and
121 deletions
+404
-121
include/LightGBM/utils/common.h
include/LightGBM/utils/common.h
+4
-4
include/LightGBM/utils/log.cpp
include/LightGBM/utils/log.cpp
+184
-0
include/LightGBM/utils/log.h
include/LightGBM/utils/log.h
+130
-31
include/LightGBM/utils/text_reader.h
include/LightGBM/utils/text_reader.h
+2
-2
src/application/application.cpp
src/application/application.cpp
+11
-11
src/application/predictor.hpp
src/application/predictor.hpp
+4
-4
src/boosting/gbdt.cpp
src/boosting/gbdt.cpp
+5
-5
src/io/config.cpp
src/io/config.cpp
+4
-4
src/io/dataset.cpp
src/io/dataset.cpp
+18
-18
src/io/metadata.cpp
src/io/metadata.cpp
+12
-12
src/io/parser.cpp
src/io/parser.cpp
+3
-3
src/io/parser.hpp
src/io/parser.hpp
+5
-5
src/io/sparse_bin.hpp
src/io/sparse_bin.hpp
+2
-2
src/io/tree.cpp
src/io/tree.cpp
+1
-1
src/metric/binary_metric.hpp
src/metric/binary_metric.hpp
+3
-3
src/metric/dcg_calculator.cpp
src/metric/dcg_calculator.cpp
+1
-1
src/metric/rank_metric.hpp
src/metric/rank_metric.hpp
+2
-2
src/metric/regression_metric.hpp
src/metric/regression_metric.hpp
+1
-1
src/network/linkers_socket.cpp
src/network/linkers_socket.cpp
+11
-11
src/network/network.cpp
src/network/network.cpp
+1
-1
No files found.
include/LightGBM/utils/common.h
View file @
ee97ed3d
...
...
@@ -157,7 +157,7 @@ inline static const char* Atof(const char* p, double* out) {
*
out
=
sign
*
1e308
;
}
else
{
Log
::
Stder
r
(
"Unknow token %s in data file"
,
tmp_str
.
c_str
());
Log
::
Erro
r
(
"Unknow token %s in data file"
,
tmp_str
.
c_str
());
}
p
+=
cnt
;
}
...
...
@@ -201,7 +201,7 @@ inline static std::string ArrayToString(const T* arr, int n, char delimiter) {
inline
static
void
StringToIntArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
,
int
*
out
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
if
(
strs
.
size
()
!=
n
)
{
Log
::
Stder
r
(
"StringToIntArray error, size don't equal."
);
Log
::
Erro
r
(
"StringToIntArray error, size don't equal."
);
}
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
strs
[
i
]
=
Trim
(
strs
[
i
]);
...
...
@@ -212,7 +212,7 @@ inline static void StringToIntArray(const std::string& str, char delimiter, size
inline
static
void
StringToDoubleArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
,
double
*
out
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
if
(
strs
.
size
()
!=
n
)
{
Log
::
Stder
r
(
"StringToDoubleArray error, size don't equal"
);
Log
::
Erro
r
(
"StringToDoubleArray error, size don't equal"
);
}
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
strs
[
i
]
=
Trim
(
strs
[
i
]);
...
...
@@ -223,7 +223,7 @@ inline static void StringToDoubleArray(const std::string& str, char delimiter, s
inline
static
void
StringToDoubleArray
(
const
std
::
string
&
str
,
char
delimiter
,
size_t
n
,
float
*
out
)
{
std
::
vector
<
std
::
string
>
strs
=
Split
(
str
.
c_str
(),
delimiter
);
if
(
strs
.
size
()
!=
n
)
{
Log
::
Stder
r
(
"StringToDoubleArray error, size don't equal"
);
Log
::
Erro
r
(
"StringToDoubleArray error, size don't equal"
);
}
double
tmp
;
for
(
size_t
i
=
0
;
i
<
strs
.
size
();
++
i
)
{
...
...
include/LightGBM/utils/log.cpp
0 → 100644
View file @
ee97ed3d
#include "LightGBM/utils/log.h"
#include <time.h>
#include <stdarg.h>
#include <string>
namespace
LightGBM
{
// Creates a Logger intance writing messages into STDOUT.
Logger
::
Logger
(
LogLevel
level
)
{
level_
=
level
;
file_
=
nullptr
;
is_kill_fatal_
=
true
;
}
// Creates a Logger instance writing messages into both STDOUT and log file.
Logger
::
Logger
(
std
::
string
filename
,
LogLevel
level
)
{
level_
=
level
;
file_
=
nullptr
;
ResetLogFile
(
filename
);
}
Logger
::~
Logger
()
{
CloseLogFile
();
}
int
Logger
::
ResetLogFile
(
std
::
string
filename
)
{
CloseLogFile
();
if
(
filename
.
size
()
>
0
)
{
// try to open the log file if it is specified
#ifdef _MSC_VER
fopen_s
(
&
file_
,
filename
.
c_str
(),
"w"
);
#else
file_
=
fopen
(
filename
.
c_str
(),
"w"
);
#endif
if
(
file_
==
nullptr
)
{
Error
(
"Cannot create log file %s
\n
"
,
filename
.
c_str
());
return
-
1
;
}
}
return
0
;
}
void
Logger
::
Write
(
LogLevel
level
,
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
Write
(
level
,
format
,
val
);
va_end
(
val
);
}
void
Logger
::
Debug
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
Write
(
LogLevel
::
Debug
,
format
,
val
);
va_end
(
val
);
}
void
Logger
::
Info
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
Write
(
LogLevel
::
Info
,
format
,
val
);
va_end
(
val
);
}
void
Logger
::
Error
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
Write
(
LogLevel
::
Error
,
format
,
val
);
va_end
(
val
);
}
void
Logger
::
Fatal
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
Write
(
LogLevel
::
Fatal
,
format
,
val
);
va_end
(
val
);
}
inline
void
Logger
::
Write
(
LogLevel
level
,
const
char
*
format
,
va_list
*
val
)
{
if
(
level
>=
level_
)
{
// omit the message with low level
std
::
string
level_str
=
GetLevelStr
(
level
);
std
::
string
time_str
=
GetSystemTime
();
va_list
val_copy
;
va_copy
(
val_copy
,
*
val
);
// write to STDOUT
printf
(
"[%s] [%s] "
,
level_str
.
c_str
(),
time_str
.
c_str
());
vprintf
(
format
,
*
val
);
fflush
(
stdout
);
// write to log file
if
(
file_
!=
nullptr
)
{
fprintf
(
file_
,
"[%s] [%s] "
,
level_str
.
c_str
(),
time_str
.
c_str
());
vfprintf
(
file_
,
format
,
val_copy
);
fflush
(
file_
);
}
va_end
(
val_copy
);
if
(
is_kill_fatal_
&&
level
==
LogLevel
::
Fatal
)
{
CloseLogFile
();
exit
(
1
);
}
}
}
// Closes the log file if it it not null.
void
Logger
::
CloseLogFile
()
{
if
(
file_
!=
nullptr
)
{
fclose
(
file_
);
file_
=
nullptr
;
}
}
std
::
string
Logger
::
GetSystemTime
()
{
time_t
t
=
time
(
0
);
char
str
[
64
];
#ifdef _MSC_VER
tm
time
;
localtime_s
(
&
time
,
&
t
);
strftime
(
str
,
sizeof
(
str
),
"%Y-%m-%d %H:%M:%S"
,
&
time
);
#else
strftime
(
str
,
sizeof
(
str
),
"%Y-%m-%d %H:%M:%S"
,
localtime
(
&
t
));
#endif
return
str
;
}
std
::
string
Logger
::
GetLevelStr
(
LogLevel
level
)
{
switch
(
level
)
{
case
LogLevel
::
Debug
:
return
"DEBUG"
;
case
LogLevel
::
Info
:
return
"INFO"
;
case
LogLevel
::
Error
:
return
"ERROR"
;
case
LogLevel
::
Fatal
:
return
"FATAL"
;
default:
return
"UNKNOW"
;
}
}
//-- End of Logger rountine ----------------------------------------------/
Logger
Log
::
logger_
;
// global (in process) static Logger instance
int
Log
::
ResetLogFile
(
std
::
string
filename
)
{
return
logger_
.
ResetLogFile
(
filename
);
}
void
Log
::
ResetLogLevel
(
LogLevel
level
)
{
logger_
.
ResetLogLevel
(
level
);
}
void
Log
::
ResetKillFatal
(
bool
is_kill_fatal
)
{
logger_
.
ResetKillFatal
(
is_kill_fatal
);
}
void
Log
::
Write
(
LogLevel
level
,
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
logger_
.
Write
(
level
,
format
,
&
val
);
va_end
(
val
);
}
void
Log
::
Debug
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
logger_
.
Write
(
LogLevel
::
Debug
,
format
,
&
val
);
va_end
(
val
);
}
void
Log
::
Info
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
logger_
.
Write
(
LogLevel
::
Info
,
format
,
&
val
);
va_end
(
val
);
}
void
Log
::
Error
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
logger_
.
Write
(
LogLevel
::
Error
,
format
,
&
val
);
va_end
(
val
);
}
void
Log
::
Fatal
(
const
char
*
format
,
...)
{
va_list
val
;
va_start
(
val
,
format
);
logger_
.
Write
(
LogLevel
::
Fatal
,
format
,
&
val
);
va_end
(
val
);
}
}
// namespace lightGBM
include/LightGBM/utils/log.h
View file @
ee97ed3d
...
...
@@ -5,45 +5,144 @@
#include <cstdlib>
#include <cstdarg>
#include <cstring>
#include <fstream>
namespace
LightGBM
{
class
Log
{
public:
inline
static
void
Stderr
(
const
char
*
format
,
...)
{
va_list
argptr
;
char
fixed
[
512
];
#ifdef _MSC_VER
sprintf_s
(
fixed
,
"[LightGBM Error] %s
\n
"
,
format
);
#else
sprintf
(
fixed
,
"[LightGBM Error] %s
\n
"
,
format
);
#ifndef CHECK
#define CHECK(condition) \
if (!(condition)) Log::Fatal("Check failed: " #condition \
" at %s, line %d .\n", __FILE__, __LINE__);
#endif
va_start
(
argptr
,
format
);
vfprintf
(
stderr
,
fixed
,
argptr
);
va_end
(
argptr
);
fflush
(
stderr
);
std
::
exit
(
1
);
}
inline
static
void
Stdout
(
const
char
*
format
,
...)
{
va_list
argptr
;
char
fixed
[
512
];
#ifdef _MSC_VER
sprintf_s
(
fixed
,
"[LightGBM] %s
\n
"
,
format
);
#else
sprintf
(
fixed
,
"[LightGBM] %s
\n
"
,
format
);
#ifndef CHECK_NOTNULL
#define CHECK_NOTNULL(pointer) \
if ((pointer) == nullptr) LightGBM::Log::Fatal(#pointer " Can't be NULL\n");
#endif
va_start
(
argptr
,
format
);
vfprintf
(
stdout
,
fixed
,
argptr
);
va_end
(
argptr
);
fflush
(
stdout
);
}
// A enumeration type of log message levels. The values are ordered:
// Debug < Info < Error < Fatal.
enum
class
LogLevel
:
int
{
Debug
=
0
,
Info
=
1
,
Error
=
2
,
Fatal
=
3
};
#define CHECK(condition) \
if (!(condition)) Log::Stderr("Check failed: " #condition \
" at %s, line %d .\n", __FILE__, __LINE__);
/*!
* \brief The Logger class is responsible for writing log messages into
* standard output or log file.
*/
class
Logger
{
// Enable the static Log class to call the private method.
friend
class
Log
;
public:
/*!
* \brief Creates an instance of Logger class. By default, the log
* messages will be written to standard output with minimal
* level of INFO. Users are able to further set the log file or
* log level with corresponding methods.
* \param level Minimal log level, Info by default.
*/
explicit
Logger
(
LogLevel
level
=
LogLevel
::
Info
);
/*!
* \brief Creates an instance of Logger class by specifying log file
* and log level. The log message will be written to both STDOUT
* and file (if created successfully).
* \param filename Log file name
* \param level Minimal log level
*/
explicit
Logger
(
std
::
string
filename
,
LogLevel
level
=
LogLevel
::
Info
);
~
Logger
();
/*!
* \brief Resets the log file.
* \param filename The new log filename. If it is empty, the Logger
* will close current log file (if it exists).
* \return Returns -1 if the filename is not empty but failed on
* creating the log file, or 0 will be returned otherwise.
*/
int
ResetLogFile
(
std
::
string
filename
);
/*!
* \brief Resets the log level.
* \param level The new log level.
*/
void
ResetLogLevel
(
LogLevel
level
)
{
level_
=
level
;
}
/*!
* \brief Resets the option of whether kill the process when fatal
* error occurs. By defualt the option is false.
*/
void
ResetKillFatal
(
bool
is_kill_fatal
)
{
is_kill_fatal_
=
is_kill_fatal
;
}
/*!
* \brief C style formatted method for writing log messages. A message
* is with the following format: [LEVEL] [TIME] message
* \param level The log level of this message.
* \param format The C format string.
* \param ... Output items.
*/
void
Write
(
LogLevel
level
,
const
char
*
format
,
...);
void
Debug
(
const
char
*
format
,
...);
void
Info
(
const
char
*
format
,
...);
void
Error
(
const
char
*
format
,
...);
void
Fatal
(
const
char
*
format
,
...);
private:
void
Write
(
LogLevel
level
,
const
char
*
format
,
va_list
*
val
);
void
CloseLogFile
();
// Returns current system time as a string.
std
::
string
GetSystemTime
();
// Returns the string of a log level.
std
::
string
GetLevelStr
(
LogLevel
level
);
std
::
FILE
*
file_
;
// A file pointer to the log file.
LogLevel
level_
;
// Only the message not less than level_ will be outputed.
bool
is_kill_fatal_
;
// If kill the process when fatal error occurs.
// No copying allowed
Logger
(
const
Logger
&
);
void
operator
=
(
const
Logger
&
);
};
/*!
* \brief The Log class is a static wrapper of a global Logger instance in
* the scope of a process. Users can write logging messages easily
* with the static methods.
*/
class
Log
{
public:
/*!
* \brief Resets the log file. The logger will write messages to the
* log file if it exists in addition to the STDOUT by default.
* \param filename The log filename. If it is empty, the logger will
* close the current log file (if it exists) and only output to
* STDOUT.
* \return -1 if fail on creating the log file, or 0 otherwise.
*/
static
int
ResetLogFile
(
std
::
string
filename
);
/*!
* \brief Resets the minimal log level. It is INFO by default.
* \param level The new minimal log level.
*/
static
void
ResetLogLevel
(
LogLevel
level
);
/*!
* \brief Resets the option of whether kill the process when fatal
* error occurs. By defualt the option is false.
*/
static
void
ResetKillFatal
(
bool
is_kill_fatal
);
/*! \brief The C formatted methods of writing the messages. */
static
void
Write
(
LogLevel
level
,
const
char
*
format
,
...);
static
void
Debug
(
const
char
*
format
,
...);
static
void
Info
(
const
char
*
format
,
...);
static
void
Error
(
const
char
*
format
,
...);
static
void
Fatal
(
const
char
*
format
,
...);
private:
static
Logger
logger_
;
};
}
// namespace LightGBM
#endif // LightGBM_UTILS_LOG_H_
include/LightGBM/utils/text_reader.h
View file @
ee97ed3d
...
...
@@ -87,7 +87,7 @@ public:
});
// if last line of file doesn't contain end of line
if
(
last_line_
.
size
()
>
0
)
{
Log
::
Stdout
(
"Warning: last line of file %s doesn't contain end of line, application will still use this line"
,
filename_
);
Log
::
Info
(
"Warning: last line of file %s doesn't contain end of line, application will still use this line"
,
filename_
);
process_fun
(
total_cnt
,
last_line_
.
c_str
(),
last_line_
.
size
());
++
total_cnt
;
last_line_
=
""
;
...
...
@@ -224,7 +224,7 @@ public:
});
// if last line of file doesn't contain end of line
if
(
last_line_
.
size
()
>
0
)
{
Log
::
Stdout
(
"Warning: last line of file %s doesn't contain end of line, application will still use this line"
,
filename_
);
Log
::
Info
(
"Warning: last line of file %s doesn't contain end of line, application will still use this line"
,
filename_
);
if
(
filter_fun
(
used_cnt
,
total_cnt
))
{
lines_
.
push_back
(
last_line_
);
process_fun
(
used_cnt
,
lines_
);
...
...
src/application/application.cpp
View file @
ee97ed3d
...
...
@@ -69,7 +69,7 @@ void Application::LoadParameters(int argc, char** argv) {
params
[
key
]
=
value
;
}
else
{
Log
::
Stdout
(
"Warning: unknown parameter in command line: %s"
,
argv
[
i
]);
Log
::
Info
(
"Warning: unknown parameter in command line: %s"
,
argv
[
i
]);
}
}
// check for alias
...
...
@@ -101,11 +101,11 @@ void Application::LoadParameters(int argc, char** argv) {
}
}
else
{
Log
::
Stdout
(
"Warning: unknown parameter in config file: %s"
,
line
.
c_str
());
Log
::
Info
(
"Warning: unknown parameter in config file: %s"
,
line
.
c_str
());
}
}
}
else
{
Log
::
Stdout
(
"config file: %s doesn't exist, will ignore"
,
Log
::
Info
(
"config file: %s doesn't exist, will ignore"
,
params
[
"config_file"
].
c_str
());
}
}
...
...
@@ -113,7 +113,7 @@ void Application::LoadParameters(int argc, char** argv) {
ParameterAlias
::
KeyAliasTransform
(
&
params
);
// load configs
config_
.
Set
(
params
);
Log
::
Stdout
(
"finished load parameters"
);
Log
::
Info
(
"finished load parameters"
);
}
void
Application
::
LoadData
()
{
...
...
@@ -201,7 +201,7 @@ void Application::LoadData() {
}
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// output used time on each iteration
Log
::
Stdout
(
"Finish loading data, use %f seconds "
,
Log
::
Info
(
"Finish loading data, use %f seconds "
,
std
::
chrono
::
duration
<
double
,
std
::
milli
>
(
end_time
-
start_time
)
*
1e-3
);
}
...
...
@@ -209,7 +209,7 @@ void Application::InitTrain() {
if
(
config_
.
is_parallel
)
{
// need init network
Network
::
Init
(
config_
.
network_config
);
Log
::
Stdout
(
"finish network initialization"
);
Log
::
Info
(
"finish network initialization"
);
// sync global random seed for feature patition
if
(
config_
.
boosting_type
==
BoostingType
::
kGBDT
)
{
GBDTConfig
*
gbdt_config
=
...
...
@@ -240,13 +240,13 @@ void Application::InitTrain() {
boosting_
->
AddDataset
(
valid_datas_
[
i
],
ConstPtrInVectorWarpper
<
Metric
>
(
valid_metrics_
[
i
]));
}
Log
::
Stdout
(
"finish training init"
);
Log
::
Info
(
"finish training init"
);
}
void
Application
::
Train
()
{
Log
::
Stdout
(
"start train"
);
Log
::
Info
(
"start train"
);
boosting_
->
Train
();
Log
::
Stdout
(
"finish train"
);
Log
::
Info
(
"finish train"
);
}
...
...
@@ -254,14 +254,14 @@ void Application::Predict() {
// create predictor
Predictor
predictor
(
boosting_
,
config_
.
io_config
.
is_sigmoid
);
predictor
.
Predict
(
config_
.
io_config
.
data_filename
.
c_str
(),
config_
.
io_config
.
output_result
.
c_str
());
Log
::
Stdout
(
"finish predict"
);
Log
::
Info
(
"finish predict"
);
}
void
Application
::
InitPredict
()
{
boosting_
=
Boosting
::
CreateBoosting
(
config_
.
boosting_type
,
config_
.
boosting_config
);
LoadModel
();
Log
::
Stdout
(
"finish predict init"
);
Log
::
Info
(
"finish predict init"
);
}
void
Application
::
LoadModel
()
{
...
...
src/application/predictor.hpp
View file @
ee97ed3d
...
...
@@ -106,13 +106,13 @@ public:
#endif
if
(
result_file
==
NULL
)
{
Log
::
Stder
r
(
"predition result file %s doesn't exists"
,
data_filename
);
Log
::
Erro
r
(
"predition result file %s doesn't exists"
,
data_filename
);
}
bool
has_label
=
false
;
Parser
*
parser
=
Parser
::
CreateParser
(
data_filename
,
num_features_
,
&
has_label
);
if
(
parser
==
nullptr
)
{
Log
::
Stder
r
(
"recongnizing input data format failed, filename %s"
,
data_filename
);
Log
::
Erro
r
(
"recongnizing input data format failed, filename %s"
,
data_filename
);
}
// function for parse data
...
...
@@ -124,14 +124,14 @@ public:
(
const
char
*
buffer
,
std
::
vector
<
std
::
pair
<
int
,
double
>>*
feature
)
{
parser
->
ParseOneLine
(
buffer
,
feature
,
&
tmp_label
);
};
Log
::
Stdout
(
"start prediction for data %s, and data has label"
,
data_filename
);
Log
::
Info
(
"start prediction for data %s, and data has label"
,
data_filename
);
}
else
{
// parse function without label
parser_fun
=
[
this
,
&
parser
]
(
const
char
*
buffer
,
std
::
vector
<
std
::
pair
<
int
,
double
>>*
feature
)
{
parser
->
ParseOneLine
(
buffer
,
feature
);
};
Log
::
Stdout
(
"start prediction for data %s, and data doesn't has label"
,
data_filename
);
Log
::
Info
(
"start prediction for data %s, and data doesn't has label"
,
data_filename
);
}
std
::
function
<
double
(
const
std
::
vector
<
std
::
pair
<
int
,
double
>>&
)
>
predict_fun
;
if
(
is_simgoid_
)
{
...
...
src/boosting/gbdt.cpp
View file @
ee97ed3d
...
...
@@ -150,7 +150,7 @@ void GBDT::Bagging(int iter) {
bag_data_cnt_
=
cur_left_cnt
;
out_of_bag_data_cnt_
=
num_data_
-
bag_data_cnt_
;
}
Log
::
Stdout
(
"re-bagging, using %d data to train"
,
bag_data_cnt_
);
Log
::
Info
(
"re-bagging, using %d data to train"
,
bag_data_cnt_
);
// set bagging data to tree learner
tree_learner_
->
SetBaggingData
(
bag_data_indices_
,
bag_data_cnt_
);
}
...
...
@@ -176,7 +176,7 @@ void GBDT::Train() {
Tree
*
new_tree
=
TrainOneTree
();
// if cannot learn a new tree, then stop
if
(
new_tree
->
num_leaves
()
<=
1
)
{
Log
::
Stdout
(
"Cannot do any boosting for tree cannot split"
);
Log
::
Info
(
"Cannot do any boosting for tree cannot split"
);
break
;
}
// shrinkage by learning rate
...
...
@@ -194,7 +194,7 @@ void GBDT::Train() {
fflush
(
output_model_file
);
auto
end_time
=
std
::
chrono
::
high_resolution_clock
::
now
();
// output used time per iteration
Log
::
Stdout
(
"%f seconds elapsed, finished %d iteration"
,
std
::
chrono
::
duration
<
double
,
Log
::
Info
(
"%f seconds elapsed, finished %d iteration"
,
std
::
chrono
::
duration
<
double
,
std
::
milli
>
(
end_time
-
start_time
)
*
1e-3
,
iter
+
1
);
}
// close file
...
...
@@ -284,7 +284,7 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
}
}
if
(
i
==
lines
.
size
())
{
Log
::
Stder
r
(
"The model doesn't contain max_feature_idx"
);
Log
::
Erro
r
(
"The model doesn't contain max_feature_idx"
);
return
;
}
// get sigmoid parameter
...
...
@@ -323,7 +323,7 @@ void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
}
}
Log
::
Stdout
(
"Loaded %d models
\n
"
,
models_
.
size
());
Log
::
Info
(
"Loaded %d models
\n
"
,
models_
.
size
());
}
double
GBDT
::
PredictRaw
(
const
double
*
value
)
const
{
...
...
src/io/config.cpp
View file @
ee97ed3d
...
...
@@ -43,7 +43,7 @@ void OverallConfig::GetBoostingType(const std::unordered_map<std::string, std::s
if
(
value
==
std
::
string
(
"gbdt"
)
||
value
==
std
::
string
(
"gbrt"
))
{
boosting_type
=
BoostingType
::
kGBDT
;
}
else
{
Log
::
Stder
r
(
"boosting type %s error"
,
value
.
c_str
());
Log
::
Erro
r
(
"boosting type %s error"
,
value
.
c_str
());
}
}
}
...
...
@@ -91,7 +91,7 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
||
value
==
std
::
string
(
"test"
))
{
task_type
=
TaskType
::
kPredict
;
}
else
{
Log
::
Stder
r
(
"task type error"
);
Log
::
Erro
r
(
"task type error"
);
}
}
}
...
...
@@ -128,7 +128,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt
(
params
,
"data_random_seed"
,
&
data_random_seed
);
if
(
!
GetString
(
params
,
"data"
,
&
data_filename
))
{
Log
::
Stder
r
(
"No training/prediction data, application quit"
);
Log
::
Erro
r
(
"No training/prediction data, application quit"
);
}
GetInt
(
params
,
"num_model_predict"
,
&
num_model_predict
);
GetBool
(
params
,
"is_pre_partition"
,
&
is_pre_partition
);
...
...
@@ -236,7 +236,7 @@ void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::s
tree_learner_type
=
TreeLearnerType
::
kDataParallelTreeLearner
;
}
else
{
Log
::
Stder
r
(
"tree learner type error"
);
Log
::
Erro
r
(
"tree learner type error"
);
}
}
}
...
...
src/io/dataset.cpp
View file @
ee97ed3d
...
...
@@ -21,7 +21,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
CheckCanLoadFromBin
();
if
(
is_loading_from_binfile_
&&
predict_fun
!=
nullptr
)
{
Log
::
Stdout
(
"cannot perform initial prediction for binary file, will use text file instead"
);
Log
::
Info
(
"cannot perform initial prediction for binary file, will use text file instead"
);
is_loading_from_binfile_
=
false
;
}
...
...
@@ -31,14 +31,14 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
// create text parser
parser_
=
Parser
::
CreateParser
(
data_filename_
,
0
,
nullptr
);
if
(
parser_
==
nullptr
)
{
Log
::
Stder
r
(
"cannot recognise input data format, filename: %s"
,
data_filename_
);
Log
::
Erro
r
(
"cannot recognise input data format, filename: %s"
,
data_filename_
);
}
// create text reader
text_reader_
=
new
TextReader
<
data_size_t
>
(
data_filename
);
}
else
{
// only need to load initilize score, other meta data will load from bin flie
metadata_
.
Init
(
init_score_filename
);
Log
::
Stdout
(
"will load data set from binary file"
);
Log
::
Info
(
"will load data set from binary file"
);
parser_
=
nullptr
;
text_reader_
=
nullptr
;
}
...
...
@@ -82,7 +82,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition
[
this
,
rank
,
num_machines
,
&
qid
,
&
query_boundaries
,
&
is_query_used
,
num_queries
]
(
data_size_t
line_idx
)
{
if
(
qid
>=
num_queries
)
{
Log
::
Stder
r
(
"current query is exceed the range of query file, please ensure your query file is correct"
);
Log
::
Erro
r
(
"current query is exceed the range of query file, please ensure your query file is correct"
);
}
if
(
line_idx
>=
query_boundaries
[
qid
+
1
])
{
// if is new query
...
...
@@ -139,7 +139,7 @@ void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partiti
[
this
,
rank
,
num_machines
,
&
qid
,
&
query_boundaries
,
&
is_query_used
,
num_queries
]
(
data_size_t
line_idx
)
{
if
(
qid
>=
num_queries
)
{
Log
::
Stder
r
(
"current query is exceed the range of query file, \
Log
::
Erro
r
(
"current query is exceed the range of query file, \
please ensure your query file is correct"
);
}
if
(
line_idx
>=
query_boundaries
[
qid
+
1
])
{
...
...
@@ -209,7 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_
,
is_enable_sparse_
));
}
else
{
// if feature is trival(only 1 bin), free spaces
Log
::
Stdout
(
"Warning: feature %d only contains one value, will ignore it"
,
i
);
Log
::
Info
(
"Warning: feature %d only contains one value, will ignore it"
,
i
);
delete
bin_mappers
[
i
];
}
}
...
...
@@ -486,10 +486,10 @@ void Dataset::SaveBinaryFile() {
file
=
fopen
(
bin_filename
.
c_str
(),
"wb"
);
#endif
if
(
file
==
NULL
)
{
Log
::
Stder
r
(
"cannot write binary data to %s "
,
bin_filename
.
c_str
());
Log
::
Erro
r
(
"cannot write binary data to %s "
,
bin_filename
.
c_str
());
}
Log
::
Stdout
(
"start save binary file for data %s"
,
data_filename_
);
Log
::
Info
(
"start save binary file for data %s"
,
data_filename_
);
// get size of header
size_t
size_of_header
=
sizeof
(
global_num_data_
)
+
sizeof
(
is_enable_sparse_
)
...
...
@@ -556,7 +556,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
#endif
if
(
file
==
NULL
)
{
Log
::
Stder
r
(
"cannot read binary data from %s"
,
bin_filename
.
c_str
());
Log
::
Erro
r
(
"cannot read binary data from %s"
,
bin_filename
.
c_str
());
}
// buffer to read binary file
...
...
@@ -567,7 +567,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
size_t
read_cnt
=
fread
(
buffer
,
sizeof
(
size_t
),
1
,
file
);
if
(
read_cnt
!=
1
)
{
Log
::
Stder
r
(
"binary file format error at header size"
);
Log
::
Erro
r
(
"binary file format error at header size"
);
}
size_t
size_of_head
=
*
(
reinterpret_cast
<
size_t
*>
(
buffer
));
...
...
@@ -582,7 +582,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt
=
fread
(
buffer
,
1
,
size_of_head
,
file
);
if
(
read_cnt
!=
size_of_head
)
{
Log
::
Stder
r
(
"binary file format error at header"
);
Log
::
Erro
r
(
"binary file format error at header"
);
}
// get header
const
char
*
mem_ptr
=
buffer
;
...
...
@@ -608,7 +608,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt
=
fread
(
buffer
,
sizeof
(
size_t
),
1
,
file
);
if
(
read_cnt
!=
1
)
{
Log
::
Stder
r
(
"binary file format error at size of meta data"
);
Log
::
Erro
r
(
"binary file format error at size of meta data"
);
}
size_t
size_of_metadata
=
*
(
reinterpret_cast
<
size_t
*>
(
buffer
));
...
...
@@ -623,7 +623,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt
=
fread
(
buffer
,
1
,
size_of_metadata
,
file
);
if
(
read_cnt
!=
size_of_metadata
)
{
Log
::
Stder
r
(
"binary file format error at meta data"
);
Log
::
Erro
r
(
"binary file format error at meta data"
);
}
// load meta data
metadata_
.
LoadFromMemory
(
buffer
);
...
...
@@ -647,7 +647,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
bool
is_query_used
=
false
;
for
(
data_size_t
i
=
0
;
i
<
num_data_
;
i
++
)
{
if
(
qid
>=
num_queries
)
{
Log
::
Stder
r
(
"current query is exceed the range of query file, please ensure your query file is correct"
);
Log
::
Erro
r
(
"current query is exceed the range of query file, please ensure your query file is correct"
);
}
if
(
i
>=
query_boundaries
[
qid
+
1
])
{
// if is new query
...
...
@@ -670,7 +670,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
// read feature size
read_cnt
=
fread
(
buffer
,
sizeof
(
size_t
),
1
,
file
);
if
(
read_cnt
!=
1
)
{
Log
::
Stder
r
(
"binary file format error at feature %d's size"
,
i
);
Log
::
Erro
r
(
"binary file format error at feature %d's size"
,
i
);
}
size_t
size_of_feature
=
*
(
reinterpret_cast
<
size_t
*>
(
buffer
));
// re-allocmate space if not enough
...
...
@@ -683,7 +683,7 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
read_cnt
=
fread
(
buffer
,
1
,
size_of_feature
,
file
);
if
(
read_cnt
!=
size_of_feature
)
{
Log
::
Stder
r
(
"binary file format error at feature %d loading , read count %d"
,
i
,
read_cnt
);
Log
::
Erro
r
(
"binary file format error at feature %d loading , read count %d"
,
i
,
read_cnt
);
}
features_
.
push_back
(
new
Feature
(
buffer
,
static_cast
<
data_size_t
>
(
global_num_data_
),
used_data_indices_
));
}
...
...
@@ -693,10 +693,10 @@ void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partit
void
Dataset
::
CheckDataset
()
{
if
(
num_data_
<=
0
)
{
Log
::
Stder
r
(
"data size of %s is zero"
,
data_filename_
);
Log
::
Erro
r
(
"data size of %s is zero"
,
data_filename_
);
}
if
(
features_
.
size
()
<=
0
)
{
Log
::
Stder
r
(
"not useful feature of data %s"
,
data_filename_
);
Log
::
Erro
r
(
"not useful feature of data %s"
,
data_filename_
);
}
}
...
...
src/io/metadata.cpp
View file @
ee97ed3d
...
...
@@ -61,7 +61,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
if
(
used_data_indices
.
size
()
==
0
)
{
// check weights
if
(
weights_
!=
nullptr
&&
num_weights_
!=
num_data_
)
{
Log
::
Stdout
(
"init weight size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init weight size doesn't equal with data file, will ignore"
);
delete
[]
weights_
;
num_weights_
=
0
;
weights_
=
nullptr
;
...
...
@@ -69,7 +69,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// check query boundries
if
(
query_boundaries_
!=
nullptr
&&
query_boundaries_
[
num_queries_
]
!=
num_data_
)
{
Log
::
Stdout
(
"init query size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init query size doesn't equal with data file, will ignore"
);
delete
[]
query_boundaries_
;
num_queries_
=
0
;
query_boundaries_
=
nullptr
;
...
...
@@ -78,21 +78,21 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file
if
(
init_score_
!=
nullptr
&&
num_init_score_
!=
num_data_
)
{
delete
[]
init_score_
;
Log
::
Stdout
(
"init score size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init score size doesn't equal with data file, will ignore"
);
num_init_score_
=
0
;
}
}
else
{
data_size_t
num_used_data
=
static_cast
<
data_size_t
>
(
used_data_indices
.
size
());
// check weights
if
(
weights_
!=
nullptr
&&
num_weights_
!=
num_all_data
)
{
Log
::
Stdout
(
"init weight size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init weight size doesn't equal with data file, will ignore"
);
delete
[]
weights_
;
num_weights_
=
0
;
weights_
=
nullptr
;
}
// check query boundries
if
(
query_boundaries_
!=
nullptr
&&
query_boundaries_
[
num_queries_
]
!=
num_all_data
)
{
Log
::
Stdout
(
"init query size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init query size doesn't equal with data file, will ignore"
);
delete
[]
query_boundaries_
;
num_queries_
=
0
;
query_boundaries_
=
nullptr
;
...
...
@@ -100,7 +100,7 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
// contain initial score file
if
(
init_score_
!=
nullptr
&&
num_init_score_
!=
num_all_data
)
{
Log
::
Stdout
(
"init score size doesn't equal with data file, will ignore"
);
Log
::
Info
(
"init score size doesn't equal with data file, will ignore"
);
delete
[]
init_score_
;
num_init_score_
=
0
;
}
...
...
@@ -131,10 +131,10 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
used_query
.
push_back
(
qid
);
data_idx
+=
len
;
}
else
{
Log
::
Stder
r
(
"data partition error, not according to query"
);
Log
::
Erro
r
(
"data partition error, not according to query"
);
}
}
else
{
Log
::
Stder
r
(
"data partition error, not according to query"
);
Log
::
Erro
r
(
"data partition error, not according to query"
);
}
}
data_size_t
*
old_query_boundaries
=
query_boundaries_
;
...
...
@@ -182,7 +182,7 @@ void Metadata::LoadWeights() {
if
(
reader
.
Lines
().
size
()
<=
0
)
{
return
;
}
Log
::
Stdout
(
"Start to load weights"
);
Log
::
Info
(
"Start to load weights"
);
num_weights_
=
static_cast
<
data_size_t
>
(
reader
.
Lines
().
size
());
weights_
=
new
float
[
num_weights_
];
for
(
data_size_t
i
=
0
;
i
<
num_weights_
;
++
i
)
{
...
...
@@ -198,7 +198,7 @@ void Metadata::LoadInitialScore() {
TextReader
<
size_t
>
reader
(
init_score_filename_
);
reader
.
ReadAllLines
();
Log
::
Stdout
(
"Start to load initial score"
);
Log
::
Info
(
"Start to load initial score"
);
num_init_score_
=
static_cast
<
data_size_t
>
(
reader
.
Lines
().
size
());
init_score_
=
new
score_t
[
num_init_score_
];
double
tmp
=
0.0
f
;
...
...
@@ -218,7 +218,7 @@ void Metadata::LoadQueryBoundaries() {
if
(
reader
.
Lines
().
size
()
<=
0
)
{
return
;
}
Log
::
Stdout
(
"Start to load query boundries"
);
Log
::
Info
(
"Start to load query boundries"
);
query_boundaries_
=
new
data_size_t
[
reader
.
Lines
().
size
()
+
1
];
num_queries_
=
static_cast
<
data_size_t
>
(
reader
.
Lines
().
size
());
query_boundaries_
[
0
]
=
0
;
...
...
@@ -233,7 +233,7 @@ void Metadata::LoadQueryWeights() {
if
(
weights_
==
nullptr
||
query_boundaries_
==
nullptr
)
{
return
;
}
Log
::
Stdout
(
"Start to load query weights"
);
Log
::
Info
(
"Start to load query weights"
);
query_weights_
=
new
float
[
num_queries_
];
for
(
data_size_t
i
=
0
;
i
<
num_queries_
;
++
i
)
{
query_weights_
[
i
]
=
0.0
f
;
...
...
src/io/parser.cpp
View file @
ee97ed3d
...
...
@@ -55,18 +55,18 @@ Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_l
std
::
ifstream
tmp_file
;
tmp_file
.
open
(
filename
);
if
(
!
tmp_file
.
is_open
())
{
Log
::
Stder
r
(
"Data file: %s doesn't exist"
,
filename
);
Log
::
Erro
r
(
"Data file: %s doesn't exist"
,
filename
);
}
std
::
string
line1
,
line2
;
if
(
!
tmp_file
.
eof
())
{
std
::
getline
(
tmp_file
,
line1
);
}
else
{
Log
::
Stder
r
(
"Data file: %s at least should have one line"
,
filename
);
Log
::
Erro
r
(
"Data file: %s at least should have one line"
,
filename
);
}
if
(
!
tmp_file
.
eof
())
{
std
::
getline
(
tmp_file
,
line2
);
}
else
{
Log
::
Stdout
(
"Data file: %s only have one line"
,
filename
);
Log
::
Info
(
"Data file: %s only have one line"
,
filename
);
}
tmp_file
.
close
();
int
comma_cnt
=
0
,
comma_cnt2
=
0
;
...
...
src/io/parser.hpp
View file @
ee97ed3d
...
...
@@ -27,7 +27,7 @@ public:
if
(
*
str
==
','
)
{
++
str
;
}
else
if
(
*
str
!=
'\0'
)
{
Log
::
Stder
r
(
"input format error, should be CSV"
);
Log
::
Erro
r
(
"input format error, should be CSV"
);
}
}
}
...
...
@@ -38,7 +38,7 @@ public:
if
(
*
str
==
','
)
{
++
str
;
}
else
if
(
*
str
!=
'\0'
)
{
Log
::
Stder
r
(
"input format error, should be CSV"
);
Log
::
Erro
r
(
"input format error, should be CSV"
);
}
return
ParseOneLine
(
str
,
out_features
);
}
...
...
@@ -58,7 +58,7 @@ public:
if
(
*
str
==
'\t'
)
{
++
str
;
}
else
if
(
*
str
!=
'\0'
)
{
Log
::
Stder
r
(
"input format error, should be TSV"
);
Log
::
Erro
r
(
"input format error, should be TSV"
);
}
}
}
...
...
@@ -69,7 +69,7 @@ public:
if
(
*
str
==
'\t'
)
{
++
str
;
}
else
if
(
*
str
!=
'\0'
)
{
Log
::
Stder
r
(
"input format error, should be TSV"
);
Log
::
Erro
r
(
"input format error, should be TSV"
);
}
return
ParseOneLine
(
str
,
out_features
);
}
...
...
@@ -88,7 +88,7 @@ public:
str
=
Common
::
Atof
(
str
,
&
val
);
out_features
->
emplace_back
(
idx
,
val
);
}
else
{
Log
::
Stder
r
(
"input format error, should be LibSVM"
);
Log
::
Erro
r
(
"input format error, should be LibSVM"
);
}
str
=
Common
::
SkipSpaceAndTab
(
str
);
}
...
...
src/io/sparse_bin.hpp
View file @
ee97ed3d
...
...
@@ -28,7 +28,7 @@ public:
:
num_data_
(
num_data
)
{
default_bin_
=
static_cast
<
VAL_T
>
(
default_bin
);
if
(
default_bin_
!=
0
)
{
Log
::
Stdout
(
"Warning: Having sparse feature with negative values. Will let negative values equal zero as well"
);
Log
::
Info
(
"Warning: Having sparse feature with negative values. Will let negative values equal zero as well
\n
"
);
}
#pragma omp parallel
#pragma omp master
...
...
@@ -54,7 +54,7 @@ public:
void
ConstructHistogram
(
data_size_t
*
,
data_size_t
,
const
score_t
*
,
const
score_t
*
,
HistogramBinEntry
*
)
const
override
{
// Will use OrderedSparseBin->ConstructHistogram() instead
Log
::
Stderr
(
"Should use OrderedSparseBin->ConstructHistogram() instead"
);
Log
::
Info
(
"Should use OrderedSparseBin->ConstructHistogram() instead"
);
}
data_size_t
Split
(
unsigned
int
threshold
,
data_size_t
*
data_indices
,
data_size_t
num_data
,
...
...
src/io/tree.cpp
View file @
ee97ed3d
...
...
@@ -140,7 +140,7 @@ Tree::Tree(const std::string& str) {
||
key_vals
.
count
(
"split_gain"
)
<=
0
||
key_vals
.
count
(
"threshold"
)
<=
0
||
key_vals
.
count
(
"left_child"
)
<=
0
||
key_vals
.
count
(
"right_child"
)
<=
0
||
key_vals
.
count
(
"leaf_parent"
)
<=
0
||
key_vals
.
count
(
"leaf_value"
)
<=
0
)
{
Log
::
Stder
r
(
"tree model string format error"
);
Log
::
Erro
r
(
"tree model string format error"
);
}
Common
::
Atoi
(
key_vals
[
"num_leaves"
].
c_str
(),
&
num_leaves_
);
...
...
src/metric/binary_metric.hpp
View file @
ee97ed3d
...
...
@@ -23,7 +23,7 @@ public:
the_bigger_the_better
=
false
;
sigmoid_
=
static_cast
<
score_t
>
(
config
.
sigmoid
);
if
(
sigmoid_
<=
0.0
f
)
{
Log
::
Stder
r
(
"sigmoid param %f should greater than zero"
,
sigmoid_
);
Log
::
Erro
r
(
"sigmoid param %f should greater than zero"
,
sigmoid_
);
}
}
...
...
@@ -72,7 +72,7 @@ public:
}
loss
=
sum_loss
/
sum_weights_
;
if
(
output_freq_
>
0
&&
iter
%
output_freq_
==
0
){
Log
::
Stdout
(
"Iteration:%d, %s's %s: %f"
,
iter
,
name
,
PointWiseLossCalculator
::
Name
(),
loss
);
Log
::
Info
(
"Iteration:%d, %s's %s: %f"
,
iter
,
name
,
PointWiseLossCalculator
::
Name
(),
loss
);
}
}
}
...
...
@@ -229,7 +229,7 @@ public:
}
loss
=
auc
;
if
(
output_freq_
>
0
&&
iter
%
output_freq_
==
0
){
Log
::
Stdout
(
"iteration:%d, %s's %s: %f"
,
iter
,
name
,
"auc"
,
loss
);
Log
::
Info
(
"iteration:%d, %s's %s: %f"
,
iter
,
name
,
"auc"
,
loss
);
}
}
}
...
...
src/metric/dcg_calculator.cpp
View file @
ee97ed3d
...
...
@@ -57,7 +57,7 @@ void DCGCalculator::CalMaxDCG(const std::vector<data_size_t>& ks,
std
::
vector
<
data_size_t
>
label_cnt
(
label_gain_
.
size
(),
0
);
// counts for all labels
for
(
data_size_t
i
=
0
;
i
<
num_data
;
++
i
)
{
if
(
static_cast
<
size_t
>
(
label
[
i
])
>=
label_cnt
.
size
())
{
Log
::
Stder
r
(
"label excel %d
\n
"
,
label
[
i
]);
}
if
(
static_cast
<
size_t
>
(
label
[
i
])
>=
label_cnt
.
size
())
{
Log
::
Erro
r
(
"label excel %d
\n
"
,
label
[
i
]);
}
++
label_cnt
[
static_cast
<
int
>
(
label
[
i
])];
}
double
cur_result
=
0.0
;
...
...
src/metric/rank_metric.hpp
View file @
ee97ed3d
...
...
@@ -43,7 +43,7 @@ public:
// get query boundaries
query_boundaries_
=
metadata
.
query_boundaries
();
if
(
query_boundaries_
==
nullptr
)
{
Log
::
Stder
r
(
"For NDCG metric, should have query information"
);
Log
::
Erro
r
(
"For NDCG metric, should have query information"
);
}
num_queries_
=
metadata
.
num_queries
();
// get query weights
...
...
@@ -136,7 +136,7 @@ public:
}
loss
=
result
[
0
];
if
(
output_freq_
>
0
&&
iter
%
output_freq_
==
0
){
Log
::
Stdout
(
"Iteration:%d, Test:%s, %s "
,
iter
,
name
,
result_ss
.
str
().
c_str
());
Log
::
Info
(
"Iteration:%d, Test:%s, %s "
,
iter
,
name
,
result_ss
.
str
().
c_str
());
}
}
}
...
...
src/metric/regression_metric.hpp
View file @
ee97ed3d
...
...
@@ -60,7 +60,7 @@ public:
}
loss
=
PointWiseLossCalculator
::
AverageLoss
(
sum_loss
,
sum_weights_
);
if
(
output_freq_
>
0
&&
iter
%
output_freq_
==
0
){
Log
::
Stdout
(
"Iteration:%d, %s's %s : %f"
,
iter
,
name
,
PointWiseLossCalculator
::
Name
(),
loss
);
Log
::
Info
(
"Iteration:%d, %s's %s : %f"
,
iter
,
name
,
PointWiseLossCalculator
::
Name
(),
loss
);
}
}
}
...
...
src/network/linkers_socket.cpp
View file @
ee97ed3d
...
...
@@ -44,7 +44,7 @@ Linkers::Linkers(NetworkConfig config) {
}
}
if
(
rank_
==
-
1
)
{
Log
::
Stder
r
(
"machine list file doesn't contain local machine, app quit"
);
Log
::
Erro
r
(
"machine list file doesn't contain local machine, app quit"
);
}
// construct listener
listener_
=
new
TcpSocket
();
...
...
@@ -73,14 +73,14 @@ Linkers::~Linkers() {
}
}
TcpSocket
::
Finalize
();
Log
::
Stdout
(
"network used %f seconds"
,
network_time_
*
1e-3
);
Log
::
Info
(
"network used %f seconds"
,
network_time_
*
1e-3
);
}
void
Linkers
::
ParseMachineList
(
const
char
*
filename
)
{
TextReader
<
size_t
>
machine_list_reader
(
filename
);
machine_list_reader
.
ReadAllLines
();
if
(
machine_list_reader
.
Lines
().
size
()
<=
0
)
{
Log
::
Stder
r
(
"machine list file:%s doesn't exist"
,
filename
);
Log
::
Erro
r
(
"machine list file:%s doesn't exist"
,
filename
);
}
for
(
auto
&
line
:
machine_list_reader
.
Lines
())
{
...
...
@@ -95,7 +95,7 @@ void Linkers::ParseMachineList(const char * filename) {
continue
;
}
if
(
client_ips_
.
size
()
>=
static_cast
<
size_t
>
(
num_machines_
))
{
Log
::
Stdout
(
"The #machine in machine list is larger than parameter num_machines, will ignore rest"
);
Log
::
Info
(
"The #machine in machine list is larger than parameter num_machines, will ignore rest"
);
break
;
}
str_after_split
[
0
]
=
Common
::
Trim
(
str_after_split
[
0
]);
...
...
@@ -104,17 +104,17 @@ void Linkers::ParseMachineList(const char * filename) {
client_ports_
.
push_back
(
atoi
(
str_after_split
[
1
].
c_str
()));
}
if
(
client_ips_
.
size
()
!=
static_cast
<
size_t
>
(
num_machines_
))
{
Log
::
Stdout
(
"The world size is bigger the #machine in machine list, change world size to %d ."
,
client_ips_
.
size
());
Log
::
Info
(
"The world size is bigger the #machine in machine list, change world size to %d ."
,
client_ips_
.
size
());
num_machines_
=
static_cast
<
int
>
(
client_ips_
.
size
());
}
}
void
Linkers
::
TryBind
(
int
port
)
{
Log
::
Stdout
(
"try to bind port %d."
,
port
);
Log
::
Info
(
"try to bind port %d."
,
port
);
if
(
listener_
->
Bind
(
port
))
{
Log
::
Stdout
(
"bind port %d success."
,
port
);
Log
::
Info
(
"bind port %d success."
,
port
);
}
else
{
Log
::
Stder
r
(
"bind port %d failed."
,
port
);
Log
::
Erro
r
(
"bind port %d failed."
,
port
);
}
}
...
...
@@ -125,7 +125,7 @@ void Linkers::SetLinker(int rank, const TcpSocket& socket) {
}
void
Linkers
::
ListenThread
(
int
incoming_cnt
)
{
Log
::
Stdout
(
"Listening..."
);
Log
::
Info
(
"Listening..."
);
char
buffer
[
100
];
int
connected_cnt
=
0
;
while
(
connected_cnt
<
incoming_cnt
)
{
...
...
@@ -192,7 +192,7 @@ void Linkers::Construct() {
if
(
cur_socket
.
Connect
(
client_ips_
[
out_rank
].
c_str
(),
client_ports_
[
out_rank
]))
{
break
;
}
else
{
Log
::
Stdout
(
"connect to rank %d failed, wait for %d milliseconds"
,
out_rank
,
connect_fail_delay_time
);
Log
::
Info
(
"connect to rank %d failed, wait for %d milliseconds"
,
out_rank
,
connect_fail_delay_time
);
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
connect_fail_delay_time
));
}
}
...
...
@@ -217,7 +217,7 @@ bool Linkers::CheckLinker(int rank) {
void
Linkers
::
PrintLinkers
()
{
for
(
int
i
=
0
;
i
<
num_machines_
;
++
i
)
{
if
(
CheckLinker
(
i
))
{
Log
::
Stdout
(
"Connected to rank %d."
,
i
);
Log
::
Info
(
"Connected to rank %d."
,
i
);
}
}
}
...
...
src/network/network.cpp
View file @
ee97ed3d
...
...
@@ -30,7 +30,7 @@ void Network::Init(NetworkConfig config) {
block_len_
=
new
int
[
num_machines_
];
buffer_size_
=
1024
*
1024
;
buffer_
=
new
char
[
buffer_size_
];
Log
::
Stdout
(
"local rank %d, total number of machines %d"
,
rank_
,
num_machines_
);
Log
::
Info
(
"local rank %d, total number of machines %d"
,
rank_
,
num_machines_
);
}
void
Network
::
Dispose
()
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment