Commit c060ca75 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine api of constructing from sampling data.

parent 4c7f11aa
...@@ -124,13 +124,14 @@ public: ...@@ -124,13 +124,14 @@ public:
/*! /*!
* \brief Construct feature value to bin mapper according feature values * \brief Construct feature value to bin mapper according feature values
* \param values (Sampled) values of this feature, Note: not include zero. * \param values (Sampled) values of this feature, Note: not include zero.
* \param num_values number of values.
* \param total_sample_cnt number of total sample count, equal with values.size() + num_zeros * \param total_sample_cnt number of total sample count, equal with values.size() + num_zeros
* \param max_bin The maximal number of bin * \param max_bin The maximal number of bin
* \param min_data_in_bin min number of data in one bin * \param min_data_in_bin min number of data in one bin
* \param min_split_data * \param min_split_data
* \param bin_type Type of this bin * \param bin_type Type of this bin
*/ */
void FindBin(std::vector<double>& values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type); void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type);
/*! /*!
* \brief Use specific number of bin to calculate the size of this class * \brief Use specific number of bin to calculate the size of this class
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <LightGBM/export.h> #include <LightGBM/export.h>
typedef void* ArrayHandle;
typedef void* DatasetHandle; typedef void* DatasetHandle;
typedef void* BoosterHandle; typedef void* BoosterHandle;
...@@ -53,52 +52,25 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -53,52 +52,25 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
DatasetHandle* out); DatasetHandle* out);
/*! /*!
* \brief create a empty dataset by sampling matrix, if num_sample_row == num_total_row, will construct this dataset. * \brief create a empty dataset by sampling data.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function. * \param sample_data sampled data, grouped by the column.
* \param data pointer to the data space * \param sample_indices indices of sampled data.
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param num_sample_row number of rows
* \param ncol number columns * \param ncol number columns
* \param num_per_col Size of each sampling column
* \param num_sample_row Number of sampled rows
* \param num_total_row number of total rows * \param num_total_row number of total rows
* \param parameters additional parameters * \param parameters additional parameters
* \param out created dataset * \param out created dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int data_type, int** sample_indices,
int32_t num_sample_row, int32_t ncol,
int32_t ncol, const int* num_per_col,
int32_t num_total_row, int32_t num_sample_row,
const char* parameters, int32_t num_total_row,
DatasetHandle* out); const char* parameters,
DatasetHandle* out);
/*!
* \brief create a empty dataset by sampling CSR data, if num_sample_row == num_total_row, will construct this dataset.
* Need call LGBM_DatasetPushRows/LGBM_DatasetPushRowsByCSR after calling this function.
* \param indptr pointer to row headers
* \param indptr_type type of indptr, can be C_API_DTYPE_INT32 or C_API_DTYPE_INT64
* \param indices findex
* \param data fvalue
* \param data_type type of data pointer, can be C_API_DTYPE_FLOAT32 or C_API_DTYPE_FLOAT64
* \param nindptr number of rows in the matrix + 1
* \param n_sample_elem number of nonzero elements in the matrix
* \param num_col number of columns
* \param num_total_row number of total rows
* \param parameters additional parameters
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t nindptr,
int64_t n_sample_elem,
int64_t num_col,
int64_t num_total_row,
const char* parameters,
DatasetHandle* out);
/*! /*!
* \brief create a empty dataset by reference Dataset * \brief create a empty dataset by reference Dataset
...@@ -769,10 +741,4 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \ ...@@ -769,10 +741,4 @@ catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
catch(...) { return LGBM_APIHandleException("unknown exception"); } \ catch(...) { return LGBM_APIHandleException("unknown exception"); } \
return 0; return 0;
LIGHTGBM_C_EXPORT int LGBM_AllocateArray(int64_t len, int type, ArrayHandle* out);
LIGHTGBM_C_EXPORT int LGBM_CopyToArray(ArrayHandle arr, int type, int64_t start_idx, const void* src, int64_t len);
LIGHTGBM_C_EXPORT int LGBM_FreeArray(ArrayHandle arr, int type);
#endif // LIGHTGBM_C_API_H_ #endif // LIGHTGBM_C_API_H_
...@@ -114,19 +114,19 @@ public: ...@@ -114,19 +114,19 @@ public:
std::string label_column = ""; std::string label_column = "";
/*! \brief Index or column name of weight, < 0 means not used /*! \brief Index or column name of weight, < 0 means not used
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string weight_column = ""; std::string weight_column = "";
/*! \brief Index or column name of group/query id, < 0 means not used /*! \brief Index or column name of group/query id, < 0 means not used
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string group_column = ""; std::string group_column = "";
/*! \brief ignored features, separate by ',' /*! \brief ignored features, separate by ','
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string ignore_column = ""; std::string ignore_column = "";
/*! \brief specific categorical columns, Note:only support for integer type categorical /*! \brief specific categorical columns, Note:only support for integer type categorical
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it dosen't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string categorical_column = ""; std::string categorical_column = "";
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override; LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
...@@ -398,7 +398,8 @@ struct ParameterAlias { ...@@ -398,7 +398,8 @@ struct ParameterAlias {
{ "topk", "top_k" }, { "topk", "top_k" },
{ "reg_alpha", "lambda_l1" }, { "reg_alpha", "lambda_l1" },
{ "reg_lambda", "lambda_l2" }, { "reg_lambda", "lambda_l2" },
{ "num_classes", "num_class" } { "num_classes", "num_class" },
{ "unbalanced_sets", "is_unbalance" }
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -286,7 +286,8 @@ public: ...@@ -286,7 +286,8 @@ public:
void Construct( void Construct(
std::vector<std::unique_ptr<BinMapper>>& bin_mappers, std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
const std::vector<std::vector<int>>& sample_non_zero_indices, int** sample_non_zero_indices,
const int* num_per_col,
size_t total_sample_cnt, size_t total_sample_cnt,
const IOConfig& io_config); const IOConfig& io_config);
......
...@@ -20,8 +20,8 @@ public: ...@@ -20,8 +20,8 @@ public:
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data); LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* CostructFromSampleData(std::vector<std::vector<double>>& sample_values, LIGHTGBM_EXPORT Dataset* CostructFromSampleData(double** sample_values,
std::vector<std::vector<int>>& sample_indices, int** sample_indices, int num_col, const int* num_per_col,
size_t total_sample_size, data_size_t num_data); size_t total_sample_size, data_size_t num_data);
/*! \brief Disable copy */ /*! \brief Disable copy */
......
...@@ -424,6 +424,24 @@ inline static double ApproximateHessianWithGaussian(const double y, const double ...@@ -424,6 +424,24 @@ inline static double ApproximateHessianWithGaussian(const double y, const double
return w * std::exp(-(x - b) * (x - b) / (2.0 * c * c)) * a / (c * std::sqrt(2 * pi)); return w * std::exp(-(x - b) * (x - b) / (2.0 * c * c)) * a / (c * std::sqrt(2 * pi));
} }
template <typename T>
inline static T** Vector2Ptr(std::vector<std::vector<T>>& data) {
T** ptr = new T*[data.size()];
for (size_t i = 0; i < data.size(); ++i) {
ptr[i] = data[i].data();
}
return ptr;
}
template <typename T>
inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& data) {
std::vector<int> ret(data.size());
for (size_t i = 0; i < data.size(); ++i) {
ret[i] = static_cast<int>(data[i].size());
}
return ret;
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
...@@ -310,85 +310,27 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -310,85 +310,27 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledMat(const void* data,
int data_type,
int32_t num_sample_row,
int32_t ncol,
int32_t num_total_row,
const char* parameters,
DatasetHandle* out) {
if (num_sample_row == num_total_row) {
return LGBM_DatasetCreateFromMat(data, data_type, num_total_row, ncol, 1, parameters, nullptr, out);
} else {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromDenseMatric(data, num_sample_row, ncol, data_type, 1);
std::vector<std::vector<double>> sample_values(ncol);
std::vector<std::vector<int>> sample_idx(ncol);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (size_t idx = 0; idx < row.size(); ++idx) {
if (std::fabs(row[idx]) > kEpsilon) {
sample_values[idx].emplace_back(row[idx]);
sample_idx[idx].emplace_back(i);
}
}
}
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
}
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledCSR(const void* indptr, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromSampledColumn(double** sample_data,
int indptr_type, int** sample_indices,
const int32_t* indices, int32_t ncol,
const void* data, const int* num_per_col,
int data_type, int32_t num_sample_row,
int64_t nindptr, int32_t num_total_row,
int64_t n_sample_elem, const char* parameters,
int64_t num_col, DatasetHandle* out) {
int64_t num_total_row, API_BEGIN();
const char* parameters, auto param = ConfigBase::Str2Map(parameters);
DatasetHandle* out) { IOConfig io_config;
if (nindptr - 1 == num_total_row) { io_config.Set(param);
return LGBM_DatasetCreateFromCSR(indptr, indptr_type, indices, data, DatasetLoader loader(io_config, nullptr, 1, nullptr);
data_type, nindptr, n_sample_elem, num_col, parameters, nullptr, out); *out = loader.CostructFromSampleData(sample_data, sample_indices, ncol, num_per_col,
} else { num_sample_row,
API_BEGIN(); static_cast<data_size_t>(num_total_row));
auto param = ConfigBase::Str2Map(parameters); API_END();
IOConfig io_config;
io_config.Set(param);
auto get_row_fun = RowFunctionFromCSR(indptr, indptr_type, indices, data, data_type, nindptr, n_sample_elem);
int32_t num_sample_row = static_cast<int32_t>(nindptr - 1);
std::vector<std::vector<double>> sample_values(num_col);
std::vector<std::vector<int>> sample_idx(num_col);
for (int i = 0; i < num_sample_row; ++i) {
auto row = get_row_fun(i);
for (std::pair<int, double>& inner_data : row) {
if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
sample_values.resize(inner_data.first + 1);
sample_idx.resize(inner_data.first + 1);
}
if (std::fabs(inner_data.second) > kEpsilon) {
sample_values[inner_data.first].emplace_back(inner_data.second);
sample_idx[inner_data.first].emplace_back(i);
}
}
}
CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr);
*out = loader.CostructFromSampleData(sample_values, sample_idx,
num_sample_row,
static_cast<data_size_t>(num_total_row));
API_END();
}
} }
LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference, LIGHTGBM_C_EXPORT int LGBM_DatasetCreateByReference(const DatasetHandle reference,
int64_t num_total_row, int64_t num_total_row,
DatasetHandle* out) { DatasetHandle* out) {
...@@ -480,7 +422,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data, ...@@ -480,7 +422,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromMat(const void* data,
} }
} }
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
...@@ -539,7 +485,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr, ...@@ -539,7 +485,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSR(const void* indptr,
} }
CHECK(num_col >= static_cast<int>(sample_values.size())); CHECK(num_col >= static_cast<int>(sample_values.size()));
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
...@@ -593,7 +543,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, ...@@ -593,7 +543,11 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr,
} }
} }
DatasetLoader loader(io_config, nullptr, 1, nullptr); DatasetLoader loader(io_config, nullptr, 1, nullptr);
ret.reset(loader.CostructFromSampleData(sample_values, sample_idx, sample_cnt, nrow)); ret.reset(loader.CostructFromSampleData(Common::Vector2Ptr<double>(sample_values),
Common::Vector2Ptr<int>(sample_idx),
static_cast<int>(sample_values.size()),
Common::VectorSize<double>(sample_values).data(),
sample_cnt, nrow));
} else { } else {
ret.reset(new Dataset(nrow)); ret.reset(new Dataset(nrow));
ret->CreateValid( ret->CreateValid(
...@@ -1123,54 +1077,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle, ...@@ -1123,54 +1077,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
API_END(); API_END();
} }
LIGHTGBM_C_EXPORT int LGBM_AllocateArray(int64_t len, int type, ArrayHandle* out) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
*out = new float[len];
} else if (type == C_API_DTYPE_FLOAT64) {
*out = new double[len];
} else if (type == C_API_DTYPE_INT32) {
*out = new int32_t[len];
} else if (type == C_API_DTYPE_INT64) {
*out = new int64_t[len];
}
API_END();
}
template<typename T>
void Copy(T* dst, const T* src, int64_t len) {
std::memcpy(dst, src, sizeof(T) * len);
}
LIGHTGBM_C_EXPORT int LGBM_CopyToArray(ArrayHandle arr, int type, int64_t start_idx, const void* src, int64_t len) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
Copy<float>(static_cast<float*>(arr) + start_idx, static_cast<const float*>(src), len);
} else if (type == C_API_DTYPE_FLOAT64) {
Copy<double>(static_cast<double*>(arr) + start_idx, static_cast<const double*>(src), len);
} else if (type == C_API_DTYPE_INT32) {
Copy<int32_t>(static_cast<int32_t*>(arr) + start_idx, static_cast<const int32_t*>(src), len);
} else if (type == C_API_DTYPE_INT64) {
Copy<int64_t>(static_cast<int64_t*>(arr) + start_idx, static_cast<const int64_t*>(src), len);
}
API_END();
}
LIGHTGBM_C_EXPORT int LGBM_FreeArray(ArrayHandle arr, int type) {
API_BEGIN();
if (type == C_API_DTYPE_FLOAT32) {
delete[] static_cast<float*>(arr);
} else if (type == C_API_DTYPE_FLOAT64) {
delete[] static_cast<double*>(arr);
} else if (type == C_API_DTYPE_INT32) {
delete[] static_cast<int32_t*>(arr);
} else if (type == C_API_DTYPE_INT64) {
delete[] static_cast<int64_t*>(arr);
}
API_END();
}
// ---- start of some help functions // ---- start of some help functions
std::function<std::vector<double>(int row_idx)> std::function<std::vector<double>(int row_idx)>
......
...@@ -68,36 +68,35 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin ...@@ -68,36 +68,35 @@ bool NeedFilter(std::vector<int>& cnt_in_bin, int total_cnt, int filter_cnt, Bin
return true; return true;
} }
void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt, void BinMapper::FindBin(double* values, int num_sample_values, size_t total_sample_cnt,
int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type) { int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type) {
bin_type_ = bin_type; bin_type_ = bin_type;
default_bin_ = 0; default_bin_ = 0;
std::vector<double>& raw_values = values; int zero_cnt = static_cast<int>(total_sample_cnt - num_sample_values);
int zero_cnt = static_cast<int>(total_sample_cnt - raw_values.size());
// find distinct_values first // find distinct_values first
std::vector<double> distinct_values; std::vector<double> distinct_values;
std::vector<int> counts; std::vector<int> counts;
std::sort(raw_values.begin(), raw_values.end()); std::sort(values, values + num_sample_values);
// push zero in the front // push zero in the front
if (raw_values.empty() || (raw_values[0] > 0.0f && zero_cnt > 0)) { if (num_sample_values || (values[0] > 0.0f && zero_cnt > 0)) {
distinct_values.push_back(0.0f); distinct_values.push_back(0.0f);
counts.push_back(zero_cnt); counts.push_back(zero_cnt);
} }
if (!raw_values.empty()) { if (num_sample_values > 0) {
distinct_values.push_back(raw_values[0]); distinct_values.push_back(values[0]);
counts.push_back(1); counts.push_back(1);
} }
for (size_t i = 1; i < raw_values.size(); ++i) { for (int i = 1; i < num_sample_values; ++i) {
if (raw_values[i] != raw_values[i - 1]) { if (values[i] != values[i - 1]) {
if (raw_values[i - 1] < 0.0f && raw_values[i] > 0.0f) { if (values[i - 1] < 0.0f && values[i] > 0.0f) {
distinct_values.push_back(0.0f); distinct_values.push_back(0.0f);
counts.push_back(zero_cnt); counts.push_back(zero_cnt);
} }
distinct_values.push_back(raw_values[i]); distinct_values.push_back(values[i]);
counts.push_back(1); counts.push_back(1);
} else { } else {
++counts.back(); ++counts.back();
...@@ -105,20 +104,20 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt, ...@@ -105,20 +104,20 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt,
} }
// push zero in the back // push zero in the back
if (!raw_values.empty() && raw_values.back() < 0.0f && zero_cnt > 0) { if (num_sample_values > 0 && values[num_sample_values - 1] < 0.0f && zero_cnt > 0) {
distinct_values.push_back(0.0f); distinct_values.push_back(0.0f);
counts.push_back(zero_cnt); counts.push_back(zero_cnt);
} }
min_val_ = distinct_values.front(); min_val_ = distinct_values.front();
max_val_ = distinct_values.back(); max_val_ = distinct_values.back();
std::vector<int> cnt_in_bin; std::vector<int> cnt_in_bin;
int num_values = static_cast<int>(distinct_values.size()); int num_distinct_values = static_cast<int>(distinct_values.size());
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
if (num_values <= max_bin) { if (num_distinct_values <= max_bin) {
// use distinct value is enough // use distinct value is enough
bin_upper_bound_.clear(); bin_upper_bound_.clear();
int cur_cnt_inbin = 0; int cur_cnt_inbin = 0;
for (int i = 0; i < num_values - 1; ++i) { for (int i = 0; i < num_distinct_values - 1; ++i) {
cur_cnt_inbin += counts[i]; cur_cnt_inbin += counts[i];
if (cur_cnt_inbin >= min_data_in_bin) { if (cur_cnt_inbin >= min_data_in_bin) {
bin_upper_bound_.push_back((distinct_values[i] + distinct_values[i + 1]) / 2); bin_upper_bound_.push_back((distinct_values[i] + distinct_values[i + 1]) / 2);
...@@ -137,14 +136,14 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt, ...@@ -137,14 +136,14 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt,
} }
double mean_bin_size = static_cast<double>(total_sample_cnt) / max_bin; double mean_bin_size = static_cast<double>(total_sample_cnt) / max_bin;
if (zero_cnt > mean_bin_size) { if (zero_cnt > mean_bin_size) {
int non_zero_cnt = static_cast<int>(raw_values.size()); int non_zero_cnt = num_sample_values;
max_bin = std::min(max_bin, 1 + static_cast<int>(non_zero_cnt / min_data_in_bin)); max_bin = std::min(max_bin, 1 + static_cast<int>(non_zero_cnt / min_data_in_bin));
} }
// mean size for one bin // mean size for one bin
int rest_bin_cnt = max_bin; int rest_bin_cnt = max_bin;
int rest_sample_cnt = static_cast<int>(total_sample_cnt); int rest_sample_cnt = static_cast<int>(total_sample_cnt);
std::vector<bool> is_big_count_value(num_values, false); std::vector<bool> is_big_count_value(num_distinct_values, false);
for (int i = 0; i < num_values; ++i) { for (int i = 0; i < num_distinct_values; ++i) {
if (counts[i] >= mean_bin_size) { if (counts[i] >= mean_bin_size) {
is_big_count_value[i] = true; is_big_count_value[i] = true;
--rest_bin_cnt; --rest_bin_cnt;
...@@ -158,7 +157,7 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt, ...@@ -158,7 +157,7 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt,
int bin_cnt = 0; int bin_cnt = 0;
lower_bounds[bin_cnt] = distinct_values[0]; lower_bounds[bin_cnt] = distinct_values[0];
int cur_cnt_inbin = 0; int cur_cnt_inbin = 0;
for (int i = 0; i < num_values - 1; ++i) { for (int i = 0; i < num_distinct_values - 1; ++i) {
if (!is_big_count_value[i]) { if (!is_big_count_value[i]) {
rest_sample_cnt -= counts[i]; rest_sample_cnt -= counts[i];
} }
...@@ -207,7 +206,7 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt, ...@@ -207,7 +206,7 @@ void BinMapper::FindBin(std::vector<double>& values, size_t total_sample_cnt,
} }
// sort by counts // sort by counts
Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true); Common::SortForPair<int, int>(counts_int, distinct_values_int, 0, true);
// will ingore the categorical of small counts // will ignore the categorical of small counts
const int cut_cnt = static_cast<int>(total_sample_cnt * 0.98f); const int cut_cnt = static_cast<int>(total_sample_cnt * 0.98f);
categorical_2_bin_.clear(); categorical_2_bin_.clear();
bin_2_categorical_.clear(); bin_2_categorical_.clear();
......
...@@ -45,7 +45,8 @@ std::vector<std::vector<int>> NoGroup( ...@@ -45,7 +45,8 @@ std::vector<std::vector<int>> NoGroup(
void Dataset::Construct( void Dataset::Construct(
std::vector<std::unique_ptr<BinMapper>>& bin_mappers, std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
const std::vector<std::vector<int>>&, int**,
const int*,
size_t, size_t,
const IOConfig& io_config) { const IOConfig& io_config) {
num_total_features_ = static_cast<int>(bin_mappers.size()); num_total_features_ = static_cast<int>(bin_mappers.size());
......
...@@ -177,7 +177,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -177,7 +177,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
dataset->metadata_.Init(filename); dataset->metadata_.Init(filename);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data to memory // read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
dataset->num_data_ = static_cast<data_size_t>(text_data.size()); dataset->num_data_ = static_cast<data_size_t>(text_data.size());
// sample data // sample data
auto sample_data = SampleTextDataFromMemory(text_data); auto sample_data = SampleTextDataFromMemory(text_data);
...@@ -263,11 +263,11 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -263,11 +263,11 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) { Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
FILE* file; FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb"); fopen_s(&file, bin_filename, "rb");
#else #else
file = fopen(bin_filename, "rb"); file = fopen(bin_filename, "rb");
#endif #endif
dataset->data_filename_ = data_filename; dataset->data_filename_ = data_filename;
if (file == NULL) { if (file == NULL) {
Log::Fatal("Could not read binary data from %s", bin_filename); Log::Fatal("Could not read binary data from %s", bin_filename);
...@@ -276,7 +276,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -276,7 +276,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// buffer to read binary file // buffer to read binary file
size_t buffer_size = 16 * 1024 * 1024; size_t buffer_size = 16 * 1024 * 1024;
auto buffer = std::vector<char>(buffer_size); auto buffer = std::vector<char>(buffer_size);
// check token // check token
size_t size_of_token = std::strlen(Dataset::binary_file_token); size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file);
...@@ -356,7 +356,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -356,7 +356,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// group_feature_start_ // group_feature_start_
const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr); const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
dataset->group_feature_start_.clear(); dataset->group_feature_start_.clear();
for (int i = 0; i < dataset->num_groups_ ; ++i) { for (int i = 0; i < dataset->num_groups_; ++i) {
dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]); dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
} }
mem_ptr += sizeof(int) * (dataset->num_groups_); mem_ptr += sizeof(int) * (dataset->num_groups_);
...@@ -464,10 +464,10 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -464,10 +464,10 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt); Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
} }
dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>( dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
new FeatureGroup(buffer.data(), new FeatureGroup(buffer.data(),
*num_global_data, *num_global_data,
*used_data_indices) *used_data_indices)
)); ));
} }
dataset->feature_groups_.shrink_to_fit(); dataset->feature_groups_.shrink_to_fit();
fclose(file); fclose(file);
...@@ -475,22 +475,22 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -475,22 +475,22 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
return dataset.release(); return dataset.release();
} }
Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& sample_values, Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
std::vector<std::vector<int>>& sample_indices, int** sample_indices, int num_col, const int* num_per_col,
size_t total_sample_size, data_size_t num_data) { size_t total_sample_size, data_size_t num_data) {
std::vector<std::unique_ptr<BinMapper>> bin_mappers(sample_values.size()); std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_col);
// fill feature_names_ if not header // fill feature_names_ if not header
if (feature_names_.empty()) { if (feature_names_.empty()) {
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < num_col; ++i) {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "Column_" << i; str_buf << "Column_" << i;
feature_names_.push_back(str_buf.str()); feature_names_.push_back(str_buf.str());
} }
} }
const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / num_data * sample_values.size()); const data_size_t filter_cnt = static_cast<data_size_t>(static_cast<double>(0.95 * io_config_.min_data_in_leaf) / num_data * num_col);
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < num_col; ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
bin_mappers[i] = nullptr; bin_mappers[i] = nullptr;
continue; continue;
...@@ -500,12 +500,12 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>& ...@@ -500,12 +500,12 @@ Dataset* DatasetLoader::CostructFromSampleData(std::vector<std::vector<double>>&
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i], total_sample_size, bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type); io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type);
} }
auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data)); auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
dataset->feature_names_ = feature_names_; dataset->feature_names_ = feature_names_;
dataset->Construct(bin_mappers, sample_indices, total_sample_size, io_config_); dataset->Construct(bin_mappers, sample_indices, num_per_col, total_sample_size, io_config_);
return dataset.release(); return dataset.release();
} }
...@@ -521,7 +521,7 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) { ...@@ -521,7 +521,7 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) {
} }
if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) { if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_, Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
static_cast<int>(dataset->feature_names_.size())); static_cast<int>(dataset->feature_names_.size()));
} }
bool is_feature_order_by_group = true; bool is_feature_order_by_group = true;
int last_group = -1; int last_group = -1;
...@@ -547,8 +547,8 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) { ...@@ -547,8 +547,8 @@ void DatasetLoader::CheckDataset(const Dataset* dataset) {
} }
std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata, std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
int rank, int num_machines, int* num_global_data, int rank, int num_machines, int* num_global_data,
std::vector<data_size_t>* used_data_indices) { std::vector<data_size_t>* used_data_indices) {
TextReader<data_size_t> text_reader(filename, io_config_.has_header); TextReader<data_size_t> text_reader(filename, io_config_.has_header);
used_data_indices->clear(); used_data_indices->clear();
if (num_machines == 1 || io_config_.is_pre_partition) { if (num_machines == 1 || io_config_.is_pre_partition) {
...@@ -706,7 +706,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -706,7 +706,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
// if only one machine, find bin locally // if only one machine, find bin locally
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) { for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
bin_mappers[i] = nullptr; bin_mappers[i] = nullptr;
...@@ -717,8 +717,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -717,8 +717,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[i], sample_data.size(), bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type); sample_data.size(), io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type);
} }
} else { } else {
// if have multi-machines, need to find bin distributed // if have multi-machines, need to find bin distributed
...@@ -738,7 +738,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -738,7 +738,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
start[i + 1] = start[i] + len[i]; start[i + 1] = start[i] + len[i];
} }
len[num_machines - 1] = total_num_feature - start[num_machines - 1]; len[num_machines - 1] = total_num_feature - start[num_machines - 1];
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < len[rank]; ++i) { for (int i = 0; i < len[rank]; ++i) {
if (ignore_features_.count(start[rank] + i) > 0) { if (ignore_features_.count(start[rank] + i) > 0) {
continue; continue;
...@@ -748,8 +748,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -748,8 +748,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
bin_type = BinType::CategoricalBin; bin_type = BinType::CategoricalBin;
} }
bin_mappers[i].reset(new BinMapper()); bin_mappers[i].reset(new BinMapper());
bin_mappers[i]->FindBin(sample_values[start[rank] + i], sample_data.size(), bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast<int>(sample_values[i].size()),
io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type); sample_data.size(), io_config_.max_bin, io_config_.min_data_in_bin, filter_cnt, bin_type);
} }
// get max_bin // get max_bin
int local_max_bin = 0; int local_max_bin = 0;
...@@ -764,7 +764,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -764,7 +764,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
Network::Allreduce(reinterpret_cast<char*>(&local_max_bin), Network::Allreduce(reinterpret_cast<char*>(&local_max_bin),
sizeof(local_max_bin), sizeof(local_max_bin), sizeof(local_max_bin), sizeof(local_max_bin),
reinterpret_cast<char*>(&max_bin), reinterpret_cast<char*>(&max_bin),
[] (const char* src, char* dst, int len) { [](const char* src, char* dst, int len) {
int used_size = 0; int used_size = 0;
const int type_size = sizeof(int); const int type_size = sizeof(int);
const int *p1; const int *p1;
...@@ -788,7 +788,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -788,7 +788,7 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
auto output_buffer = std::vector<char>(buffer_size); auto output_buffer = std::vector<char>(buffer_size);
// find local feature bins and copy to buffer // find local feature bins and copy to buffer
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < len[rank]; ++i) { for (int i = 0; i < len[rank]; ++i) {
if (ignore_features_.count(start[rank] + i) > 0) { if (ignore_features_.count(start[rank] + i) > 0) {
continue; continue;
...@@ -815,7 +815,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, ...@@ -815,7 +815,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
} }
} }
sample_values.clear(); sample_values.clear();
dataset->Construct(bin_mappers, sample_indices, sample_data.size(), io_config_); dataset->Construct(bin_mappers, Common::Vector2Ptr<int>(sample_indices),
Common::VectorSize<int>(sample_indices).data(), sample_data.size(), io_config_);
} }
/*! \brief Extract local features from memory */ /*! \brief Extract local features from memory */
...@@ -824,7 +825,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -824,7 +825,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
double tmp_label = 0.0f; double tmp_label = 0.0f;
if (predict_fun_ == nullptr) { if (predict_fun_ == nullptr) {
// if doesn't need to prediction with initial model // if doesn't need to prediction with initial model
#pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < dataset->num_data_; ++i) { for (data_size_t i = 0; i < dataset->num_data_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
oneline_features.clear(); oneline_features.clear();
...@@ -857,7 +858,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -857,7 +858,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
} else { } else {
// if need to prediction with initial model // if need to prediction with initial model
std::vector<double> init_score(dataset->num_data_ * num_class_); std::vector<double> init_score(dataset->num_data_ * num_class_);
#pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < dataset->num_data_; ++i) { for (data_size_t i = 0; i < dataset->num_data_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
oneline_features.clear(); oneline_features.clear();
...@@ -882,7 +883,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat ...@@ -882,7 +883,7 @@ void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>& text_dat
// if is used feature // if is used feature
int group = dataset->feature2group_[feature_idx]; int group = dataset->feature2group_[feature_idx];
int sub_feature = dataset->feature2subfeature_[feature_idx]; int sub_feature = dataset->feature2subfeature_[feature_idx];
dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second); dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
} else { } else {
if (inner_data.first == weight_idx_) { if (inner_data.first == weight_idx_) {
dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second)); dataset->metadata_.SetWeightAt(i, static_cast<float>(inner_data.second));
...@@ -911,7 +912,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* ...@@ -911,7 +912,7 @@ void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser*
(data_size_t start_idx, const std::vector<std::string>& lines) { (data_size_t start_idx, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
double tmp_label = 0.0f; double tmp_label = 0.0f;
#pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
oneline_features.clear(); oneline_features.clear();
...@@ -968,23 +969,23 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { ...@@ -968,23 +969,23 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
FILE* file; FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb"); fopen_s(&file, bin_filename.c_str(), "rb");
#else #else
file = fopen(bin_filename.c_str(), "rb"); file = fopen(bin_filename.c_str(), "rb");
#endif #endif
if (file == NULL) { if (file == NULL) {
bin_filename = std::string(filename); bin_filename = std::string(filename);
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb"); fopen_s(&file, bin_filename.c_str(), "rb");
#else #else
file = fopen(bin_filename.c_str(), "rb"); file = fopen(bin_filename.c_str(), "rb");
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Fatal("cannot open data file %s", bin_filename.c_str()); Log::Fatal("cannot open data file %s", bin_filename.c_str());
} }
} }
size_t buffer_size = 256; size_t buffer_size = 256;
auto buffer = std::vector<char>(buffer_size); auto buffer = std::vector<char>(buffer_size);
...@@ -992,8 +993,8 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { ...@@ -992,8 +993,8 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
size_t size_of_token = std::strlen(Dataset::binary_file_token); size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file);
fclose(file); fclose(file);
if (read_cnt == size_of_token if (read_cnt == size_of_token
&& std::string(buffer.data()) == std::string(Dataset::binary_file_token)) { && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
return bin_filename; return bin_filename;
} else { } else {
return std::string(); return std::string();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment