Commit 7e186a57 authored by ebernhardson's avatar ebernhardson Committed by Guolin Ke
Browse files

Experimental support for HDFS (#1243)

* Read and write datsets from hdfs.
* Only enabled when cmake is run with -DUSE_HDFS:BOOL=TRUE
* Introduces VirtualFile(Reader|Writer) to asbtract VFS differences
parent 7501faa6
......@@ -12,6 +12,7 @@ OPTION(USE_MPI "MPI based parallel learning" OFF)
OPTION(USE_OPENMP "Enable OpenMP" ON)
OPTION(USE_GPU "Enable GPU-acclerated training (EXPERIMENTAL)" OFF)
OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF)
OPTION(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "4.8.2")
......@@ -86,6 +87,15 @@ if(USE_GPU)
ADD_DEFINITIONS(-DUSE_GPU)
endif(USE_GPU)
if(USE_HDFS)
find_package(JNI REQUIRED)
find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED)
find_library(HDFS_LIB NAMES hdfs REQUIRED)
include_directories(${HDFS_INCLUDE_DIR})
ADD_DEFINITIONS(-DUSE_HDFS)
SET(HDFS_CXX_LIBRARIES ${HDFS_LIB} ${JAVA_JVM_LIBRARY})
endif(USE_HDFS)
if(UNIX OR MINGW OR CYGWIN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas")
endif()
......@@ -173,6 +183,11 @@ if(USE_GPU)
TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES})
endif(USE_GPU)
if(USE_HDFS)
TARGET_LINK_LIBRARIES(lightgbm ${HDFS_CXX_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${HDFS_CXX_LIBRARIES})
endif(USE_HDFS)
if(WIN32 AND (MINGW OR CYGWIN))
TARGET_LINK_LIBRARIES(lightgbm Ws2_32)
TARGET_LINK_LIBRARIES(_lightgbm Ws2_32)
......
......@@ -98,7 +98,7 @@ public:
* \brief Save binary data to file
* \param file File want to write
*/
void SaveBinaryToFile(FILE* file) const;
void SaveBinaryToFile(const VirtualFileWriter* writer) const;
/*!
* \brief Mapping bin into feature value
* \param bin
......@@ -308,7 +308,7 @@ public:
* \brief Save binary data to file
* \param file File want to write
*/
virtual void SaveBinaryToFile(FILE* file) const = 0;
virtual void SaveBinaryToFile(const VirtualFileWriter* writer) const = 0;
/*!
* \brief Load from memory
......
......@@ -99,7 +99,7 @@ public:
* \brief Save binary data to file
* \param file File want to write
*/
void SaveBinaryToFile(FILE* file) const;
void SaveBinaryToFile(const VirtualFileWriter* writer) const;
/*!
* \brief Get sizes in byte of this object
......
......@@ -191,13 +191,13 @@ public:
* \brief Save binary data to file
* \param file File want to write
*/
void SaveBinaryToFile(FILE* file) const {
fwrite(&is_sparse_, sizeof(is_sparse_), 1, file);
fwrite(&num_feature_, sizeof(num_feature_), 1, file);
void SaveBinaryToFile(const VirtualFileWriter* writer) const {
writer->Write(&is_sparse_, sizeof(is_sparse_));
writer->Write(&num_feature_, sizeof(num_feature_));
for (int i = 0; i < num_feature_; ++i) {
bin_mappers_[i]->SaveBinaryToFile(file);
bin_mappers_[i]->SaveBinaryToFile(writer);
}
bin_data_->SaveBinaryToFile(file);
bin_data_->SaveBinaryToFile(writer);
}
/*!
* \brief Get sizes in byte of this object
......
#ifndef LIGHTGBM_UTILS_FILE_IO_H_
#define LIGHTGBM_UTILS_FILE_IO_H_
#include <memory>
namespace LightGBM{
/*!
* \brief An interface for writing files from buffers
*/
struct VirtualFileWriter {
virtual ~VirtualFileWriter() {};
/*!
* \brief Initialize the writer
* \return True when the file is available for writes
*/
virtual bool Init() = 0;
/*!
* \brief Append buffer to file
* \param data Buffer to write from
* \param bytes Number of bytes to write from buffer
* \return Number of bytes written
*/
virtual size_t Write(const void* data, size_t bytes) const = 0;
/*!
* \brief Create appropriate writer for filename
* \param filename Filename of the data
* \return File writer instance
*/
static std::unique_ptr<VirtualFileWriter> Make(const std::string& filename);
/*!
* \brief Check filename existence
* \param filename Filename of the data
* \return True when the file exists
*/
static bool Exists(const std::string& filename);
};
/**
* \brief An interface for reading files into buffers
*/
struct VirtualFileReader {
/*!
* \brief Constructor
* \param filename Filename of the data
*/
virtual ~VirtualFileReader() {};
/*!
* \brief Initialize the reader
* \return True when the file is available for read
*/
virtual bool Init() = 0;
/*!
* \brief Read data into buffer
* \param buffer Buffer to read data into
* \param bytes Number of bytes to read
* \return Number of bytes read
*/
virtual size_t Read(void* buffer, size_t bytes) const = 0;
/*!
* \brief Create appropriate reader for filename
* \param filename Filename of the data
* \return File reader instance
*/
static std::unique_ptr<VirtualFileReader> Make(const std::string& filename);
};
} // namespace LightGBM
#endif // LightGBM_UTILS_FILE_IO_H_
......@@ -9,6 +9,7 @@
#include <thread>
#include <memory>
#include <algorithm>
#include "file_io.h"
namespace LightGBM{
......@@ -23,14 +24,8 @@ public:
* \process_fun Process function
*/
static size_t Read(const char* filename, int skip_bytes, const std::function<size_t (const char*, size_t)>& process_fun) {
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, filename, "rb");
#else
file = fopen(filename, "rb");
#endif
if (file == NULL) {
auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) {
return 0;
}
size_t cnt = 0;
......@@ -42,16 +37,17 @@ public:
size_t read_cnt = 0;
if (skip_bytes > 0) {
// skip first k bytes
read_cnt = fread(buffer_process.data(), 1, skip_bytes, file);
read_cnt = reader->Read(buffer_process.data(), skip_bytes);
}
// read first block
read_cnt = fread(buffer_process.data(), 1, buffer_size, file);
read_cnt = reader->Read(buffer_process.data(), buffer_size);
size_t last_read_cnt = 0;
while (read_cnt > 0) {
// start read thread
std::thread read_worker = std::thread(
[file, &buffer_read, buffer_size, &last_read_cnt] {
last_read_cnt = fread(buffer_read.data(), 1, buffer_size, file);
[&reader, &buffer_read, buffer_size, &last_read_cnt] {
last_read_cnt = reader->Read(buffer_read.data(), buffer_size);
}
);
// start process
......@@ -62,8 +58,6 @@ public:
std::swap(buffer_process, buffer_read);
read_cnt = last_read_cnt;
}
// close file
fclose(file);
return cnt;
}
......
......@@ -28,36 +28,29 @@ public:
TextReader(const char* filename, bool is_skip_first_line):
filename_(filename), is_skip_first_line_(is_skip_first_line){
if (is_skip_first_line_) {
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, filename, "r");
#else
file = fopen(filename, "r");
#endif
if (file == NULL) {
auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) {
Log::Fatal("Could not open %s", filename);
}
std::stringstream str_buf;
int read_c = -1;
read_c = fgetc(file);
while (read_c != EOF) {
char tmp_ch = static_cast<char>(read_c);
if (tmp_ch == '\n' || tmp_ch == '\r') {
char read_c;
size_t nread = reader->Read(&read_c, 1);
while (nread == 1) {
if (read_c == '\n' || read_c == '\r') {
break;
}
str_buf << tmp_ch;
str_buf << read_c;
++skip_bytes_;
read_c = fgetc(file);
nread = reader->Read(&read_c, 1);
}
if (static_cast<char>(read_c) == '\r') {
read_c = fgetc(file);
if (read_c == '\r') {
reader->Read(&read_c, 1);
++skip_bytes_;
}
if (static_cast<char>(read_c) == '\n') {
read_c = fgetc(file);
if (read_c == '\n') {
reader->Read(&read_c, 1);
++skip_bytes_;
}
fclose(file);
first_line_ = str_buf.str();
Log::Debug("Skipped header \"%s\" in file %s", first_line_.c_str(), filename_);
}
......@@ -151,25 +144,18 @@ public:
std::vector<char> ReadContent(size_t* out_len) {
std::vector<char> ret;
*out_len = 0;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, filename_, "rb");
#else
file = fopen(filename_, "rb");
#endif
if (file == NULL) {
auto reader = VirtualFileReader::Make(filename_);
if (!reader->Init()) {
return ret;
}
const size_t buffer_size = 16 * 1024 * 1024;
auto buffer_read = std::vector<char>(buffer_size);
size_t read_cnt = 0;
do {
read_cnt = fread(buffer_read.data(), 1, buffer_size, file);
read_cnt = reader->Read(buffer_read.data(), buffer_size);
ret.insert(ret.end(), buffer_read.begin(), buffer_read.begin() + read_cnt);
*out_len += read_cnt;
} while (read_cnt > 0);
// close file
fclose(file);
return ret;
}
......
......@@ -128,15 +128,8 @@ public:
* \param result_filename Filename of output result
*/
void Predict(const char* data_filename, const char* result_filename, bool has_header) {
FILE* result_file;
#ifdef _MSC_VER
fopen_s(&result_file, result_filename, "w");
#else
result_file = fopen(result_filename, "w");
#endif
if (result_file == NULL) {
auto writer = VirtualFileWriter::Make(result_filename);
if (!writer->Init()) {
Log::Fatal("Prediction results file %s cannot be found.", result_filename);
}
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, has_header, boosting_->MaxFeatureIdx() + 1, boosting_->LabelIdx()));
......@@ -189,7 +182,7 @@ public:
};
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &parser_fun, &result_file]
[this, &parser_fun, &writer]
(data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size());
......@@ -209,11 +202,11 @@ public:
}
OMP_THROW_EX();
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
fprintf(result_file, "%s\n", result_to_write[i].c_str());
writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
writer->Write("\n", 1);
}
};
predict_data_reader.ReadAllAndProcessParallel(process_fun);
fclose(result_file);
}
private:
......
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/file_io.h>
#include <LightGBM/bin.h>
#include "dense_bin.hpp"
......@@ -455,19 +456,19 @@ namespace LightGBM {
}
}
void BinMapper::SaveBinaryToFile(FILE* file) const {
fwrite(&num_bin_, sizeof(num_bin_), 1, file);
fwrite(&missing_type_, sizeof(missing_type_), 1, file);
fwrite(&is_trival_, sizeof(is_trival_), 1, file);
fwrite(&sparse_rate_, sizeof(sparse_rate_), 1, file);
fwrite(&bin_type_, sizeof(bin_type_), 1, file);
fwrite(&min_val_, sizeof(min_val_), 1, file);
fwrite(&max_val_, sizeof(max_val_), 1, file);
fwrite(&default_bin_, sizeof(default_bin_), 1, file);
void BinMapper::SaveBinaryToFile(const VirtualFileWriter* writer) const {
writer->Write(&num_bin_, sizeof(num_bin_));
writer->Write(&missing_type_, sizeof(missing_type_));
writer->Write(&is_trival_, sizeof(is_trival_));
writer->Write(&sparse_rate_, sizeof(sparse_rate_));
writer->Write(&bin_type_, sizeof(bin_type_));
writer->Write(&min_val_, sizeof(min_val_));
writer->Write(&max_val_, sizeof(max_val_));
writer->Write(&default_bin_, sizeof(default_bin_));
if (bin_type_ == BinType::NumericalBin) {
fwrite(bin_upper_bound_.data(), sizeof(double), num_bin_, file);
writer->Write(bin_upper_bound_.data(), sizeof(double) * num_bin_);
} else {
fwrite(bin_2_categorical_.data(), sizeof(int), num_bin_, file);
writer->Write(bin_2_categorical_.data(), sizeof(int) * num_bin_);
}
}
......
......@@ -522,31 +522,20 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
bin_filename = bin_filename_str.c_str();
}
bool is_file_existed = false;
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
if (file != NULL) {
if (VirtualFileWriter::Exists(bin_filename)) {
is_file_existed = true;
Log::Warning("File %s existed, cannot save binary to it", bin_filename);
fclose(file);
}
if (!is_file_existed) {
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "wb");
#else
file = fopen(bin_filename, "wb");
#endif
if (file == NULL) {
auto writer = VirtualFileWriter::Make(bin_filename);
if (!writer->Init()) {
Log::Fatal("Cannot write binary data to %s ", bin_filename);
}
Log::Info("Saving data to binary file %s", bin_filename);
size_t size_of_token = std::strlen(binary_file_token);
fwrite(binary_file_token, sizeof(char), size_of_token, file);
writer->Write(binary_file_token, size_of_token);
// get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
......@@ -555,44 +544,43 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int);
}
fwrite(&size_of_header, sizeof(size_of_header), 1, file);
writer->Write(&size_of_header, sizeof(size_of_header));
// write header
fwrite(&num_data_, sizeof(num_data_), 1, file);
fwrite(&num_features_, sizeof(num_features_), 1, file);
fwrite(&num_total_features_, sizeof(num_total_features_), 1, file);
fwrite(&label_idx_, sizeof(label_idx_), 1, file);
fwrite(used_feature_map_.data(), sizeof(int), num_total_features_, file);
fwrite(&num_groups_, sizeof(num_groups_), 1, file);
fwrite(real_feature_idx_.data(), sizeof(int), num_features_, file);
fwrite(feature2group_.data(), sizeof(int), num_features_, file);
fwrite(feature2subfeature_.data(), sizeof(int), num_features_, file);
fwrite(group_bin_boundaries_.data(), sizeof(uint64_t), num_groups_ + 1, file);
fwrite(group_feature_start_.data(), sizeof(int), num_groups_, file);
fwrite(group_feature_cnt_.data(), sizeof(int), num_groups_, file);
writer->Write(&num_data_, sizeof(num_data_));
writer->Write(&num_features_, sizeof(num_features_));
writer->Write(&num_total_features_, sizeof(num_total_features_));
writer->Write(&label_idx_, sizeof(label_idx_));
writer->Write(used_feature_map_.data(), sizeof(int) * num_total_features_);
writer->Write(&num_groups_, sizeof(num_groups_));
writer->Write(real_feature_idx_.data(), sizeof(int) * num_features_);
writer->Write(feature2group_.data(), sizeof(int) * num_features_);
writer->Write(feature2subfeature_.data(), sizeof(int) * num_features_);
writer->Write(group_bin_boundaries_.data(), sizeof(uint64_t) * (num_groups_ + 1));
writer->Write(group_feature_start_.data(), sizeof(int) * num_groups_);
writer->Write(group_feature_cnt_.data(), sizeof(int) * num_groups_);
// write feature names
for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast<int>(feature_names_[i].size());
fwrite(&str_len, sizeof(int), 1, file);
writer->Write(&str_len, sizeof(int));
const char* c_str = feature_names_[i].c_str();
fwrite(c_str, sizeof(char), str_len, file);
writer->Write(c_str, sizeof(char) * str_len);
}
// get size of meta data
size_t size_of_metadata = metadata_.SizesInByte();
fwrite(&size_of_metadata, sizeof(size_of_metadata), 1, file);
writer->Write(&size_of_metadata, sizeof(size_of_metadata));
// write meta data
metadata_.SaveBinaryToFile(file);
metadata_.SaveBinaryToFile(writer.get());
// write feature data
for (int i = 0; i < num_groups_; ++i) {
// get size of feature
size_t size_of_feature = feature_groups_[i]->SizesInByte();
fwrite(&size_of_feature, sizeof(size_of_feature), 1, file);
writer->Write(&size_of_feature, sizeof(size_of_feature));
// write feature
feature_groups_[i]->SaveBinaryToFile(file);
feature_groups_[i]->SaveBinaryToFile(writer.get());
}
fclose(file);
}
}
......
......@@ -264,14 +264,9 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
auto dataset = std::unique_ptr<Dataset>(new Dataset());
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename, "rb");
#endif
auto reader = VirtualFileReader::Make(bin_filename);
dataset->data_filename_ = data_filename;
if (file == NULL) {
if (!reader->Init()) {
Log::Fatal("Could not read binary data from %s", bin_filename);
}
......@@ -281,8 +276,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// check token
size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file);
if (read_cnt != size_of_token) {
size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
if (read_cnt != sizeof(char) * size_of_token) {
Log::Fatal("Binary file error: token has the wrong size");
}
if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
......@@ -290,9 +285,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
// read size of header
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file);
read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != 1) {
if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: header has the wrong size");
}
......@@ -304,7 +299,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size);
}
// read header
read_cnt = fread(buffer.data(), 1, size_of_head, file);
read_cnt = reader->Read(buffer.data(), size_of_head);
if (read_cnt != size_of_head) {
Log::Fatal("Binary file error: header is incorrect");
......@@ -389,9 +384,9 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
// read size of meta data
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file);
read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != 1) {
if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: meta data has the wrong size");
}
......@@ -403,7 +398,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size);
}
// read meta data
read_cnt = fread(buffer.data(), 1, size_of_metadata, file);
read_cnt = reader->Read(buffer.data(), size_of_metadata);
if (read_cnt != size_of_metadata) {
Log::Fatal("Binary file error: meta data is incorrect");
......@@ -451,8 +446,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// read feature data
for (int i = 0; i < dataset->num_groups_; ++i) {
// read feature size
read_cnt = fread(buffer.data(), sizeof(size_t), 1, file);
if (read_cnt != 1) {
read_cnt = reader->Read(buffer.data(), sizeof(size_t));
if (read_cnt != sizeof(size_t)) {
Log::Fatal("Binary file error: feature %d has the wrong size", i);
}
size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
......@@ -462,7 +457,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
buffer.resize(buffer_size);
}
read_cnt = fread(buffer.data(), 1, size_of_feature, file);
read_cnt = reader->Read(buffer.data(), size_of_feature);
if (read_cnt != size_of_feature) {
Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
......@@ -474,7 +469,6 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
));
}
dataset->feature_groups_.shrink_to_fit();
fclose(file);
dataset->is_finish_load_ = true;
return dataset.release();
}
......@@ -1068,22 +1062,12 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
std::string bin_filename(filename);
bin_filename.append(".bin");
FILE* file;
auto reader = VirtualFileReader::Make(bin_filename.c_str());
#ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb");
#else
file = fopen(bin_filename.c_str(), "rb");
#endif
if (file == NULL) {
if (!reader->Init()) {
bin_filename = std::string(filename);
#ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb");
#else
file = fopen(bin_filename.c_str(), "rb");
#endif
if (file == NULL) {
reader = VirtualFileReader::Make(bin_filename.c_str());
if (!reader->Init()) {
Log::Fatal("cannot open data file %s", bin_filename.c_str());
}
}
......@@ -1092,8 +1076,7 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
auto buffer = std::vector<char>(buffer_size);
// read size of token
size_t size_of_token = std::strlen(Dataset::binary_file_token);
size_t read_cnt = fread(buffer.data(), sizeof(char), size_of_token, file);
fclose(file);
size_t read_cnt = reader->Read(buffer.data(), size_of_token);
if (read_cnt == size_of_token
&& std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
return bin_filename;
......
......@@ -302,8 +302,8 @@ public:
}
}
void SaveBinaryToFile(FILE* file) const override {
fwrite(data_.data(), sizeof(VAL_T), num_data_, file);
void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
writer->Write(data_.data(), sizeof(VAL_T) * num_data_);
}
size_t SizesInByte() const override {
......
......@@ -367,8 +367,8 @@ public:
}
}
void SaveBinaryToFile(FILE* file) const override {
fwrite(data_.data(), sizeof(uint8_t), data_.size(), file);
void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
writer->Write(data_.data(), sizeof(uint8_t) * data_.size());
}
size_t SizesInByte() const override {
......
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/file_io.h>
#include <algorithm>
#include <sstream>
#include <unordered_map>
#ifdef USE_HDFS
#include <hdfs.h>
#endif
namespace LightGBM{
struct LocalFile : VirtualFileReader, VirtualFileWriter {
LocalFile(const std::string& filename, const std::string& mode) : filename_(filename), mode_(mode) {}
virtual ~LocalFile() {
if (file_ != NULL) {
fclose(file_);
}
}
bool Init() {
if (file_ == NULL) {
#if _MSC_VER
fopen_s(&file_, filename_.c_str(), mode_.c_str());
#else
file_ = fopen(filename_.c_str(), mode_.c_str());
#endif
}
return file_ != NULL;
}
bool Exists() const {
LocalFile file(filename_, "rb");
return file.Init();
}
size_t Read(void* buffer, size_t bytes) const {
return fread(buffer, 1, bytes, file_);
}
size_t Write(const void* buffer, size_t bytes) const {
return fwrite(buffer, bytes, 1, file_) == 1 ? bytes : 0;
}
private:
FILE* file_ = NULL;
const std::string filename_;
const std::string mode_;
};
const std::string kHdfsProto = "hdfs://";
#ifdef USE_HDFS
struct HdfsFile : VirtualFileReader, VirtualFileWriter {
HdfsFile(const std::string& filename, int flags) : filename_(filename), flags_(flags) {}
~HdfsFile() {
if (file_ != NULL) {
hdfsCloseFile(fs_, file_);
}
}
bool Init() {
if (file_ == NULL) {
if (fs_ == NULL) {
fs_ = getHdfsFS(filename_);
}
if (fs_ != NULL && (flags_ == O_WRONLY || 0 == hdfsExists(fs_, filename_.c_str()))) {
file_ = hdfsOpenFile(fs_, filename_.c_str(), flags_, 0, 0, 0);
}
}
return file_ != NULL;
}
bool Exists() const {
if (fs_ == NULL) {
fs_ = getHdfsFS(filename_);
}
return fs_ != NULL && 0 == hdfsExists(fs_, filename_.c_str());
}
size_t Read(void* data, size_t bytes) const {
return FileOperation<void*>(data, bytes, &hdfsRead);
}
size_t Write(const void* data, size_t bytes) const {
return FileOperation<const void*>(data, bytes, &hdfsWrite);
}
private:
template <typename BufferType>
using fileOp = tSize(*)(hdfsFS, hdfsFile, BufferType, tSize);
template <typename BufferType>
inline size_t FileOperation(BufferType data, size_t bytes, fileOp<BufferType> op) const {
char* buffer = (char *)data;
size_t remain = bytes;
while (remain != 0) {
size_t nmax = static_cast<size_t>(std::numeric_limits<tSize>::max());
tSize ret = op(fs_, file_, buffer, std::min(nmax, remain));
if (ret > 0) {
size_t n = static_cast<size_t>(ret);
remain -= n;
buffer += n;
} else if (ret == 0) {
break;
} else if (errno != EINTR) {
Log::Fatal("Failed hdfs file operation [%s]", strerror(errno));
}
}
return bytes - remain;
}
static hdfsFS getHdfsFS(const std::string& uri) {
size_t end = uri.find("/", kHdfsProto.length());
if (uri.find(kHdfsProto) != 0 || end == std::string::npos) {
Log::Warning("Bad hdfs uri, no namenode found [%s]", uri.c_str());
return NULL;
}
std::string hostport = uri.substr(kHdfsProto.length(), end - kHdfsProto.length());
if (fs_cache_.count(hostport) == 0) {
fs_cache_[hostport] = makeHdfsFs(hostport);
}
return fs_cache_[hostport];
}
static hdfsFS makeHdfsFs(const std::string& hostport) {
std::istringstream iss(hostport);
std::string host;
tPort port = 0;
std::getline(iss, host, ':');
iss >> port;
hdfsFS fs = iss.eof() ? hdfsConnect(host.c_str(), port) : NULL;
if (fs == NULL) {
Log::Warning("Could not connect to hdfs namenode [%s]", hostport.c_str());
}
return fs;
}
mutable hdfsFS fs_ = NULL;
hdfsFile file_ = NULL;
const std::string filename_;
const int flags_;
static std::unordered_map<std::string, hdfsFS> fs_cache_;
};
std::unordered_map<std::string, hdfsFS> HdfsFile::fs_cache_ = std::unordered_map<std::string, hdfsFS>();
#define WITH_HDFS(x) x
#else
#define WITH_HDFS(x) Log::Fatal("HDFS Support not enabled.")
#endif // USE_HDFS
std::unique_ptr<VirtualFileReader> VirtualFileReader::Make(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(return std::unique_ptr<VirtualFileReader>(new HdfsFile(filename, O_RDONLY)));
} else {
return std::unique_ptr<VirtualFileReader>(new LocalFile(filename, "rb"));
}
}
std::unique_ptr<VirtualFileWriter> VirtualFileWriter::Make(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(return std::unique_ptr<VirtualFileWriter>(new HdfsFile(filename, O_WRONLY)));
} else {
return std::unique_ptr<VirtualFileWriter>(new LocalFile(filename, "wb"));
}
}
bool VirtualFileWriter::Exists(const std::string& filename) {
if (0 == filename.find(kHdfsProto)) {
WITH_HDFS(HdfsFile file(filename, O_RDONLY); return file.Exists());
} else {
LocalFile file(filename, "rb");
return file.Exists();
}
}
} // namespace LightGBM
......@@ -505,16 +505,16 @@ void Metadata::LoadFromMemory(const void* memory) {
LoadQueryWeights();
}
void Metadata::SaveBinaryToFile(FILE* file) const {
fwrite(&num_data_, sizeof(num_data_), 1, file);
fwrite(&num_weights_, sizeof(num_weights_), 1, file);
fwrite(&num_queries_, sizeof(num_queries_), 1, file);
fwrite(label_.data(), sizeof(label_t), num_data_, file);
void Metadata::SaveBinaryToFile(const VirtualFileWriter* writer) const {
writer->Write(&num_data_, sizeof(num_data_));
writer->Write(&num_weights_, sizeof(num_weights_));
writer->Write(&num_queries_, sizeof(num_queries_));
writer->Write(label_.data(), sizeof(label_t) * num_data_);
if (!weights_.empty()) {
fwrite(weights_.data(), sizeof(label_t), num_weights_, file);
writer->Write(weights_.data(), sizeof(label_t) * num_weights_);
}
if (!query_boundaries_.empty()) {
fwrite(query_boundaries_.data(), sizeof(data_size_t), num_queries_ + 1, file);
writer->Write(query_boundaries_.data(), sizeof(data_size_t) * (num_queries_ + 1));
}
}
......
......@@ -69,29 +69,50 @@ enum DataType {
LIBSVM
};
void getline(std::stringstream& ss, std::string& line, const VirtualFileReader* reader, std::vector<char>& buffer, size_t buffer_size) {
std::getline(ss, line);
while (ss.eof()) {
size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) {
break;
}
ss.clear();
ss.str(std::string(buffer.data(), read_len));
std::string tmp;
std::getline(ss, tmp);
line += tmp;
}
}
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) {
std::ifstream tmp_file;
tmp_file.open(filename);
if (!tmp_file.is_open()) {
auto reader = VirtualFileReader::Make(filename);
if (!reader->Init()) {
Log::Fatal("Data file %s doesn't exist'", filename);
}
std::string line1, line2;
size_t buffer_size = 64 * 1024;
auto buffer = std::vector<char>(buffer_size);
size_t read_len = reader->Read(buffer.data(), buffer_size);
if (read_len <= 0) {
Log::Fatal("Data file %s couldn't be read", filename);
}
std::stringstream tmp_file(std::string(buffer.data(), read_len));
if (has_header) {
if (!tmp_file.eof()) {
std::getline(tmp_file, line1);
getline(tmp_file, line1, reader.get(), buffer, buffer_size);
}
}
if (!tmp_file.eof()) {
std::getline(tmp_file, line1);
getline(tmp_file, line1, reader.get(), buffer, buffer_size);
} else {
Log::Fatal("Data file %s should have at least one line", filename);
}
if (!tmp_file.eof()) {
std::getline(tmp_file, line2);
getline(tmp_file, line2, reader.get(), buffer, buffer_size);
} else {
Log::Warning("Data file %s only has one line", filename);
}
tmp_file.close();
int comma_cnt = 0, comma_cnt2 = 0;
int tab_cnt = 0, tab_cnt2 = 0;
int colon_cnt = 0, colon_cnt2 = 0;
......
......@@ -318,10 +318,10 @@ public:
fast_index_.shrink_to_fit();
}
void SaveBinaryToFile(FILE* file) const override {
fwrite(&num_vals_, sizeof(num_vals_), 1, file);
fwrite(deltas_.data(), sizeof(uint8_t), num_vals_ + 1, file);
fwrite(vals_.data(), sizeof(VAL_T), num_vals_, file);
void SaveBinaryToFile(const VirtualFileWriter* writer) const override {
writer->Write(&num_vals_, sizeof(num_vals_));
writer->Write(deltas_.data(), sizeof(uint8_t) * (num_vals_ + 1));
writer->Write(vals_.data(), sizeof(VAL_T) * num_vals_);
}
size_t SizesInByte() const override {
......
......@@ -24,7 +24,7 @@ class FileLoader(object):
self.params[key] = value
def load_dataset(self, suffix, is_sparse=False):
filename = os.path.join(self.directory, self.prefix + suffix)
filename = self.path(suffix)
if is_sparse:
X, Y = load_svmlight_file(filename, dtype=np.float64, zero_based=True)
return X, Y, filename
......@@ -45,6 +45,23 @@ class FileLoader(object):
np.testing.assert_array_almost_equal(y_pred, cpp_pred, decimal=5)
np.testing.assert_array_almost_equal(y_pred, sk_pred, decimal=5)
def file_load_check(self, lgb_train, name):
lgb_train_f = lgb.Dataset(self.path(name), params=self.params).construct()
for f in ('num_data', 'num_feature', 'get_label', 'get_weight', 'get_init_score', 'get_group'):
a = getattr(lgb_train, f)()
b = getattr(lgb_train_f, f)()
if a is None and b is None:
pass
elif a is None:
assert np.all(b == 1), f
elif isinstance(b, (list, np.ndarray)):
np.testing.assert_array_almost_equal(a, b)
else:
assert a == b, f
def path(self, suffix):
return os.path.join(self.directory, self.prefix + suffix)
class TestEngine(unittest.TestCase):
......@@ -58,6 +75,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, sample_weight=weight_train)
sk_pred = gbm.predict_proba(X_test)[:, 1]
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_multiclass(self):
fd = FileLoader('../../examples/multiclass_classification', 'multiclass')
......@@ -68,6 +86,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train)
sk_pred = gbm.predict_proba(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_regression(self):
fd = FileLoader('../../examples/regression', 'regression')
......@@ -79,6 +98,7 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, init_score=init_score_train)
sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
def test_lambdarank(self):
fd = FileLoader('../../examples/lambdarank', 'rank')
......@@ -90,3 +110,4 @@ class TestEngine(unittest.TestCase):
gbm.fit(X_train, y_train, group=group_train)
sk_pred = gbm.predict(X_test)
fd.train_predict_check(lgb_train, X_test, X_test_fn, sk_pred)
fd.file_load_check(lgb_train, '.train')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment