Unverified Commit dee72159 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

check the shape for mat, csr and csc in prediction (#2464)

* check the shape for mat, csr and csc

* guess from csr

* support file checking

* better error msg

* grammar

* clean code

* code clean

* check range for CSR

* Update test_.py

* Update test_.py

* added tests
parent dc65e0ac
...@@ -683,7 +683,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -683,7 +683,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param nindptr Number of rows in the matrix + 1 * \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix * \param nelem Number of nonzero elements in the matrix
* \param num_col Number of columns; when it's set to 0, then guess from data * \param num_col Number of columns
* \param predict_type What should be predicted * \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed); * - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
...@@ -726,7 +726,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -726,7 +726,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64`` * \param data_type Type of ``data`` pointer, can be ``C_API_DTYPE_FLOAT32`` or ``C_API_DTYPE_FLOAT64``
* \param nindptr Number of rows in the matrix + 1 * \param nindptr Number of rows in the matrix + 1
* \param nelem Number of nonzero elements in the matrix * \param nelem Number of nonzero elements in the matrix
* \param num_col Number of columns; when it's set to 0, then guess from data * \param num_col Number of columns
* \param predict_type What should be predicted * \param predict_type What should be predicted
* - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed); * - ``C_API_PREDICT_NORMAL``: normal prediction, with transform (if needed);
* - ``C_API_PREDICT_RAW_SCORE``: raw score; * - ``C_API_PREDICT_RAW_SCORE``: raw score;
......
...@@ -265,7 +265,7 @@ class Parser { ...@@ -265,7 +265,7 @@ class Parser {
virtual void ParseOneLine(const char* str, virtual void ParseOneLine(const char* str,
std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0; std::vector<std::pair<int, double>>* out_features, double* out_label) const = 0;
virtual int TotalColumns() const = 0; virtual int NumFeatures() const = 0;
/*! /*!
* \brief Create a object of parser, will auto choose the format depend on file * \brief Create a object of parser, will auto choose the format depend on file
...@@ -290,6 +290,7 @@ class Dataset { ...@@ -290,6 +290,7 @@ class Dataset {
void Construct( void Construct(
std::vector<std::unique_ptr<BinMapper>>* bin_mappers, std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
int num_total_features,
const std::vector<std::vector<double>>& forced_bins, const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col, const int* num_per_col,
......
...@@ -140,7 +140,9 @@ class Predictor { ...@@ -140,7 +140,9 @@ class Predictor {
if (parser == nullptr) { if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename); Log::Fatal("Could not recognize the data format of data file %s", data_filename);
} }
if (parser->NumFeatures() != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", parser->NumFeatures(), boosting_->MaxFeatureIdx() + 1);
}
TextReader<data_size_t> predict_data_reader(data_filename, header); TextReader<data_size_t> predict_data_reader(data_filename, header);
std::unordered_map<int, int> feature_names_map_; std::unordered_map<int, int> feature_names_map_;
bool need_adjust = false; bool need_adjust = false;
......
...@@ -249,17 +249,19 @@ class Booster { ...@@ -249,17 +249,19 @@ class Booster {
boosting_->RollbackOneIter(); boosting_->RollbackOneIter();
} }
void PredictSingleRow(int num_iteration, int predict_type, void PredictSingleRow(int num_iteration, int predict_type, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (single_row_predictor_[predict_type].get() == nullptr || if (single_row_predictor_[predict_type].get() == nullptr ||
!single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) { !single_row_predictor_[predict_type]->IsPredictorEqual(config, num_iteration, boosting_.get())) {
single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(), single_row_predictor_[predict_type].reset(new SingleRowPredictor(predict_type, boosting_.get(),
config, num_iteration)); config, num_iteration));
} }
auto one_row = get_row_fun(0); auto one_row = get_row_fun(0);
auto pred_wrt_ptr = out_result; auto pred_wrt_ptr = out_result;
single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr); single_row_predictor_[predict_type]->predict_function(one_row, pred_wrt_ptr);
...@@ -268,10 +270,13 @@ class Booster { ...@@ -268,10 +270,13 @@ class Booster {
} }
void Predict(int num_iteration, int predict_type, int nrow, void Predict(int num_iteration, int predict_type, int nrow, int ncol,
std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun, std::function<std::vector<std::pair<int, double>>(int row_idx)> get_row_fun,
const Config& config, const Config& config,
double* out_result, int64_t* out_len) { double* out_result, int64_t* out_len) {
if (ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).", ncol, boosting_->MaxFeatureIdx() + 1);
}
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
...@@ -647,7 +652,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, ...@@ -647,7 +652,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat,
DatasetLoader loader(config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(), Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()), ncol,
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, total_nrow)); sample_cnt, total_nrow));
} else { } else {
...@@ -687,6 +692,11 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -687,6 +692,11 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameters); auto param = Config::Str2Map(parameters);
Config config; Config config;
config.Set(param); config.Set(param);
...@@ -718,7 +728,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -718,7 +728,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr,
DatasetLoader loader(config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(), Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(num_col),
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt, nrow));
} else { } else {
...@@ -748,9 +758,12 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ...@@ -748,9 +758,12 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
const DatasetHandle reference, const DatasetHandle reference,
DatasetHandle* out) { DatasetHandle* out) {
API_BEGIN(); API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr); auto get_row_fun = *static_cast<std::function<void(int idx, std::vector<std::pair<int, double>>&)>*>(get_row_funptr);
auto param = Config::Str2Map(parameters); auto param = Config::Str2Map(parameters);
Config config; Config config;
config.Set(param); config.Set(param);
...@@ -783,7 +796,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, ...@@ -783,7 +796,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr,
DatasetLoader loader(config, nullptr, 1, nullptr); DatasetLoader loader(config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(), ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(&sample_values).data(),
Common::Vector2Ptr<int>(&sample_idx).data(), Common::Vector2Ptr<int>(&sample_idx).data(),
static_cast<int>(sample_values.size()), static_cast<int>(num_col),
Common::VectorSize<double>(sample_values).data(), Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow)); sample_cnt, nrow));
} else { } else {
...@@ -1299,13 +1312,18 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1299,13 +1312,18 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameter); auto param = Config::Str2Map(parameter);
Config config; Config config;
config.Set(param); config.Set(param);
...@@ -1315,7 +1333,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle, ...@@ -1315,7 +1333,7 @@ int LGBM_BoosterPredictForCSR(BoosterHandle handle,
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
int nrow = static_cast<int>(nindptr - 1); int nrow = static_cast<int>(nindptr - 1);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, static_cast<int>(num_col), get_row_fun,
config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1328,13 +1346,18 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1328,13 +1346,18 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
int data_type, int data_type,
int64_t nindptr, int64_t nindptr,
int64_t nelem, int64_t nelem,
int64_t, int64_t num_col,
int predict_type, int predict_type,
int num_iteration, int num_iteration,
const char* parameter, const char* parameter,
int64_t* out_len, int64_t* out_len,
double* out_result) { double* out_result) {
API_BEGIN(); API_BEGIN();
if (num_col <= 0) {
Log::Fatal("The number of columns should be greater than zero.");
} else if (num_col >= INT32_MAX) {
Log::Fatal("The number of columns should be smaller than INT32_MAX.");
}
auto param = Config::Str2Map(parameter); auto param = Config::Str2Map(parameter);
Config config; Config config;
config.Set(param); config.Set(param);
...@@ -1343,7 +1366,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle, ...@@ -1343,7 +1366,7 @@ int LGBM_BoosterPredictForCSRSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem); auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, nelem);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len); ref_booster->PredictSingleRow(num_iteration, predict_type, static_cast<int32_t>(num_col), get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1395,7 +1418,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, ...@@ -1395,7 +1418,7 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
} }
return one_row; return one_row;
}; };
ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), get_row_fun, config, ref_booster->Predict(num_iteration, predict_type, static_cast<int>(num_row), ncol, get_row_fun, config,
out_result, out_len); out_result, out_len);
API_END(); API_END();
} }
...@@ -1420,7 +1443,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, ...@@ -1420,7 +1443,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, nrow, ncol, data_type, is_row_major);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun,
config, out_result, out_len); config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1444,7 +1467,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, ...@@ -1444,7 +1467,7 @@ int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major); auto get_row_fun = RowPairFunctionFromDenseMatric(data, 1, ncol, data_type, is_row_major);
ref_booster->PredictSingleRow(num_iteration, predict_type, get_row_fun, config, out_result, out_len); ref_booster->PredictSingleRow(num_iteration, predict_type, ncol, get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
...@@ -1468,7 +1491,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle, ...@@ -1468,7 +1491,7 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
} }
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type); auto get_row_fun = RowPairFunctionFromDenseRows(data, ncol, data_type);
ref_booster->Predict(num_iteration, predict_type, nrow, get_row_fun, config, out_result, out_len); ref_booster->Predict(num_iteration, predict_type, nrow, ncol, get_row_fun, config, out_result, out_len);
API_END(); API_END();
} }
......
...@@ -215,12 +215,14 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_ ...@@ -215,12 +215,14 @@ std::vector<std::vector<int>> FastFeatureBundling(const std::vector<std::unique_
void Dataset::Construct( void Dataset::Construct(
std::vector<std::unique_ptr<BinMapper>>* bin_mappers, std::vector<std::unique_ptr<BinMapper>>* bin_mappers,
int num_total_features,
const std::vector<std::vector<double>>& forced_bins, const std::vector<std::vector<double>>& forced_bins,
int** sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col, const int* num_per_col,
size_t total_sample_cnt, size_t total_sample_cnt,
const Config& io_config) { const Config& io_config) {
num_total_features_ = static_cast<int>(bin_mappers->size()); num_total_features_ = num_total_features;
CHECK(num_total_features_ == static_cast<int>(bin_mappers->size()));
sparse_threshold_ = io_config.sparse_threshold; sparse_threshold_ = io_config.sparse_threshold;
// get num_features // get num_features
std::vector<int> used_features; std::vector<int> used_features;
......
...@@ -721,7 +721,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, ...@@ -721,7 +721,7 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
} }
} }
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->Construct(&bin_mappers, forced_bin_bounds, sample_indices, num_per_col, total_sample_size, config_); dataset->Construct(&bin_mappers, num_col, forced_bin_bounds, sample_indices, num_per_col, total_sample_size, config_);
dataset->set_feature_names(feature_names_); dataset->set_feature_names(feature_names_);
return dataset.release(); return dataset.release();
} }
...@@ -897,7 +897,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -897,7 +897,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
if (feature_names_.empty()) { if (feature_names_.empty()) {
// -1 means doesn't use this feature // -1 means doesn't use this feature
dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->TotalColumns() - 1); dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
dataset->used_feature_map_ = std::vector<int>(dataset->num_total_features_, -1); dataset->used_feature_map_ = std::vector<int>(dataset->num_total_features_, -1);
} else { } else {
dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1); dataset->used_feature_map_ = std::vector<int>(feature_names_.size(), -1);
...@@ -1059,7 +1059,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -1059,7 +1059,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
} }
sample_values.clear(); sample_values.clear();
dataset->Construct(&bin_mappers, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(), dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), config_); Common::VectorSize<int>(sample_indices).data(), sample_data.size(), config_);
} }
......
...@@ -89,47 +89,52 @@ void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader* ...@@ -89,47 +89,52 @@ void GetLine(std::stringstream* ss, std::string* line, const VirtualFileReader*
} }
} }
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) { std::vector<std::string> ReadKLineFromFile(const char* filename, bool header, int k) {
auto reader = VirtualFileReader::Make(filename); auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) { if (!reader->Init()) {
Log::Fatal("Data file %s doesn't exist", filename); Log::Fatal("Data file %s doesn't exist.", filename);
} }
std::string line1, line2; std::vector<std::string> ret;
size_t buffer_size = 64 * 1024; std::string cur_line;
const size_t buffer_size = 1024 * 1024;
auto buffer = std::vector<char>(buffer_size); auto buffer = std::vector<char>(buffer_size);
size_t read_len = reader->Read(buffer.data(), buffer_size); size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) { if (read_len <= 0) {
Log::Fatal("Data file %s couldn't be read", filename); Log::Fatal("Data file %s couldn't be read.", filename);
} }
std::string read_str = std::string(buffer.data(), read_len);
std::stringstream tmp_file(std::string(buffer.data(), read_len)); std::stringstream tmp_file(read_str);
if (header) { if (header) {
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size); GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
} }
} }
if (!tmp_file.eof()) { for (int i = 0; i < k; ++i) {
GetLine(&tmp_file, &line1, reader.get(), &buffer, buffer_size); if (!tmp_file.eof()) {
} else { GetLine(&tmp_file, &cur_line, reader.get(), &buffer, buffer_size);
Log::Fatal("Data file %s should have at least one line", filename); ret.push_back(cur_line);
} else {
break;
}
} }
if (!tmp_file.eof()) { if (ret.empty()) {
GetLine(&tmp_file, &line2, reader.get(), &buffer, buffer_size); Log::Fatal("Data file %s should have at least one line.", filename);
} else { } else if (ret.size() == 1) {
Log::Warning("Data file %s only has one line", filename); Log::Warning("Data file %s only has one line.", filename);
} }
int comma_cnt = 0, comma_cnt2 = 0; return ret;
int tab_cnt = 0, tab_cnt2 = 0; }
int colon_cnt = 0, colon_cnt2 = 0;
// Get some statistic from 2 line
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
DataType GetDataType(const std::vector<std::string>& lines, int* num_col) {
DataType type = DataType::INVALID; DataType type = DataType::INVALID;
if (line2.size() == 0) { if (lines.empty()) {
// if only have one line on file return type;
}
int comma_cnt = 0;
int tab_cnt = 0;
int colon_cnt = 0;
GetStatistic(lines[0].c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
if (lines.size() == 1) {
if (colon_cnt > 0) { if (colon_cnt > 0) {
type = DataType::LIBSVM; type = DataType::LIBSVM;
} else if (tab_cnt > 0) { } else if (tab_cnt > 0) {
...@@ -137,34 +142,74 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features ...@@ -137,34 +142,74 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
} else if (comma_cnt > 0) { } else if (comma_cnt > 0) {
type = DataType::CSV; type = DataType::CSV;
} }
} else { }
if (colon_cnt > 0 || colon_cnt2 > 0) { int comma_cnt2 = 0;
type = DataType::LIBSVM; int tab_cnt2 = 0;
} else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { int colon_cnt2 = 0;
type = DataType::TSV; GetStatistic(lines[1].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
CHECK(tab_cnt == tab_cnt2); if (colon_cnt > 0 || colon_cnt2 > 0) {
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { type = DataType::LIBSVM;
type = DataType::CSV; } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
CHECK(comma_cnt == comma_cnt2); type = DataType::TSV;
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
type = DataType::CSV;
}
if (type == DataType::TSV || type == DataType::CSV) {
// valid the type
for (size_t i = 2; i < lines.size(); ++i) {
GetStatistic(lines[i].c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
if (type == DataType::TSV && tab_cnt2 != tab_cnt) {
type = DataType::INVALID;
break;
} else if (type == DataType::CSV && comma_cnt != comma_cnt2) {
type = DataType::INVALID;
break;
}
}
}
if (type == DataType::LIBSVM) {
int max_col_idx = 0;
for (size_t i = 0; i < lines.size(); ++i) {
auto str = Common::Trim(lines[i]);
auto colon_pos = str.find_last_of(":");
auto space_pos = str.find_last_of(" \f\t\v");
auto sub_str = str.substr(space_pos + 1, space_pos - colon_pos - 1);
int cur_idx = 0;
Common::Atoi(sub_str.c_str(), &cur_idx);
max_col_idx = std::max(cur_idx, max_col_idx);
} }
*num_col = max_col_idx + 1;
} else if (type == DataType::CSV) {
*num_col = comma_cnt + 1;
} else if (type == DataType::TSV) {
*num_col = tab_cnt + 1;
} }
return type;
}
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx) {
const int n_read_line = 20;
auto lines = ReadKLineFromFile(filename, header, n_read_line);
int num_col = 0;
DataType type = GetDataType(lines, &num_col);
if (type == DataType::INVALID) { if (type == DataType::INVALID) {
Log::Fatal("Unknown format of training data"); Log::Fatal("Unknown format of training data.");
} }
std::unique_ptr<Parser> ret; std::unique_ptr<Parser> ret;
if (type == DataType::LIBSVM) { if (type == DataType::LIBSVM) {
label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx); label_idx = GetLabelIdxForLibsvm(lines[0], num_features, label_idx);
ret.reset(new LibSVMParser(label_idx)); ret.reset(new LibSVMParser(label_idx, num_col));
} else if (type == DataType::TSV) { } else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(line1, num_features, label_idx); label_idx = GetLabelIdxForTSV(lines[0], num_features, label_idx);
ret.reset(new TSVParser(label_idx, tab_cnt + 1)); ret.reset(new TSVParser(label_idx, num_col));
} else if (type == DataType::CSV) { } else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(line1, num_features, label_idx); label_idx = GetLabelIdxForCSV(lines[0], num_features, label_idx);
ret.reset(new CSVParser(label_idx, comma_cnt + 1)); ret.reset(new CSVParser(label_idx, num_col));
} }
if (label_idx < 0) { if (label_idx < 0) {
Log::Info("Data file %s doesn't contain a label column", filename); Log::Info("Data file %s doesn't contain a label column.", filename);
} }
return ret.release(); return ret.release();
} }
......
...@@ -43,8 +43,8 @@ class CSVParser: public Parser { ...@@ -43,8 +43,8 @@ class CSVParser: public Parser {
} }
} }
inline int TotalColumns() const override { inline int NumFeatures() const override {
return total_columns_; return total_columns_ - (label_idx_ >= 0);
} }
private: private:
...@@ -79,8 +79,8 @@ class TSVParser: public Parser { ...@@ -79,8 +79,8 @@ class TSVParser: public Parser {
} }
} }
inline int TotalColumns() const override { inline int NumFeatures() const override {
return total_columns_; return total_columns_ - (label_idx_ >= 0);
} }
private: private:
...@@ -90,8 +90,8 @@ class TSVParser: public Parser { ...@@ -90,8 +90,8 @@ class TSVParser: public Parser {
class LibSVMParser: public Parser { class LibSVMParser: public Parser {
public: public:
explicit LibSVMParser(int label_idx) explicit LibSVMParser(int label_idx, int total_columns)
:label_idx_(label_idx) { :label_idx_(label_idx), total_columns_(total_columns) {
if (label_idx > 0) { if (label_idx > 0) {
Log::Fatal("Label should be the first column in a LibSVM file"); Log::Fatal("Label should be the first column in a LibSVM file");
} }
...@@ -119,12 +119,13 @@ class LibSVMParser: public Parser { ...@@ -119,12 +119,13 @@ class LibSVMParser: public Parser {
} }
} }
inline int TotalColumns() const override { inline int NumFeatures() const override {
return -1; return total_columns_;
} }
private: private:
int label_idx_ = 0; int label_idx_ = 0;
int total_columns_ = -1;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -105,9 +105,9 @@ def load_from_csr(filename, reference): ...@@ -105,9 +105,9 @@ def load_from_csr(filename, reference):
c_array(ctypes.c_int, csr.indices), c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64, dtype_float64,
len(csr.indptr), ctypes.c_int64(len(csr.indptr)),
len(csr.data), ctypes.c_int64(len(csr.data)),
csr.shape[1], ctypes.c_int64(csr.shape[1]),
c_str('max_bin=15'), c_str('max_bin=15'),
ref, ref,
ctypes.byref(handle)) ctypes.byref(handle))
...@@ -141,9 +141,9 @@ def load_from_csc(filename, reference): ...@@ -141,9 +141,9 @@ def load_from_csc(filename, reference):
c_array(ctypes.c_int, csr.indices), c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64, dtype_float64,
len(csr.indptr), ctypes.c_int64(len(csr.indptr)),
len(csr.data), ctypes.c_int64(len(csr.data)),
csr.shape[0], ctypes.c_int64(csr.shape[0]),
c_str('max_bin=15'), c_str('max_bin=15'),
ref, ref,
ctypes.byref(handle)) ctypes.byref(handle))
......
...@@ -6,6 +6,8 @@ import unittest ...@@ -6,6 +6,8 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
from scipy import sparse
from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
...@@ -53,6 +55,7 @@ class TestBasic(unittest.TestCase): ...@@ -53,6 +55,7 @@ class TestBasic(unittest.TestCase):
# check saved model persistence # check saved model persistence
bst = lgb.Booster(params, model_file="model.txt") bst = lgb.Booster(params, model_file="model.txt")
os.remove("model.txt")
pred_from_model_file = bst.predict(X_test) pred_from_model_file = bst.predict(X_test)
self.assertEqual(len(pred_from_matr), len(pred_from_model_file)) self.assertEqual(len(pred_from_matr), len(pred_from_model_file))
for preds in zip(pred_from_matr, pred_from_model_file): for preds in zip(pred_from_matr, pred_from_model_file):
...@@ -67,6 +70,25 @@ class TestBasic(unittest.TestCase): ...@@ -67,6 +70,25 @@ class TestBasic(unittest.TestCase):
# scores likely to be different, but prediction should still be the same # scores likely to be different, but prediction should still be the same
self.assertEqual(preds[0] > 0, preds[1] > 0) self.assertEqual(preds[0] > 0, preds[1] > 0)
# test that shape is checked during prediction
bad_X_test = X_test[:, 1:]
bad_shape_error_msg = "The number of features in data*"
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, bad_X_test)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csr_matrix(bad_X_test))
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csc_matrix(bad_X_test))
with open(tname, "w+b") as f:
dump_svmlight_file(bad_X_test, y_test, f)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
with open(tname, "w+b") as f:
dump_svmlight_file(X_test, y_test, f, zero_based=False)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
os.remove(tname)
def test_chunked_dataset(self): def test_chunked_dataset(self):
X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=2) X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(True), test_size=0.1, random_state=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment