"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "dd3be8dead4115cf7227b0ba40447f4228cbfcd9"
Commit 0612dcc0 authored by Guolin Ke's avatar Guolin Ke
Browse files

support get subset of dataset

parent 522e9993
...@@ -57,7 +57,7 @@ DllExport const char* LGBM_GetLastError(); ...@@ -57,7 +57,7 @@ DllExport const char* LGBM_GetLastError();
* \param out a loaded dataset * \param out a loaded dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromFile(const char* filename, DllExport int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
...@@ -77,7 +77,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, ...@@ -77,7 +77,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
* \param out created dataset * \param out created dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -104,7 +104,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -104,7 +104,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
* \param out created dataset * \param out created dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -128,7 +128,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -128,7 +128,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
* \param out created dataset * \param out created dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
*/ */
DllExport int LGBM_CreateDatasetFromMat(const void* data, DllExport int LGBM_DatasetCreateFromMat(const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
...@@ -137,6 +137,22 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -137,6 +137,22 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
/*!
* \brief Create subset of a data
* \param full_data the full dataset
* \param used_row_indices Indices used in subset
* \param num_used_row_indices len of used_row_indices
* \param parameters additional parameters
* \param out subset of data
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_DatasetGetSubset(
const DatesetHandle* full_data,
const int32_t* used_row_indices,
const int32_t num_used_row_indices,
const char* parameters,
DatesetHandle* out);
/*! /*!
* \brief free space for dataset * \brief free space for dataset
* \return 0 when succeed, -1 when failure happens * \return 0 when succeed, -1 when failure happens
......
...@@ -330,6 +330,8 @@ public: ...@@ -330,6 +330,8 @@ public:
} }
} }
Dataset* Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const;
void FinishLoad(); void FinishLoad();
bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element); bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);
...@@ -396,8 +398,6 @@ private: ...@@ -396,8 +398,6 @@ private:
int num_class_; int num_class_;
/*! \brief Store some label level data*/ /*! \brief Store some label level data*/
Metadata metadata_; Metadata metadata_;
/*! \brief True if dataset is loaded from binary file */
bool is_loading_from_binfile_;
/*! \brief index of label column */ /*! \brief index of label column */
int label_idx_ = 0; int label_idx_ = 0;
/*! \brief store feature names */ /*! \brief store feature names */
......
...@@ -80,6 +80,9 @@ public: ...@@ -80,6 +80,9 @@ public:
unsigned int bin = bin_mapper_->ValueToBin(value); unsigned int bin = bin_mapper_->ValueToBin(value);
bin_data_->Push(tid, line_idx, bin); bin_data_->Push(tid, line_idx, bin);
} }
inline void PushBin(int tid, data_size_t line_idx, unsigned int bin) {
bin_data_->Push(tid, line_idx, bin);
}
inline void FinishLoad() { bin_data_->FinishLoad(); } inline void FinishLoad() { bin_data_->FinishLoad(); }
/*! \brief Index of this feature */ /*! \brief Index of this feature */
inline int feature_index() const { return feature_index_; } inline int feature_index() const { return feature_index_; }
......
...@@ -438,7 +438,7 @@ class Dataset(object): ...@@ -438,7 +438,7 @@ class Dataset(object):
if params["has_header"].lower() == "true" or params["header"].lower() == "true": if params["has_header"].lower() == "true" or params["header"].lower() == "true":
self.data_has_header = True self.data_has_header = True
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_CreateDatasetFromFile( _safe_call(_LIB.LGBM_DatasetCreateFromFile(
c_str(data), c_str(data),
c_str(params_str), c_str(params_str),
ref_dataset, ref_dataset,
...@@ -521,7 +521,7 @@ class Dataset(object): ...@@ -521,7 +521,7 @@ class Dataset(object):
data = np.array(mat.reshape(mat.size), dtype=np.float32) data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data) ptr_data, type_ptr_data = c_float_array(data)
_safe_call(_LIB.LGBM_CreateDatasetFromMat( _safe_call(_LIB.LGBM_DatasetCreateFromMat(
ptr_data, ptr_data,
type_ptr_data, type_ptr_data,
mat.shape[0], mat.shape[0],
...@@ -542,7 +542,7 @@ class Dataset(object): ...@@ -542,7 +542,7 @@ class Dataset(object):
ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr) ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data) ptr_data, type_ptr_data = c_float_array(csr.data)
_safe_call(_LIB.LGBM_CreateDatasetFromCSR( _safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr, ptr_indptr,
type_ptr_indptr, type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
......
...@@ -204,7 +204,7 @@ DllExport const char* LGBM_GetLastError() { ...@@ -204,7 +204,7 @@ DllExport const char* LGBM_GetLastError() {
return LastErrorMsg(); return LastErrorMsg();
} }
DllExport int LGBM_CreateDatasetFromFile(const char* filename, DllExport int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
...@@ -223,7 +223,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename, ...@@ -223,7 +223,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
API_END(); API_END();
} }
DllExport int LGBM_CreateDatasetFromMat(const void* data, DllExport int LGBM_DatasetCreateFromMat(const void* data,
int data_type, int data_type,
int32_t nrow, int32_t nrow,
int32_t ncol, int32_t ncol,
...@@ -272,7 +272,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data, ...@@ -272,7 +272,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
API_END(); API_END();
} }
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type, int indptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -334,7 +334,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr, ...@@ -334,7 +334,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
API_END(); API_END();
} }
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type, int col_ptr_type,
const int32_t* indices, const int32_t* indices,
const void* data, const void* data,
...@@ -384,6 +384,26 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr, ...@@ -384,6 +384,26 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
API_END(); API_END();
} }
DllExport int LGBM_DatasetGetSubset(
const DatesetHandle* full_data,
const int32_t* used_row_indices,
const int32_t num_used_row_indices,
const char* parameters,
DatesetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto full_dataset = reinterpret_cast<const Dataset*>(*full_data);
auto ret = std::unique_ptr<Dataset>(
full_dataset->Subset(used_row_indices,
num_used_row_indices,
io_config.is_enable_sparse));
ret->FinishLoad();
*out = ret.release();
API_END();
}
DllExport int LGBM_DatasetFree(DatesetHandle handle) { DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN(); API_BEGIN();
delete reinterpret_cast<Dataset*>(handle); delete reinterpret_cast<Dataset*>(handle);
......
...@@ -19,13 +19,11 @@ const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______ ...@@ -19,13 +19,11 @@ const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______
Dataset::Dataset() { Dataset::Dataset() {
num_class_ = 1; num_class_ = 1;
num_data_ = 0; num_data_ = 0;
is_loading_from_binfile_ = false;
} }
Dataset::Dataset(data_size_t num_data, int num_class) { Dataset::Dataset(data_size_t num_data, int num_class) {
num_class_ = num_class; num_class_ = num_class;
num_data_ = num_data; num_data_ = num_data;
is_loading_from_binfile_ = false;
metadata_.Init(num_data_, num_class_, -1, -1); metadata_.Init(num_data_, num_class_, -1, -1);
} }
...@@ -59,6 +57,18 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars ...@@ -59,6 +57,18 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars
feature_names_ = dataset->feature_names_; feature_names_ = dataset->feature_names_;
} }
Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const {
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_indices, num_class_));
ret->CopyFeatureMapperFrom(this, is_enable_sparse);
#pragma omp parallel for schedule(guided)
for (int fidx = 0; fidx < num_features_; ++fidx) {
auto iterator = features_[fidx]->bin_data()->GetIterator(0);
for (data_size_t i = 0; i < num_used_indices; ++i) {
ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i]));
}
}
}
bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) { bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) {
std::string name(field_name); std::string name(field_name);
name = Common::Trim(name); name = Common::Trim(name);
...@@ -118,15 +128,27 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int** ...@@ -118,15 +128,27 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
} }
void Dataset::SaveBinaryFile(const char* bin_filename) { void Dataset::SaveBinaryFile(const char* bin_filename) {
bool is_file_existed = false;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
if (file != NULL) {
is_file_existed = true;
Log::Warning("File %s existed, cannot save binary to it", bin_filename);
fclose(file);
}
if (!is_loading_from_binfile_) { if (!is_file_existed) {
std::string bin_filename_str(data_filename_); std::string bin_filename_str(data_filename_);
// if not pass a filename, just append ".bin" of original file // if not pass a filename, just append ".bin" of original file
if (bin_filename == nullptr || bin_filename[0] == '\0') { if (bin_filename == nullptr || bin_filename[0] == '\0') {
bin_filename_str.append(".bin"); bin_filename_str.append(".bin");
bin_filename = bin_filename_str.c_str(); bin_filename = bin_filename_str.c_str();
} }
FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename, "wb"); fopen_s(&file, bin_filename, "wb");
#else #else
......
...@@ -142,18 +142,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -142,18 +142,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Please use an additional query file or pre-partition the data"); Please use an additional query file or pre-partition the data");
} }
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_)); auto dataset = std::unique_ptr<Dataset>(new Dataset());
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
data_size_t num_global_data = 0; data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices; std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset());
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
auto bin_filename = CheckCanLoadFromBin(filename); auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) { if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data to memory // read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices);
...@@ -197,18 +197,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -197,18 +197,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) { Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
data_size_t num_global_data = 0; data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices; std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
auto bin_filename = CheckCanLoadFromBin(filename); auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) { if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data in memory // read data in memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
...@@ -407,7 +407,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int ...@@ -407,7 +407,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int
} }
dataset->features_.shrink_to_fit(); dataset->features_.shrink_to_fit();
fclose(file); fclose(file);
dataset->is_loading_from_binfile_ = true;
return dataset.release(); return dataset.release();
} }
......
...@@ -16,6 +16,8 @@ def LoadDll(): ...@@ -16,6 +16,8 @@ def LoadDll():
LIB = LoadDll() LIB = LoadDll()
LIB.LGBM_GetLastError.restype = ctypes.c_char_p
dtype_float32 = 0 dtype_float32 = 0
dtype_float64 = 1 dtype_float64 = 1
dtype_int32 = 2 dtype_int32 = 2
...@@ -33,9 +35,10 @@ def test_load_from_file(filename, reference): ...@@ -33,9 +35,10 @@ def test_load_from_file(filename, reference):
if reference != None: if reference != None:
ref = ctypes.byref(reference) ref = ctypes.byref(reference)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromFile(c_str(filename), LIB.LGBM_DatasetCreateFromFile(c_str(filename),
c_str('max_bin=15'), c_str('max_bin=15'),
ref, ctypes.byref(handle) ) ref, ctypes.byref(handle) )
print(LIB.LGBM_GetLastError())
num_data = ctypes.c_long() num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) ) LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long() num_feature = ctypes.c_long()
...@@ -46,15 +49,6 @@ def test_load_from_file(filename, reference): ...@@ -46,15 +49,6 @@ def test_load_from_file(filename, reference):
def test_save_to_binary(handle, filename): def test_save_to_binary(handle, filename):
LIB.LGBM_DatasetSaveBinary(handle, c_str(filename)) LIB.LGBM_DatasetSaveBinary(handle, c_str(filename))
def test_load_from_binary(filename):
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromBinaryFile(c_str(filename), ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle
def test_load_from_csr(filename, reference): def test_load_from_csr(filename, reference):
data = [] data = []
...@@ -72,7 +66,7 @@ def test_load_from_csr(filename, reference): ...@@ -72,7 +66,7 @@ def test_load_from_csr(filename, reference):
if reference != None: if reference != None:
ref = ctypes.byref(reference) ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr), LIB.LGBM_DatasetCreateFromCSR(c_array(ctypes.c_int, csr.indptr),
dtype_int32, dtype_int32,
c_array(ctypes.c_int, csr.indices), c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
...@@ -107,7 +101,7 @@ def test_load_from_csc(filename, reference): ...@@ -107,7 +101,7 @@ def test_load_from_csc(filename, reference):
if reference != None: if reference != None:
ref = ctypes.byref(reference) ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSC(c_array(ctypes.c_int, csr.indptr), LIB.LGBM_DatasetCreateFromCSC(c_array(ctypes.c_int, csr.indptr),
dtype_int32, dtype_int32,
c_array(ctypes.c_int, csr.indices), c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
...@@ -142,7 +136,7 @@ def test_load_from_mat(filename, reference): ...@@ -142,7 +136,7 @@ def test_load_from_mat(filename, reference):
if reference != None: if reference != None:
ref = ctypes.byref(reference) ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)), LIB.LGBM_DatasetCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64, dtype_float64,
mat.shape[0], mat.shape[0],
mat.shape[1], mat.shape[1],
...@@ -170,7 +164,7 @@ def test_dataset(): ...@@ -170,7 +164,7 @@ def test_dataset():
test_free_dataset(test) test_free_dataset(test)
test_save_to_binary(train, 'train.binary.bin') test_save_to_binary(train, 'train.binary.bin')
test_free_dataset(train) test_free_dataset(train)
train = test_load_from_binary('train.binary.bin') train = test_load_from_file('train.binary.bin', None)
test_free_dataset(train) test_free_dataset(train)
def test_booster(): def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None) train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment