Commit c9d7ec24 authored by Guolin Ke's avatar Guolin Ke
Browse files

add suport for csc format

parent 494fef34
...@@ -79,6 +79,32 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -79,6 +79,32 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
/*!
* \brief create a dataset from CSC format
* \param col_ptr pointer to col headers
* \param indices findex
* \param data fvalue
* \param float_type 0 for float_32 1 for float_64
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
......
...@@ -267,6 +267,15 @@ public: ...@@ -267,6 +267,15 @@ public:
} }
} }
inline void PushOneCol(int tid, data_size_t col_idx, const std::vector<std::pair<int, double>>& feature_values) {
int feature_idx = used_feature_map_[col_idx];
if (feature_idx >= 0) {
for (auto& inner_data : feature_values) {
features_[feature_idx]->PushData(tid, inner_data.first, inner_data.second);
}
}
}
inline void SetNumData(data_size_t num_data) { inline void SetNumData(data_size_t num_data) {
num_data_ = num_data; num_data_ = num_data;
} }
......
...@@ -462,6 +462,52 @@ GetRowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void* ...@@ -462,6 +462,52 @@ GetRowFunctionFromCSR(const int32_t* indptr, const int32_t* indices, const void*
} }
} }
inline std::function<std::vector<std::pair<int, double>>(int idx)>
GetColFunctionFromCSC(const int32_t* col_ptr, const int32_t* indices, const void* data, int float_type, uint64_t ncol_ptr, uint64_t nelem) {
if (float_type == 0) {
const float* dptr = reinterpret_cast<const float*>(data);
return [&col_ptr, &indices, &dptr, &ncol_ptr, &nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int32_t start = col_ptr[idx];
int32_t end = col_ptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
} else {
const double* dptr = reinterpret_cast<const double*>(data);
return [&col_ptr, &indices, &dptr, &ncol_ptr, &nelem](int idx) {
CHECK(idx + 1 < ncol_ptr);
std::vector<std::pair<int, double>> ret;
int32_t start = col_ptr[idx];
int32_t end = col_ptr[idx + 1];
CHECK(start >= 0 && end < nelem);
for (int32_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], dptr[i]);
}
return ret;
};
}
}
inline std::vector<double> SampleFromOneColumn(const std::vector<std::pair<int, double>>& data, const std::vector<size_t>& indices) {
size_t j = 0;
std::vector<double> ret;
for (auto row_idx : indices) {
while (j < data.size() && data[j].first < row_idx) {
++j;
}
if (j < data.size() && data[j].first == row_idx) {
ret.push_back(data[j].second);
} else {
ret.push_back(0);
}
}
return ret;
}
} // namespace Common } // namespace Common
......
...@@ -193,6 +193,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -193,6 +193,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
int float_type, int float_type,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_col,
const char* parameters, const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out) { DatesetHandle* out) {
...@@ -229,6 +230,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -229,6 +230,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
sample_values[inner_data.first][i] = inner_data.second; sample_values[inner_data.first][i] = inner_data.second;
} }
} }
CHECK(num_col >= sample_values.size());
ret = loader.CostructFromSampleData(sample_values, nrow); ret = loader.CostructFromSampleData(sample_values, nrow);
} else { } else {
ret = new Dataset(); ret = new Dataset();
...@@ -249,6 +251,54 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr, ...@@ -249,6 +251,54 @@ DllExport int LGBM_CreateDatasetFromCSR(const int32_t* indptr,
return 0; return 0;
} }
DllExport int LGBM_CreateDatasetFromCSC(const int32_t* col_ptr,
const int32_t* indices,
const void* data,
int float_type,
uint64_t ncol_ptr,
uint64_t nelem,
uint64_t num_row,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out) {
OverallConfig config;
config.LoadFromString(parameters);
DatasetLoader loader(config.io_config, nullptr);
Dataset* ret = nullptr;
auto get_col_fun = Common::GetColFunctionFromCSC(col_ptr, indices, data, float_type, ncol_ptr, nelem);
int32_t nrow = static_cast<int32_t>(num_row);
if (reference == nullptr) {
Log::Warning("Construct from CSC format is not efficient");
// sample data first
Random rand(config.io_config.data_random_seed);
const size_t sample_cnt = static_cast<size_t>(nrow < config.io_config.bin_construct_sample_cnt ? nrow : config.io_config.bin_construct_sample_cnt);
auto sample_indices = rand.Sample(nrow, sample_cnt);
std::vector<std::vector<double>> sample_values(ncol_ptr - 1);
#pragma omp parallel for schedule(guided)
for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
auto cur_col = get_col_fun(i);
sample_values[i] = Common::SampleFromOneColumn(cur_col, sample_indices);
}
ret = loader.CostructFromSampleData(sample_values, nrow);
} else {
ret = new Dataset();
// need to set num_data first
ret->SetNumData(nrow);
reinterpret_cast<const Dataset*>(*reference)->CopyFeatureMetadataTo(ret, config.io_config.is_enable_sparse);
}
#pragma omp parallel for schedule(guided)
for (int i = 0; i < ncol_ptr - 1; ++i) {
const int tid = omp_get_thread_num();
auto one_col = get_col_fun(i);
ret->PushOneCol(tid, i, one_col);
}
ret->FinishLoad();
*out = ret;
return 0;
}
DllExport int LGBM_DatasetFree(DatesetHandle* handle) { DllExport int LGBM_DatasetFree(DatesetHandle* handle) {
auto dataset = reinterpret_cast<Dataset*>(*handle); auto dataset = reinterpret_cast<Dataset*>(*handle);
delete dataset; delete dataset;
...@@ -268,7 +318,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle, ...@@ -268,7 +318,7 @@ DllExport int LGBM_DatasetSetField(DatesetHandle handle,
uint64_t num_element, uint64_t num_element,
int type) { int type) {
auto dataset = reinterpret_cast<Dataset*>(handle); auto dataset = reinterpret_cast<Dataset*>(handle);
dataset->SetField(field_name, field_data, num_element, type); dataset->SetField(field_name, field_data, static_cast<int32_t>(num_element), type);
return 0; return 0;
} }
......
...@@ -30,16 +30,16 @@ ...@@ -30,16 +30,16 @@
</PropertyGroup> </PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'"> <PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Debug_mpi|x64'">
<PlatformToolset>v120</PlatformToolset> <PlatformToolset>v140</PlatformToolset>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> <PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<PlatformToolset>v120</PlatformToolset> <PlatformToolset>v140</PlatformToolset>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> <PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<PlatformToolset>v120</PlatformToolset> <PlatformToolset>v140</PlatformToolset>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Release_mpi|x64'"> <PropertyGroup Label="Configuration" Condition="'$(Configuration)|$(Platform)'=='Release_mpi|x64'">
<PlatformToolset>v120</PlatformToolset> <PlatformToolset>v140</PlatformToolset>
</PropertyGroup> </PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings"> <ImportGroup Label="ExtensionSettings">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment