Commit 7e186a57 authored by ebernhardson's avatar ebernhardson Committed by Guolin Ke
Browse files

Experimental support for HDFS (#1243)

* Read and write datsets from hdfs.
* Only enabled when cmake is run with -DUSE_HDFS:BOOL=TRUE
* Introduces VirtualFile(Reader|Writer) to asbtract VFS differences
parent 7501faa6
...@@ -12,6 +12,7 @@ OPTION(USE_MPI "MPI based parallel learning" OFF) ...@@ -12,6 +12,7 @@ OPTION(USE_MPI "MPI based parallel learning" OFF)
OPTION(USE_OPENMP "Enable OpenMP" ON) OPTION(USE_OPENMP "Enable OpenMP" ON)
OPTION(USE_GPU "Enable GPU-acclerated training (EXPERIMENTAL)" OFF) OPTION(USE_GPU "Enable GPU-acclerated training (EXPERIMENTAL)" OFF)
OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF) OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF)
OPTION(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "4.8.2") if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "4.8.2")
...@@ -86,6 +87,15 @@ if(USE_GPU) ...@@ -86,6 +87,15 @@ if(USE_GPU)
ADD_DEFINITIONS(-DUSE_GPU) ADD_DEFINITIONS(-DUSE_GPU)
endif(USE_GPU) endif(USE_GPU)
if(USE_HDFS)
find_package(JNI REQUIRED)
find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED)
find_library(HDFS_LIB NAMES hdfs REQUIRED)
include_directories(${HDFS_INCLUDE_DIR})
ADD_DEFINITIONS(-DUSE_HDFS)
SET(HDFS_CXX_LIBRARIES ${HDFS_LIB} ${JAVA_JVM_LIBRARY})
endif(USE_HDFS)
if(UNIX OR MINGW OR CYGWIN) if(UNIX OR MINGW OR CYGWIN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas")
endif() endif()
...@@ -173,6 +183,11 @@ if(USE_GPU) ...@@ -173,6 +183,11 @@ if(USE_GPU)
TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES}) TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES})
endif(USE_GPU) endif(USE_GPU)
if(USE_HDFS)
TARGET_LINK_LIBRARIES(lightgbm ${HDFS_CXX_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${HDFS_CXX_LIBRARIES})
endif(USE_HDFS)
if(WIN32 AND (MINGW OR CYGWIN)) if(WIN32 AND (MINGW OR CYGWIN))
TARGET_LINK_LIBRARIES(lightgbm Ws2_32) TARGET_LINK_LIBRARIES(lightgbm Ws2_32)
TARGET_LINK_LIBRARIES(_lightgbm Ws2_32) TARGET_LINK_LIBRARIES(_lightgbm Ws2_32)
......
...@@ -98,7 +98,7 @@ public: ...@@ -98,7 +98,7 @@ public:
* \brief Save binary data to file * \brief Save binary data to file
* \param file File want to write * \param file File want to write
*/ */
void SaveBinaryToFile(FILE* file) const; void SaveBinaryToFile(const VirtualFileWriter* writer) const;
/*! /*!
* \brief Mapping bin into feature value * \brief Mapping bin into feature value
* \param bin * \param bin
...@@ -308,7 +308,7 @@ public: ...@@ -308,7 +308,7 @@ public:
* \brief Save binary data to file * \brief Save binary data to file
* \param file File want to write * \param file File want to write
*/ */
virtual void SaveBinaryToFile(FILE* file) const = 0; virtual void SaveBinaryToFile(const VirtualFileWriter* writer) const = 0;
/*! /*!
* \brief Load from memory * \brief Load from memory
......
...@@ -99,7 +99,7 @@ public: ...@@ -99,7 +99,7 @@ public:
* \brief Save binary data to file * \brief Save binary data to file
* \param file File want to write * \param file File want to write
*/ */
void SaveBinaryToFile(FILE* file) const; void SaveBinaryToFile(const VirtualFileWriter* writer) const;
/*! /*!
* \brief Get sizes in byte of this object * \brief Get sizes in byte of this object
......
...@@ -191,13 +191,13 @@ public: ...@@ -191,13 +191,13 @@ public:
* \brief Save binary data to file * \brief Save binary data to file
* \param file File want to write * \param file File want to write
*/ */
void SaveBinaryToFile(FILE* file) const { void SaveBinaryToFile(const VirtualFileWriter* writer) const {
fwrite(&is_sparse_, sizeof(is_sparse_), 1, file); writer->Write(&is_sparse_, sizeof(is_sparse_));
fwrite(&num_feature_, sizeof(num_feature_), 1, file); writer->Write(&num_feature_, sizeof(num_feature_));
for (int i = 0; i < num_feature_; ++i) { for (int i = 0; i < num_feature_; ++i) {
bin_mappers_[i]->SaveBinaryToFile(file); bin_mappers_[i]->SaveBinaryToFile(writer);
} }
bin_data_->SaveBinaryToFile(file); bin_data_->SaveBinaryToFile(writer);
} }
/*! /*!
* \brief Get sizes in byte of this object * \brief Get sizes in byte of this object
......
#ifndef LIGHTGBM_UTILS_FILE_IO_H_
#define LIGHTGBM_UTILS_FILE_IO_H_
#include <memory>
namespace LightGBM{
/*!
* \brief An interface for writing files from buffers
*/
struct VirtualFileWriter {
virtual ~VirtualFileWriter() {};
/*!
* \brief Initialize the writer
* \return True when the file is available for writes
*/
virtual bool Init() = 0;
/*!
* \brief Append buffer to file
* \param data Buffer to write from
* \param bytes Number of bytes to write from buffer
* \return Number of bytes written
*/
virtual size_t Write(const void* data, size_t bytes) const = 0;
/*!
* \brief Create appropriate writer for filename
* \param filename Filename of the data
* \return File writer instance
*/
static std::unique_ptr<VirtualFileWriter> Make(const std::string& filename);
/*!
* \brief Check filename existence
* \param filename Filename of the data
* \return True when the file exists
*/
static bool Exists(const std::string& filename);
};
/**
* \brief An interface for reading files into buffers
*/
struct VirtualFileReader {
/*!
* \brief Constructor
* \param filename Filename of the data
*/
virtual ~VirtualFileReader() {};
/*!
* \brief Initialize the reader
* \return True when the file is available for read
*/
virtual bool Init() = 0;
/*!
* \brief Read data into buffer
* \param buffer Buffer to read data into
* \param bytes Number of bytes to read
* \return Number of bytes read
*/
virtual size_t Read(void* buffer, size_t bytes) const = 0;
/*!
* \brief Create appropriate reader for filename
* \param filename Filename of the data
* \return File reader instance
*/
static std::unique_ptr<VirtualFileReader> Make(const std::string& filename);
};
} // namespace LightGBM
#endif // LightGBM_UTILS_FILE_IO_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <thread> #include <thread>
#include <memory> #include <memory>
#include <algorithm> #include <algorithm>
#include "file_io.h"
namespace LightGBM{ namespace LightGBM{
...@@ -23,14 +24,8 @@ public: ...@@ -23,14 +24,8 @@ public:
* \process_fun Process function * \process_fun Process function
*/ */
static size_t Read(const char* filename, int skip_bytes, const std::function<size_t (const char*, size_t)>& process_fun) { static size_t Read(const char* filename, int skip_bytes, const std::function<size_t (const char*, size_t)>& process_fun) {
FILE* file; auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) {
#ifdef _MSC_VER
fopen_s(&file, filename, "rb");
#else
file = fopen(filename, "rb");
#endif
if (file == NULL) {
return 0; return 0;
} }
size_t cnt = 0; size_t cnt = 0;
...@@ -42,16 +37,17 @@ public: ...@@ -42,16 +37,17 @@ public:
size_t read_cnt = 0; size_t read_cnt = 0;
if (skip_bytes > 0) { if (skip_bytes > 0) {
// skip first k bytes // skip first k bytes
read_cnt = fread(buffer_process.data(), 1, skip_bytes, file); read_cnt = reader->Read(buffer_process.data(), skip_bytes);
} }
// read first block // read first block
read_cnt = fread(buffer_process.data(), 1, buffer_size, file); read_cnt = reader->Read(buffer_process.data(), buffer_size);
size_t last_read_cnt = 0; size_t last_read_cnt = 0;
while (read_cnt > 0) { while (read_cnt > 0) {
// start read thread // start read thread
std::thread read_worker = std::thread( std::thread read_worker = std::thread(
[file, &buffer_read, buffer_size, &last_read_cnt] { [&reader, &buffer_read, buffer_size, &last_read_cnt] {
last_read_cnt = fread(buffer_read.data(), 1, buffer_size, file); last_read_cnt = reader->Read(buffer_read.data(), buffer_size);
} }
); );
// start process // start process
...@@ -62,8 +58,6 @@ public: ...@@ -62,8 +58,6 @@ public:
std::swap(buffer_process, buffer_read); std::swap(buffer_process, buffer_read);
read_cnt = last_read_cnt; read_cnt = last_read_cnt;
} }
// close file
fclose(file);
return cnt; return cnt;
} }
......
...@@ -28,36 +28,29 @@ public: ...@@ -28,36 +28,29 @@ public:
TextReader(const char* filename, bool is_skip_first_line): TextReader(const char* filename, bool is_skip_first_line):
filename_(filename), is_skip_first_line_(is_skip_first_line){ filename_(filename), is_skip_first_line_(is_skip_first_line){
if (is_skip_first_line_) { if (is_skip_first_line_) {
FILE* file; auto reader = VirtualFileReader::Make(filename);
#ifdef _MSC_VER if (!reader->Init()) {
fopen_s(&file, filename, "r");
#else
file = fopen(filename, "r");
#endif
if (file == NULL) {
Log::Fatal("Could not open %s", filename); Log::Fatal("Could not open %s", filename);
} }
std::stringstream str_buf; std::stringstream str_buf;
int read_c = -1; char read_c;
read_c = fgetc(file); size_t nread = reader->Read(&read_c, 1);
while (read_c != EOF) { while (nread == 1) {
char tmp_ch = static_cast<char>(read_c); if (read_c == '\n' || read_c == '\r') {
if (tmp_ch == '\n' || tmp_ch == '\r') {
break; break;
} }
str_buf << tmp_ch; str_buf << read_c;
++skip_bytes_; ++skip_bytes_;
read_c = fgetc(file); nread = reader->Read(&read_c, 1);
} }
if (static_cast<char>(read_c) == '\r') { if (read_c == '\r') {
read_c = fgetc(file); reader->Read(&read_c, 1);
++skip_bytes_; ++skip_bytes_;
} }
if (static_cast<char>(read_c) == '\n') { if (read_c == '\n') {
read_c = fgetc(file); reader->Read(&read_c, 1);
++skip_bytes_; ++skip_bytes_;
} }
fclose(file);
first_line_ = str_buf.str(); first_line_ = str_buf.str();
Log::Debug("Skipped header \"%s\" in file %s", first_line_.c_str(), filename_); Log::Debug("Skipped header \"%s\" in file %s", first_line_.c_str(), filename_);
} }
...@@ -151,25 +144,18 @@ public: ...@@ -151,25 +144,18 @@ public:
std::vector<char> ReadContent(size_t* out_len) { std::vector<char> ReadContent(size_t* out_len) {
std::vector<char> ret; std::vector<char> ret;
*out_len = 0; *out_len = 0;
FILE* file; auto reader = VirtualFileReader::Make(filename_);
#ifdef _MSC_VER if (!reader->Init()) {
fopen_s(&file, filename_, "rb");
#else
file = fopen(filename_, "rb");
#endif
if (file == NULL) {
return ret; return ret;
} }
const size_t buffer_size = 16 * 1024 * 1024; const size_t buffer_size = 16 * 1024 * 1024;
auto buffer_read = std::vector<char>(buffer_size); auto buffer_read = std::vector<char>(buffer_size);
size_t read_cnt = 0; size_t read_cnt = 0;
do { do {
read_cnt = fread(buffer_read.data(), 1, buffer_size, file); read_cnt = reader->Read(buffer_read.data(), buffer_size);
ret.insert(ret.end(), buffer_read.begin(), buffer_read.begin() + read_cnt); ret.insert(ret.end(), buffer_read.begin(), buffer_read.begin() + read_cnt);
*out_len += read_cnt; *out_len += read_cnt;
} while (read_cnt > 0); } while (read_cnt > 0);
// close file
fclose(file);
return ret; return ret;
} }
......
...@@ -128,15 +128,8 @@ public: ...@@ -128,15 +128,8 @@ public:
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, const char* result_filename, bool has_header) { void Predict(const char* data_filename, const char* result_filename, bool has_header) {
FILE* result_file; auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
#ifdef _MSC_VER
fopen_s(&result_file, result_filename, "w");
#else
result_file = fopen(result_filename, "w");
#endif
if (result_file == NULL) {
Log::Fatal("Prediction results file %s cannot be found.", result_filename); Log::Fatal("Prediction results file %s cannot be found.", result_filename);
} }
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx())); auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
...@@ -189,7 +182,7 @@ public: ...@@ -189,7 +182,7 @@ public:
}; };
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &result_file] [this, &parser_fun, &writer]
(data_size_t, const std::vector<std::string>& lines) { (data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features; std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size()); std::vector<std::string> result_to_write(lines.size());
...@@ -209,11 +202,11 @@ public: ...@@ -209,11 +202,11 @@ public:
} }
OMP_THROW_EX(); OMP_THROW_EX();
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) { for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
fprintf(result_file, "%s\n", result_to_write[i].c_str()); writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
writer->Write("\n", 1);
} }
}; };
predict_data_reader.ReadAllAndProcessParallel(process_fun); predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file);
} }
private: private:
......
#include <LightGBM/utils/common.h> #include <LightGBM/utils/common.h>
#include <LightGBM/utils/file_io.h>
#include <LightGBM/bin.h> #include <LightGBM/bin.h>
#include "dense_bin.hpp" #include "dense_bin.hpp"
...@@ -455,19 +456,19 @@ namespace LightGBM { ...@@ -455,19 +456,19 @@ namespace LightGBM {
} }
} }
void BinMapper::SaveBinaryToFile(FILE* file) const { void BinMapper::SaveBinaryToFile(const VirtualFileWriter* writer) const {
fwrite(&num_bin_, sizeof(num_bin_), 1, file); writer->Write(&num_bin_, sizeof(num_bin_));
fwrite(&missing_type_, sizeof(missing_type_), 1, file); writer->Write(&missing_type_, sizeof(missing_type_));
fwrite(&is_trival_, sizeof(is_trival_), 1, file); writer->Write(&is_trival_, sizeof(is_trival_));
fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file); writer->Write(&sparse_rate_, sizeof(sparse_rate_));
fwrite(&bin_type_, sizeof(bin_type_), 1, file); writer->Write(&bin_type_, sizeof(bin_type_));
fwrite(&min_val_, sizeof(min_val_), 1, file); writer->Write(&min_val_, sizeof(min_val_));
fwrite(&max_val_, sizeof(max_val_), 1, file); writer->Write(&max_val_, sizeof(max_val_));
fwrite(&default_bin_, sizeof(default_bin_), 1, file); writer->Write(&default_bin_, sizeof(default_bin_));
if (bin_type_ == BinType::NumericalBin) { if (bin_type_ == BinType::NumericalBin) {
fwrite(bin_upper_bound_.data(), sizeof(double), num_bin_, file); writer->Write(bin_upper_bound_.data(), sizeof(double) * num_bin_);
} else { } else {
fwrite(bin_2_categorical_.data(), sizeof(int), num_bin_, file); writer->Write(bin_2_categorical_.data(), sizeof(int) * num_bin_);
} }
} }
......
...@@ -522,31 +522,20 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -522,31 +522,20 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
bin_filename = bin_filename_str.c_str(); bin_filename = bin_filename_str.c_str();
} }
bool is_file_existed = false; bool is_file_existed = false;
FILE* file;
#ifdef _MSC_VER if (VirtualFileWriter::Exists(bin_filename)) {
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
if (file != NULL) {
is_file_existed = true; is_file_existed = true;
Log::Warning("File %s existed, cannot save binary to it", bin_filename); Log::Warning("File %s existed, cannot save binary to it", bin_filename);
fclose(file);
} }
if (!is_file_existed) { if (!is_file_existed) {
#ifdef _MSC_VER auto writer = VirtualFileWriter::Make(bin_filename);
fopen_s(&file, bin_filename, "wb"); if (!writer->Init()) {
#else
file = fopen(bin_filename, "wb");
#endif
if (file == NULL) {
Log::Fatal("Cannot write binary data to %s ", bin_filename); Log::Fatal("Cannot write binary data to %s ", bin_filename);
} }
Log::Info("Saving data to binary file %s", bin_filename); Log::Info("Saving data to binary file %s", bin_filename);
size_t size_of_token = std::strlen(binary_file_token); size_t size_of_token = std::strlen(binary_file_token);
fwrite(binary_file_token, sizeof(char), size_of_token, file); writer->Write(binary_file_token, size_of_token);
// get size of header // get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_) size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
...@@ -555,44 +544,43 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -555,44 +544,43 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int); size_of_header += feature_names_[i].size() + sizeof(int);
} }
fwrite(&size_of_header, sizeof(size_of_header), 1, file); writer->Write(&size_of_header, sizeof(size_of_header));
// write header // write header
fwrite(&num_data_, sizeof(num_data_), 1, file); writer->Write(&num_data_, sizeof(num_data_));
fwrite(&num_features_, sizeof(num_features_), 1, file); writer->Write(&num_features_, sizeof(num_features_));
fwrite(&num_total_features_, sizeof(num_total_features_), 1, file); writer->Write(&num_total_features_, sizeof(num_total_features_));
fwrite(&label_idx_, sizeof(label_idx_), 1, file); writer->Write(&label_idx_, sizeof(label_idx_));
fwrite(used_feature_map_.data(), sizeof(int), num_total_features_, file); writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_);
fwrite(&num_groups_, sizeof(num_groups_), 1, file); writer->Write(&num_groups_, sizeof(num_groups_));
fwrite(real_feature_idx_.data(), sizeof(int), num_features_, file); writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_);
fwrite(feature2group_.data(), sizeof(int), num_features_, file); writer->Write(feature2group_.data(), sizeof(int) * num_features_);
fwrite(feature2subfeature_.data(), sizeof(int), num_features_, file); writer->Write(feature2subfeature_.data(), sizeof(int) * num_features_);
fwrite(group_bin_boundaries_.data(), sizeof(uint64_t), num_groups_ + 1, file); writer->Write(group_bin_boundaries_.data(), sizeof(uint64_t) * (num_groups_ + 1));
fwrite(group_feature_start_.data(), sizeof(int), num_groups_, file); writer->Write(group_feature_start_.data(), sizeof(int) * num_groups_);
fwrite(group_feature_cnt_.data(), sizeof(int), num_groups_, file); writer->Write(group_feature_cnt_.data(), sizeof(int) * num_groups_);
// write feature names // write feature names
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast<int>(feature_names_[i].size()); int str_len = static_cast<int>(feature_names_[i].size());
fwrite(&str_len, sizeof(int), 1, file); writer->Write(&str_len, sizeof(int));
const char* c_str = feature_names_[i].c_str(); const char* c_str = feature_names_[i].c_str();
fwrite(c_str, sizeof(char), str_len, file); writer->Write(c_str, sizeof(char) * str_len);
} }
// get size of meta data // get size of meta data
size_t size_of_metadata = metadata_.SizesInByte(); size_t size_of_metadata = metadata_.SizesInByte();
fwrite(&size_of_metadata, sizeof(size_of_metadata), 1, file); writer->Write(&size_of_metadata, sizeof(size_of_metadata));
// write meta data // write meta data
metadata_.SaveBinaryToFile(file); metadata_.SaveBinaryToFile(writer.get());
// write feature data // write feature data
for (int i = 0; i < num_groups_; ++i) { for (int i = 0; i < num_groups_; ++i) {
// get size of feature // get size of feature
size_t size_of_feature = feature_groups_[i]->SizesInByte(); size_t size_of_feature = feature_groups_[i]->SizesInByte();
fwrite(&size_of_feature, sizeof(size_of_feature), 1, file); writer->Write(&size_of_feature, sizeof(size_of_feature));
// write feature // write feature
feature_groups_[i]->SaveBinaryToFile(file); feature_groups_[i]->SaveBinaryToFile(writer.get());
} }
fclose(file);
} }
} }
......
...@@ -264,14 +264,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -264,14 +264,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) { Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
FILE* file; auto reader = VirtualFileReader::Make(bin_filename);
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
dataset->data_filename_ = data_filename; dataset->data_filename_ = data_filename;
if (file == NULL) { if (!reader->Init()) {
Log::Fatal("Could not read binary data from %s", bin_filename); Log::Fatal("Could not read binary data from %s", bin_filename);
} }
...@@ -281,8 +276,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -281,8 +276,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// check token // check token
size_t size_of_token = std::strlen(Dataset::binary_file_token); size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
if (read_cnt != size_of_token) { if (read_cnt != sizeof(char) * size_of_token) {
Log::Fatal("Binary file error: token has the wrong size"); Log::Fatal("Binary file error: token has the wrong size");
} }
if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) { if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
...@@ -290,9 +285,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -290,9 +285,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
} }
// read size of header // read size of header
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file); read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != 1) { if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: header has the wrong size"); Log::Fatal("Binary file error: header has the wrong size");
} }
...@@ -304,7 +299,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -304,7 +299,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size); buffer.resize(buffer_size);
} }
// read header // read header
read_cnt = fread(buffer.data(), 1, size_of_head, file); read_cnt = reader->Read(buffer.data(), size_of_head);
if (read_cnt != size_of_head) { if (read_cnt != size_of_head) {
Log::Fatal("Binary file error: header is incorrect"); Log::Fatal("Binary file error: header is incorrect");
...@@ -389,9 +384,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -389,9 +384,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
} }
// read size of meta data // read size of meta data
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file); read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != 1) { if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: meta data has the wrong size"); Log::Fatal("Binary file error: meta data has the wrong size");
} }
...@@ -403,7 +398,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -403,7 +398,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size); buffer.resize(buffer_size);
} }
// read meta data // read meta data
read_cnt = fread(buffer.data(), 1, size_of_metadata, file); read_cnt = reader->Read(buffer.data(), size_of_metadata);
if (read_cnt != size_of_metadata) { if (read_cnt != size_of_metadata) {
Log::Fatal("Binary file error: meta data is incorrect"); Log::Fatal("Binary file error: meta data is incorrect");
...@@ -451,8 +446,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -451,8 +446,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// read feature data // read feature data
for (int i = 0; i < dataset->num_groups_; ++i) { for (int i = 0; i < dataset->num_groups_; ++i) {
// read feature size // read feature size
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file); read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != 1) { if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: feature %d has the wrong size", i); Log::Fatal("Binary file error: feature %d has the wrong size", i);
} }
size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data())); size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
...@@ -462,7 +457,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -462,7 +457,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size); buffer.resize(buffer_size);
} }
read_cnt = fread(buffer.data(), 1, size_of_feature, file); read_cnt = reader->Read(buffer.data(), size_of_feature);
if (read_cnt != size_of_feature) { if (read_cnt != size_of_feature) {
Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt); Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
...@@ -474,7 +469,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -474,7 +469,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
)); ));
} }
dataset->feature_groups_.shrink_to_fit(); dataset->feature_groups_.shrink_to_fit();
fclose(file);
dataset->is_finish_load_ = true; dataset->is_finish_load_ = true;
return dataset.release(); return dataset.release();
} }
...@@ -1068,22 +1062,12 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { ...@@ -1068,22 +1062,12 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
std::string bin_filename(filename); std::string bin_filename(filename);
bin_filename.append(".bin"); bin_filename.append(".bin");
FILE* file; auto reader = VirtualFileReader::Make(bin_filename.c_str());
#ifdef _MSC_VER if (!reader->Init()) {
fopen_s(&file, bin_filename.c_str(), "rb");
#else
file = fopen(bin_filename.c_str(), "rb");
#endif
if (file == NULL) {
bin_filename = std::string(filename); bin_filename = std::string(filename);
#ifdef _MSC_VER reader = VirtualFileReader::Make(bin_filename.c_str());
fopen_s(&file, bin_filename.c_str(), "rb"); if (!reader->Init()) {
#else
file = fopen(bin_filename.c_str(), "rb");
#endif
if (file == NULL) {
Log::Fatal("cannot open data file %s", bin_filename.c_str()); Log::Fatal("cannot open data file %s", bin_filename.c_str());
} }
} }
...@@ -1092,8 +1076,7 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) { ...@@ -1092,8 +1076,7 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
auto buffer = std::vector<char>(buffer_size); auto buffer = std::vector<char>(buffer_size);
// read size of token // read size of token
size_t size_of_token = std::strlen(Dataset::binary_file_token); size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file); size_t read_cnt = reader->Read(buffer.data(), size_of_token);
fclose(file);
if (read_cnt == size_of_token if (read_cnt == size_of_token
&& std::string(buffer.data()) == std::string(Dataset::binary_file_token)) { && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
return bin_filename; return bin_filename;
......
...@@ -302,8 +302,8 @@ public: ...@@ -302,8 +302,8 @@ public:
} }
} }
void SaveBinaryToFile(FILE* file) const override { void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
fwrite(data_.data(), sizeof(VAL_T), num_data_, file); writer->Write(data_.data(), sizeof(VAL_T) * num_data_);
} }
size_t SizesInByte() const override { size_t SizesInByte() const override {
......
...@@ -367,8 +367,8 @@ public: ...@@ -367,8 +367,8 @@ public:
} }
} }
void SaveBinaryToFile(FILE* file) const override { void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
fwrite(data_.data(), sizeof(uint8_t), data_.size(), file); writer->Write(data_.data(), sizeof(uint8_t) * data_.size());
} }
size_t SizesInByte() const override { size_t SizesInByte() const override {
......
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/file_io.h>
#include <algorithm>
#include <sstream>
#include <unordered_map>
#ifdef USE_HDFS
#include <hdfs.h>
#endif
namespace LightGBM{
struct LocalFile : VirtualFileReader, VirtualFileWriter {
LocalFile(const std::string& filename, const std::string& mode) : filename_(filename), mode_(mode) {}
virtual ~LocalFile() {
if (file_ != NULL) {
fclose(file_);
}
}
bool Init() {
if (file_ == NULL) {
#if _MSC_VER
fopen_s(&file_, filename_.c_str(), mode_.c_str());
#else
file_ = fopen(filename_.c_str(), mode_.c_str());
#endif
}
return file_ != NULL;
}
bool Exists() const {
LocalFile file(filename_, "rb");
return file.Init();
}
size_t Read(void* buffer, size_t bytes) const {
return fread(buffer, 1, bytes, file_);
}
size_t Write(const void* buffer, size_t bytes) const {
return fwrite(buffer, bytes, 1, file_) == 1 ? bytes : 0;
}
private:
FILE* file_ = NULL;
const std::string filename_;
const std::string mode_;
};
const std::string kHdfsProto = "hdfs://";
#ifdef USE_HDFS
struct HdfsFile : VirtualFileReader, VirtualFileWriter {
HdfsFile(const std::string& filename, int flags) : filename_(filename), flags_(flags) {}
~HdfsFile() {
if (file_ != NULL) {
hdfsCloseFile(fs_, file_);
}
}
bool Init() {
if (file_ == NULL) {
if (fs_ == NULL) {
fs_ = getHdfsFS(filename_);
}
if (fs_ != NULL && (flags_ == O_WRONLY || 0 == hdfsExists(fs_, filename_.c_str()))) {
file_ = hdfsOpenFile(fs_, filename_.c_str(), flags_, 0, 0, 0);
}
}
return file_ != NULL;
}
bool Exists() const {
if (fs_ == NULL) {
fs_ = getHdfsFS(filename_);
}
return fs_ != NULL && 0 == hdfsExists(fs_, filename_.c_str());
}
size_t Read(void* data, size_t bytes) const {
return FileOperation<void*>(data, bytes, &hdfsRead);
}
size_t Write(const void* data, size_t bytes) const {
return FileOperation<const void*>(data, bytes, &hdfsWrite);
}
private:
template <typename BufferType>
using fileOp = tSize(*)(hdfsFS, hdfsFile, BufferType, tSize);
template <typename BufferType>
inline size_t FileOperation(BufferType data, size_t bytes, fileOp<BufferType> op) const {
char* buffer = (char *)data;
size_t remain = bytes;
while (remain != 0) {
size_t nmax = static_cast<size_t>(std::numeric_limits<tSize>::max());
tSize ret = op(fs_, file_, buffer, std::min(nmax, remain));
if (ret > 0) {
size_t n = static_cast<size_t>(ret);
remain -= n;
buffer += n;
} else if (ret == 0) {
break;
} else if (errno != EINTR) {
Log::Fatal("Failed hdfs file operation [%s]", strerror(errno));
}
}
return bytes - remain;
}
static hdfsFS getHdfsFS(const std::string& uri) {
size_t end = uri.find("/", kHdfsProto.length());
if (uri.find(kHdfsProto) != 0 || end == std::string::npos) {
Log::Warning("Bad hdfs uri, no namenode found [%s]", uri.c_str());
return NULL;
}
std::string hostport = uri.substr(kHdfsProto.length(), end - kHdfsProto.length());
if (fs_cache_.count(hostport) == 0) {
fs_cache_[hostport] = makeHdfsFs(hostport);
}
return fs_cache_[hostport];
}
static hdfsFS makeHdfsFs(const std::string& hostport) {
std::istringstream iss(hostport);
std::string host;
tPort port = 0;
std::getline(iss, host, ':');
iss >> port;
hdfsFS fs = iss.eof() ? hdfsConnect(host.c_str(), port) : NULL;
if (fs == NULL) {
Log::Warning("Could not connect to hdfs namenode [%s]", hostport.c_str());
}
return fs;
}
mutable hdfsFS fs_ = NULL;
hdfsFile file_ = NULL;
const std::string filename_;
const int flags_;
static std::unordered_map<std::string, hdfsFS> fs_cache_;
};
std::unordered_map<std::string, hdfsFS> HdfsFile::fs_cache_ = std::unordered_map<std::string, hdfsFS>();
#define WITH_HDFS(x) x
#else
#define WITH_HDFS(x) Log::Fatal("HDFS Support not enabled.")
#endif // USE_HDFS
std::unique_ptr<VirtualFileReader> VirtualFileReader::Make(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(return std::unique_ptr<VirtualFileReader>(new HdfsFile(filename, O_RDONLY)));
} else {
return std::unique_ptr<VirtualFileReader>(new LocalFile(filename, "rb"));
}
}
std::unique_ptr<VirtualFileWriter> VirtualFileWriter::Make(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(return std::unique_ptr<VirtualFileWriter>(new HdfsFile(filename, O_WRONLY)));
} else {
return std::unique_ptr<VirtualFileWriter>(new LocalFile(filename, "wb"));
}
}
bool VirtualFileWriter::Exists(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(HdfsFile file(filename, O_RDONLY); return file.Exists());
} else {
LocalFile file(filename, "rb");
return file.Exists();
}
}
} // namespace LightGBM
...@@ -505,16 +505,16 @@ void Metadata::LoadFromMemory(const void* memory) { ...@@ -505,16 +505,16 @@ void Metadata::LoadFromMemory(const void* memory) {
LoadQueryWeights(); LoadQueryWeights();
} }
void Metadata::SaveBinaryToFile(FILE* file) const { void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
fwrite(&num_data_, sizeof(num_data_), 1, file); writer->Write(&num_data_, sizeof(num_data_));
fwrite(&num_weights_, sizeof(num_weights_), 1, file); writer->Write(&num_weights_, sizeof(num_weights_));
fwrite(&num_queries_, sizeof(num_queries_), 1, file); writer->Write(&num_queries_, sizeof(num_queries_));
fwrite(label_.data(), sizeof(label_t), num_data_, file); writer->Write(label_.data(), sizeof(label_t) * num_data_);
if (!weights_.empty()) { if (!weights_.empty()) {
fwrite(weights_.data(), sizeof(label_t), num_weights_, file); writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
} }
if (!query_boundaries_.empty()) { if (!query_boundaries_.empty()) {
fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file); writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
} }
} }
......
...@@ -69,29 +69,50 @@ enum DataType { ...@@ -69,29 +69,50 @@ enum DataType {
LIBSVM LIBSVM
}; };
void getline(std::stringstream& ss, std::string& line, const VirtualFileReader* reader, std::vector<char>& buffer, size_t buffer_size) {
std::getline(ss, line);
while (ss.eof()) {
size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) {
break;
}
ss.clear();
ss.str(std::string(buffer.data(), read_len));
std::string tmp;
std::getline(ss, tmp);
line += tmp;
}
}
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) { Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) {
std::ifstream tmp_file; auto reader = VirtualFileReader::Make(filename);
tmp_file.open(filename); if (!reader->Init()) {
if (!tmp_file.is_open()) {
Log::Fatal("Data file %s doesn't exist'", filename); Log::Fatal("Data file %s doesn't exist'", filename);
} }
std::string line1, line2; std::string line1, line2;
size_t buffer_size = 64 * 1024;
auto buffer = std::vector<char>(buffer_size);
size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) {
Log::Fatal("Data file %s couldn't be read", filename);
}
std::stringstream tmp_file(std::string(buffer.data(), read_len));
if (has_header) { if (has_header) {
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line1); getline(tmp_file, line1, reader.get(), buffer, buffer_size);
} }
} }
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line1); getline(tmp_file, line1, reader.get(), buffer, buffer_size);
} else { } else {
Log::Fatal("Data file %s should have at least one line", filename); Log::Fatal("Data file %s should have at least one line", filename);
} }
if (!tmp_file.eof()) { if (!tmp_file.eof()) {
std::getline(tmp_file, line2); getline(tmp_file, line2, reader.get(), buffer, buffer_size);
} else { } else {
Log::Warning("Data file %s only has one line", filename); Log::Warning("Data file %s only has one line", filename);
} }
tmp_file.close();
int comma_cnt = 0, comma_cnt2 = 0; int comma_cnt = 0, comma_cnt2 = 0;
int tab_cnt = 0, tab_cnt2 = 0; int tab_cnt = 0, tab_cnt2 = 0;
int colon_cnt = 0, colon_cnt2 = 0; int colon_cnt = 0, colon_cnt2 = 0;
......
...@@ -318,10 +318,10 @@ public: ...@@ -318,10 +318,10 @@ public:
fast_index_.shrink_to_fit(); fast_index_.shrink_to_fit();
} }
void SaveBinaryToFile(FILE* file) const override { void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
fwrite(&num_vals_, sizeof(num_vals_), 1, file); writer->Write(&num_vals_, sizeof(num_vals_));
fwrite(deltas_.data(), sizeof(uint8_t), num_vals_ + 1, file); writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
fwrite(vals_.data(), sizeof(VAL_T), num_vals_, file); writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
} }
size_t SizesInByte() const override { size_t SizesInByte() const override {
......
...@@ -24,7 +24,7 @@ class FileLoader(object): ...@@ -24,7 +24,7 @@ class FileLoader(object):
self.params[key] = value self.params[key] = value
def load_dataset(self, suffix, is_sparse=False): def load_dataset(self, suffix, is_sparse=False):
filename = os.path.join(self.directory, self.prefix + suffix) filename = self.path(suffix)
if is_sparse: if is_sparse:
X, Y = load_svmlight_file(filename, dtype=np.float64, zero_based=True) X, Y = load_svmlight_file(filename, dtype=np.float64, zero_based=True)
return X, Y, filename return X, Y, filename
...@@ -45,6 +45,23 @@ class FileLoader(object): ...@@ -45,6 +45,23 @@ class FileLoader(object):
np.testing.assert_array_almost_equal(y_pred, cpp_pred, decimal=5) np.testing.assert_array_almost_equal(y_pred, cpp_pred, decimal=5)
np.testing.assert_array_almost_equal(y_pred, sk_pred, decimal=5) np.testing.assert_array_almost_equal(y_pred, sk_pred, decimal=5)
def file_load_check(self, lgb_train, name):
lgb_train_f = lgb.Dataset(self.path(name), params=self.params).construct()
for f in ('num_data', 'num_feature', 'get_label', 'get_weight', 'get_init_score', 'get_group'):
a = getattr(lgb_train, f)()
b = getattr(lgb_train_f, f)()
if a is None and b is None:
pass
elif a is None:
assert np.all(b == 1), f
elif isinstance(b, (list, np.ndarray)):
np.testing.assert_array_almost_equal(a, b)
else:
assert a == b, f
def path(self, suffix):
return os.path.join(self.directory, self.prefix + suffix)
class TestEngine(unittest.TestCase): class TestEngine(unittest.TestCase):
...@@ -58,6 +75,7 @@ class TestEngine(unittest.TestCase): ...@@ -58,6 +75,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, sample_weight=weight_train) gbm.fit(X_train, y_train, sample_weight=weight_train)
sk_pred = gbm.predict_proba(X_test)[:, 1] sk_pred = gbm.predict_proba(X_test)[:, 1]
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_multiclass(self): def test_multiclass(self):
fd = FileLoader('../../examples/multiclass_classification', 'multiclass') fd = FileLoader('../../examples/multiclass_classification', 'multiclass')
...@@ -68,6 +86,7 @@ class TestEngine(unittest.TestCase): ...@@ -68,6 +86,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train) gbm.fit(X_train, y_train)
sk_pred = gbm.predict_proba(X_test) sk_pred = gbm.predict_proba(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_regression(self): def test_regression(self):
fd = FileLoader('../../examples/regression', 'regression') fd = FileLoader('../../examples/regression', 'regression')
...@@ -79,6 +98,7 @@ class TestEngine(unittest.TestCase): ...@@ -79,6 +98,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, init_score=init_score_train) gbm.fit(X_train, y_train, init_score=init_score_train)
sk_pred = gbm.predict(X_test) sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_lambdarank(self): def test_lambdarank(self):
fd = FileLoader('../../examples/lambdarank', 'rank') fd = FileLoader('../../examples/lambdarank', 'rank')
...@@ -90,3 +110,4 @@ class TestEngine(unittest.TestCase): ...@@ -90,3 +110,4 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, group=group_train) gbm.fit(X_train, y_train, group=group_train)
sk_pred = gbm.predict(X_test) sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred) fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment