Commit 0612dcc0 authored by Guolin Ke's avatar Guolin Ke
Browse files

support get subset of dataset

parent 522e9993
......@@ -57,7 +57,7 @@ DllExport const char* LGBM_GetLastError();
* \param out a loaded dataset
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromFile(const char* filename,
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......@@ -77,7 +77,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
......@@ -104,7 +104,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
......@@ -128,7 +128,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
* \param out created dataset
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromMat(const void* data,
DllExport int LGBM_DatasetCreateFromMat(const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
......@@ -137,6 +137,22 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief Create subset of a data
* \param full_data the full dataset
* \param used_row_indices Indices used in subset
* \param num_used_row_indices len of used_row_indices
* \param parameters additional parameters
* \param out subset of data
* \return 0 when succeed, -1 when failure happens
*/
DllExport int LGBM_DatasetGetSubset(
const DatesetHandle* full_data,
const int32_t* used_row_indices,
const int32_t num_used_row_indices,
const char* parameters,
DatesetHandle* out);
/*!
* \brief free space for dataset
* \return 0 when succeed, -1 when failure happens
......
......@@ -330,6 +330,8 @@ public:
}
}
Dataset* Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const;
void FinishLoad();
bool SetFloatField(const char* field_name, const float* field_data, data_size_t num_element);
......@@ -396,8 +398,6 @@ private:
int num_class_;
/*! \brief Store some label level data*/
Metadata metadata_;
/*! \brief True if dataset is loaded from binary file */
bool is_loading_from_binfile_;
/*! \brief index of label column */
int label_idx_ = 0;
/*! \brief store feature names */
......
......@@ -80,6 +80,9 @@ public:
unsigned int bin = bin_mapper_->ValueToBin(value);
bin_data_->Push(tid, line_idx, bin);
}
inline void PushBin(int tid, data_size_t line_idx, unsigned int bin) {
bin_data_->Push(tid, line_idx, bin);
}
inline void FinishLoad() { bin_data_->FinishLoad(); }
/*! \brief Index of this feature */
inline int feature_index() const { return feature_index_; }
......
......@@ -438,7 +438,7 @@ class Dataset(object):
if params["has_header"].lower() == "true" or params["header"].lower() == "true":
self.data_has_header = True
self.handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_CreateDatasetFromFile(
_safe_call(_LIB.LGBM_DatasetCreateFromFile(
c_str(data),
c_str(params_str),
ref_dataset,
......@@ -521,7 +521,7 @@ class Dataset(object):
data = np.array(mat.reshape(mat.size), dtype=np.float32)
ptr_data, type_ptr_data = c_float_array(data)
_safe_call(_LIB.LGBM_CreateDatasetFromMat(
_safe_call(_LIB.LGBM_DatasetCreateFromMat(
ptr_data,
type_ptr_data,
mat.shape[0],
......@@ -542,7 +542,7 @@ class Dataset(object):
ptr_indptr, type_ptr_indptr = c_int_array(csr.indptr)
ptr_data, type_ptr_data = c_float_array(csr.data)
_safe_call(_LIB.LGBM_CreateDatasetFromCSR(
_safe_call(_LIB.LGBM_DatasetCreateFromCSR(
ptr_indptr,
type_ptr_indptr,
csr.indices.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)),
......
......@@ -204,7 +204,7 @@ DllExport const char* LGBM_GetLastError() {
return LastErrorMsg();
}
DllExport int LGBM_CreateDatasetFromFile(const char* filename,
DllExport int LGBM_DatasetCreateFromFile(const char* filename,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
......@@ -223,7 +223,7 @@ DllExport int LGBM_CreateDatasetFromFile(const char* filename,
API_END();
}
DllExport int LGBM_CreateDatasetFromMat(const void* data,
DllExport int LGBM_DatasetCreateFromMat(const void* data,
int data_type,
int32_t nrow,
int32_t ncol,
......@@ -272,7 +272,7 @@ DllExport int LGBM_CreateDatasetFromMat(const void* data,
API_END();
}
DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
DllExport int LGBM_DatasetCreateFromCSR(const void* indptr,
int indptr_type,
const int32_t* indices,
const void* data,
......@@ -334,7 +334,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const void* indptr,
API_END();
}
DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
DllExport int LGBM_DatasetCreateFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
......@@ -384,6 +384,26 @@ DllExport int LGBM_CreateDatasetFromCSC(const void* col_ptr,
API_END();
}
DllExport int LGBM_DatasetGetSubset(
const DatesetHandle* full_data,
const int32_t* used_row_indices,
const int32_t num_used_row_indices,
const char* parameters,
DatesetHandle* out) {
API_BEGIN();
auto param = ConfigBase::Str2Map(parameters);
IOConfig io_config;
io_config.Set(param);
auto full_dataset = reinterpret_cast<const Dataset*>(*full_data);
auto ret = std::unique_ptr<Dataset>(
full_dataset->Subset(used_row_indices,
num_used_row_indices,
io_config.is_enable_sparse));
ret->FinishLoad();
*out = ret.release();
API_END();
}
DllExport int LGBM_DatasetFree(DatesetHandle handle) {
API_BEGIN();
delete reinterpret_cast<Dataset*>(handle);
......
......@@ -19,13 +19,11 @@ const char* Dataset::binary_file_token = "______LightGBM_Binary_File_Token______
Dataset::Dataset() {
num_class_ = 1;
num_data_ = 0;
is_loading_from_binfile_ = false;
}
Dataset::Dataset(data_size_t num_data, int num_class) {
num_class_ = num_class;
num_data_ = num_data;
is_loading_from_binfile_ = false;
metadata_.Init(num_data_, num_class_, -1, -1);
}
......@@ -59,6 +57,18 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset, bool is_enable_spars
feature_names_ = dataset->feature_names_;
}
Dataset* Dataset::Subset(const data_size_t* used_indices, data_size_t num_used_indices, bool is_enable_sparse) const {
auto ret = std::unique_ptr<Dataset>(new Dataset(num_used_indices, num_class_));
ret->CopyFeatureMapperFrom(this, is_enable_sparse);
#pragma omp parallel for schedule(guided)
for (int fidx = 0; fidx < num_features_; ++fidx) {
auto iterator = features_[fidx]->bin_data()->GetIterator(0);
for (data_size_t i = 0; i < num_used_indices; ++i) {
ret->features_[fidx]->PushBin(0, i, iterator->Get(used_indices[i]));
}
}
}
bool Dataset::SetFloatField(const char* field_name, const float* field_data, data_size_t num_element) {
std::string name(field_name);
name = Common::Trim(name);
......@@ -118,15 +128,27 @@ bool Dataset::GetIntField(const char* field_name, int64_t* out_len, const int**
}
void Dataset::SaveBinaryFile(const char* bin_filename) {
bool is_file_existed = false;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
if (file != NULL) {
is_file_existed = true;
Log::Warning("File %s existed, cannot save binary to it", bin_filename);
fclose(file);
}
if (!is_loading_from_binfile_) {
if (!is_file_existed) {
std::string bin_filename_str(data_filename_);
// if not pass a filename, just append ".bin" of original file
if (bin_filename == nullptr || bin_filename[0] == '\0') {
bin_filename_str.append(".bin");
bin_filename = bin_filename_str.c_str();
}
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "wb");
#else
......
......@@ -142,18 +142,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Please use an additional query file or pre-partition the data");
}
}
auto dataset = std::unique_ptr<Dataset>(new Dataset());
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset());
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
if (!io_config_.use_two_round_loading) {
// read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines,&num_global_data, &used_data_indices);
......@@ -197,18 +197,18 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset());
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, io_config_.has_header, 0, label_idx_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset());
dataset->data_filename_ = filename;
dataset->num_class_ = io_config_.num_class;
dataset->metadata_.Init(filename, dataset->num_class_);
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
if (!io_config_.use_two_round_loading) {
// read data in memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
......@@ -407,7 +407,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* bin_filename, int rank, int
}
dataset->features_.shrink_to_fit();
fclose(file);
dataset->is_loading_from_binfile_ = true;
return dataset.release();
}
......
......@@ -16,6 +16,8 @@ def LoadDll():
LIB = LoadDll()
LIB.LGBM_GetLastError.restype = ctypes.c_char_p
dtype_float32 = 0
dtype_float64 = 1
dtype_int32 = 2
......@@ -33,9 +35,10 @@ def test_load_from_file(filename, reference):
if reference != None:
ref = ctypes.byref(reference)
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromFile(c_str(filename),
LIB.LGBM_DatasetCreateFromFile(c_str(filename),
c_str('max_bin=15'),
ref, ctypes.byref(handle) )
print(LIB.LGBM_GetLastError())
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
......@@ -46,15 +49,6 @@ def test_load_from_file(filename, reference):
def test_save_to_binary(handle, filename):
LIB.LGBM_DatasetSaveBinary(handle, c_str(filename))
def test_load_from_binary(filename):
handle = ctypes.c_void_p()
LIB.LGBM_CreateDatasetFromBinaryFile(c_str(filename), ctypes.byref(handle) )
num_data = ctypes.c_long()
LIB.LGBM_DatasetGetNumData(handle, ctypes.byref(num_data) )
num_feature = ctypes.c_long()
LIB.LGBM_DatasetGetNumFeature(handle, ctypes.byref(num_feature) )
print ('#data:%d #feature:%d' %(num_data.value, num_feature.value) )
return handle
def test_load_from_csr(filename, reference):
data = []
......@@ -72,7 +66,7 @@ def test_load_from_csr(filename, reference):
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSR(c_array(ctypes.c_int, csr.indptr),
LIB.LGBM_DatasetCreateFromCSR(c_array(ctypes.c_int, csr.indptr),
dtype_int32,
c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
......@@ -107,7 +101,7 @@ def test_load_from_csc(filename, reference):
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromCSC(c_array(ctypes.c_int, csr.indptr),
LIB.LGBM_DatasetCreateFromCSC(c_array(ctypes.c_int, csr.indptr),
dtype_int32,
c_array(ctypes.c_int, csr.indices),
csr.data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
......@@ -142,7 +136,7 @@ def test_load_from_mat(filename, reference):
if reference != None:
ref = ctypes.byref(reference)
LIB.LGBM_CreateDatasetFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
LIB.LGBM_DatasetCreateFromMat(data.ctypes.data_as(ctypes.POINTER(ctypes.c_void_p)),
dtype_float64,
mat.shape[0],
mat.shape[1],
......@@ -170,7 +164,7 @@ def test_dataset():
test_free_dataset(test)
test_save_to_binary(train, 'train.binary.bin')
test_free_dataset(train)
train = test_load_from_binary('train.binary.bin')
train = test_load_from_file('train.binary.bin', None)
test_free_dataset(train)
def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment