Unverified Commit b0137deb authored by chjinche's avatar chjinche Committed by GitHub
Browse files

Add customized parser support (#4782)

* add customized parser support

* fix typo of parser_config_file description

* make delimiter as parameter of JoinedLines
parent 843d380d
......@@ -117,6 +117,8 @@ MLflow (experiment tracking, model monitoring framework): https://github.com/mlf
`{mlr3extralearners}` (R `{mlr3}`-compliant interface): https://github.com/mlr-org/mlr3extralearners
lightgbm-transform (feature transformation binding): https://github.com/microsoft/lightgbm-transform
Support
-------
......
......@@ -850,6 +850,14 @@ Dataset Parameters
- **Note**: setting this to ``true`` may lead to much slower text parsing
- ``parser_config_file`` :raw-html:`<a id="parser_config_file" title="Permalink to this parameter" href="#parser_config_file">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
- path to a ``.json`` file that specifies customized parser initialized configuration
- see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples
- **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issues page <https://github.com/microsoft/lightgbm-transform/issues>`__
Predict Parameters
~~~~~~~~~~~~~~~~~~
......
......@@ -314,6 +314,8 @@ class LIGHTGBM_EXPORT Boosting {
static Boosting* CreateBoosting(const std::string& type, const char* filename);
virtual bool IsLinear() const { return false; }
virtual std::string ParserConfigStr() const = 0;
};
class GBDTBase : public Boosting {
......
......@@ -721,6 +721,11 @@ struct Config {
// desc = **Note**: setting this to ``true`` may lead to much slower text parsing
bool precise_float_parser = false;
// desc = path to a ``.json`` file that specifies customized parser initialized configuration
// desc = see `lightgbm-transform <https://github.com/microsoft/lightgbm-transform>`__ for usage examples
// desc = **Note**: ``lightgbm-transform`` is not maintained by LightGBM's maintainers. Bug reports or feature requests should go to `issues page <https://github.com/microsoft/lightgbm-transform/issues>`__
std::string parser_config_file = "";
#pragma endregion
#pragma region Predict Parameters
......
......@@ -15,6 +15,7 @@
#include <string>
#include <functional>
#include <map>
#include <memory>
#include <mutex>
#include <unordered_set>
......@@ -254,6 +255,14 @@ class Parser {
public:
typedef const char* (*AtofFunc)(const char* p, double* out);
/*! \brief Default constructor */
Parser() {}
/*!
* \brief Constructor for customized parser. The constructor accepts content not path because need to save/load the config along with model string
*/
explicit Parser(std::string) {}
/*! \brief virtual destructor */
virtual ~Parser() {}
......@@ -271,12 +280,58 @@ class Parser {
/*!
* \brief Create an object of parser, will auto choose the format depend on file
* \param filename One Filename of data
* \param header whether input file contains header
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
* \param precise_float_parser using precise floating point number parsing if true
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser);
/*!
* \brief Create an object of parser, could use customized parser, or auto choose the format depend on file
* \param filename One Filename of data
* \param header whether input file contains header
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param label_idx index of label column
* \param precise_float_parser using precise floating point number parsing if true
* \param parser_config_str Customized parser config content
* \return Object of parser
*/
static Parser* CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser,
std::string parser_config_str);
/*!
* \brief Generate parser config str used for custom parser initialization, may save values of label id and header
* \param filename One Filename of data
* \param parser_config_filename One Filename of parser config
* \param header whether input file contains header
* \param label_idx index of label column
* \return Parser config str
*/
static std::string GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx);
};
/*! \brief Interface for parser factory, used by customized parser */
class ParserFactory {
private:
ParserFactory() {}
std::map<std::string, std::function<Parser*(std::string)>> object_map_;
public:
~ParserFactory() {}
static ParserFactory& getInstance();
void Register(std::string class_name, std::function<Parser*(std::string)> objc);
Parser* getObject(std::string class_name, std::string config_str);
};
/*! \brief Interface for parser reflector, used by customized parser */
class ParserReflector {
public:
ParserReflector(std::string class_name, std::function<Parser*(std::string)> objc) {
ParserFactory::getInstance().Register(class_name, objc);
}
virtual ~ParserReflector() {}
};
/*! \brief The main class of data set,
......@@ -605,6 +660,9 @@ class Dataset {
/*! \brief Get names of current data set */
inline const std::vector<std::string>& feature_names() const { return feature_names_; }
/*! \brief Get content of parser config file */
inline const std::string parser_config_str() const { return parser_config_str_; }
inline void set_feature_names(const std::vector<std::string>& feature_names) {
if (feature_names.size() != static_cast<size_t>(num_total_features_)) {
Log::Fatal("Size of feature_names error, should equal with total number of features");
......@@ -722,6 +780,7 @@ class Dataset {
/*! map feature (inner index) to its index in the list of numeric (non-categorical) features */
std::vector<int> numeric_feature_map_;
int num_numeric_features_;
std::string parser_config_str_;
};
} // namespace LightGBM
......
......@@ -8,6 +8,7 @@
#if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__)))
#include <LightGBM/utils/common_legacy_solaris.h>
#endif
#include <LightGBM/utils/json11.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
......@@ -62,6 +63,8 @@ namespace LightGBM {
namespace Common {
using json11::Json;
/*!
* Imbues the stream with the C locale.
*/
......@@ -200,6 +203,28 @@ inline static std::vector<std::string> Split(const char* c_str, const char* deli
return ret;
}
inline static std::string GetFromParserConfig(std::string config_str, std::string key) {
// parser config should follow json format.
std::string err;
Json config_json = Json::parse(config_str, &err);
if (!err.empty()) {
Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str());
}
return config_json[key].string_value();
}
inline static std::string SaveToParserConfig(std::string config_str, std::string key, std::string value) {
std::string err;
Json config_json = Json::parse(config_str, &err);
if (!err.empty()) {
Log::Fatal("Invalid parser config: %s. Please check if follow json format.", err.c_str());
}
CHECK(config_json.is_object());
std::map<std::string, Json> config_map = config_json.object_items();
config_map.insert(std::pair<std::string, Json>(key, Json(value)));
return Json(config_map).dump();
}
template<typename T>
inline static const char* Atoi(const char* p, T* out) {
int sign;
......
......@@ -84,6 +84,17 @@ class TextReader {
* \return Text data, store in std::vector by line
*/
inline std::vector<std::string>& Lines() { return lines_; }
/*!
* \brief Get joined text data that read from file
* \return Text data, store in std::string, joined all lines by delimiter
*/
inline std::string JoinedLines(std::string delimiter = "\n") {
std::stringstream ss;
for (auto line : lines_) {
ss << line << delimiter;
}
return ss.str();
}
INDEX_T ReadAllAndProcess(const std::function<void(INDEX_T, const char*, size_t)>& process_fun) {
last_line_ = "";
......
......@@ -8,6 +8,7 @@
#include <LightGBM/boosting.h>
#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>
......@@ -167,7 +168,7 @@ class Predictor {
}
auto label_idx = header ? -1 : boosting_->LabelIdx();
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(data_filename, header, boosting_->MaxFeatureIdx() + 1, label_idx,
precise_float_parser));
precise_float_parser, boosting_->ParserConfigStr()));
if (parser == nullptr) {
Log::Fatal("Could not recognize the data format of data file %s", data_filename);
......@@ -179,7 +180,8 @@ class Predictor {
TextReader<data_size_t> predict_data_reader(data_filename, header);
std::vector<int> feature_remapper(parser->NumFeatures(), -1);
bool need_adjust = false;
if (header) {
// skip raw feature remapping if trained model has parser config str which may contain actual feature names.
if (header && boosting_->ParserConfigStr().empty()) {
std::string first_line = predict_data_reader.first_line();
std::vector<std::string> header_words = Common::Split(first_line.c_str(), "\t,");
std::unordered_map<std::string, int> header_mapper;
......
......@@ -120,6 +120,8 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
monotone_constraints_ = config->monotone_constraints;
// get parser config file content
parser_config_str_ = train_data_->parser_config_str();
// if need bagging, create buffer
ResetBaggingConfig(config_.get(), true);
......@@ -730,6 +732,7 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
label_idx_ = train_data_->label_idx();
feature_names_ = train_data_->feature_names();
feature_infos_ = train_data_->feature_infos();
parser_config_str_ = train_data_->parser_config_str();
tree_learner_->ResetTrainingData(train_data, is_constant_hessian_);
ResetBaggingConfig(config_.get(), true);
......
......@@ -394,6 +394,8 @@ class GBDT : public GBDTBase {
bool IsLinear() const override { return linear_tree_; }
inline std::string ParserConfigStr() const override {return parser_config_str_;}
protected:
virtual bool GetIsConstHessian(const ObjectiveFunction* objective_function) {
if (objective_function != nullptr) {
......@@ -483,6 +485,8 @@ class GBDT : public GBDTBase {
std::vector<std::unique_ptr<Tree>> models_;
/*! \brief Max feature index of training data*/
int max_feature_idx_;
/*! \brief Parser config file content */
std::string parser_config_str_ = "";
#ifdef USE_CUDA
/*! \brief First order derivative of training data */
......
......@@ -16,7 +16,7 @@
namespace LightGBM {
const char* kModelVersion = "v3";
const char* kModelVersion = "v4";
std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const {
std::stringstream str_buf;
......@@ -399,6 +399,11 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int
ss << loaded_parameter_ << "\n";
ss << "end of parameters" << '\n';
}
if (!parser_config_str_.empty()) {
ss << "\nparser:" << '\n';
ss << parser_config_str_ << "\n";
ss << "end of parser" << '\n';
}
return ss.str();
}
......@@ -568,7 +573,7 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
num_init_iteration_ = num_iteration_for_pred_;
iter_ = 0;
bool is_inparameter = false;
bool is_inparameter = false, is_inparser = false;
std::stringstream ss;
Common::C_stringstream(ss);
while (p < end) {
......@@ -594,6 +599,28 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if (!ss.str().empty()) {
loaded_parameter_ = ss.str();
}
ss.clear();
ss.str("");
while (p < end) {
auto line_len = Common::GetLine(p);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parser:")) {
is_inparser = true;
} else if (cur_line == std::string("end of parser")) {
p += line_len;
p = Common::SkipNewLine(p);
break;
} else if (is_inparser) {
ss << cur_line << "\n";
}
}
p += line_len;
p = Common::SkipNewLine(p);
}
parser_config_str_ = ss.str();
ss.clear();
ss.str("");
return true;
}
......
......@@ -272,6 +272,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"forcedbins_filename",
"save_binary",
"precise_float_parser",
"parser_config_file",
"start_iteration_predict",
"num_iteration_predict",
"predict_raw_score",
......@@ -540,6 +541,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetBool(params, "precise_float_parser", &precise_float_parser);
GetString(params, "parser_config_file", &parser_config_file);
GetInt(params, "start_iteration_predict", &start_iteration_predict);
GetInt(params, "num_iteration_predict", &num_iteration_predict);
......@@ -723,6 +726,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
str_buf << "[precise_float_parser: " << precise_float_parser << "]\n";
str_buf << "[parser_config_file: " << parser_config_file << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
......
......@@ -42,6 +42,18 @@ void DatasetLoader::SetHeader(const char* filename) {
if (config_.header) {
std::string first_line = text_reader.first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t,");
} else if (!config_.parser_config_file.empty()) {
// support to get header from parser config, so could utilize following label name to id mapping logic.
TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
parser_config_reader.ReadAllLines();
std::string parser_config_str = parser_config_reader.JoinedLines();
if (!parser_config_str.empty()) {
std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
if (!header_in_parser_config.empty()) {
Log::Info("Get raw column names from parser config.");
feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
}
}
}
// load label idx first
......@@ -71,6 +83,15 @@ void DatasetLoader::SetHeader(const char* filename) {
}
}
if (!config_.parser_config_file.empty()) {
// if parser config file exists, feature names may be changed after customized parser applied.
// clear here so could use default filled feature names during dataset construction.
// may improve by saving real feature names defined in parser in the future.
if (!feature_names_.empty()) {
feature_names_.clear();
}
}
if (!feature_names_.empty()) {
// erase label column name
feature_names_.erase(feature_names_.begin() + label_idx_);
......@@ -196,8 +217,9 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
auto bin_filename = CheckCanLoadFromBin(filename);
bool is_load_from_binary = false;
if (bin_filename.size() == 0) {
dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_);
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
config_.precise_float_parser));
config_.precise_float_parser, dataset->parser_config_str_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
......@@ -257,8 +279,6 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
return dataset.release();
}
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
......@@ -269,7 +289,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
auto bin_filename = CheckCanLoadFromBin(filename);
if (bin_filename.size() == 0) {
auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
config_.precise_float_parser));
config_.precise_float_parser, dataset->parser_config_str_));
if (parser == nullptr) {
Log::Fatal("Could not recognize data format of %s", filename);
}
......@@ -1010,7 +1030,11 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
categorical_features_);
// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
// skip label check if user input parser config file,
// because label id is got from raw features while dataset features are consistent with customized parser.
if (dataset->parser_config_str_.empty()) {
CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
}
CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);
......@@ -1383,8 +1407,6 @@ std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
}
}
std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
const std::unordered_set<int>& categorical_features) {
std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
......
......@@ -4,8 +4,10 @@
*/
#include "parser.hpp"
#include <functional>
#include <string>
#include <algorithm>
#include <map>
#include <memory>
namespace LightGBM {
......@@ -230,6 +232,30 @@ DataType GetDataType(const char* filename, bool header,
return type;
}
// parser factory implementation.
ParserFactory& ParserFactory::getInstance() {
static ParserFactory factory;
return factory;
}
void ParserFactory::Register(std::string class_name, std::function<Parser*(std::string)> m_objc) {
if (m_objc) {
object_map_.insert(
std::map<std::string, std::function<Parser*(std::string)>>::value_type(class_name, m_objc));
}
}
Parser* ParserFactory::getObject(std::string class_name, std::string config_str) {
std::map<std::string, std::function<Parser*(std::string)>>::const_iterator iter =
object_map_.find(class_name);
if (iter != object_map_.end()) {
return iter->second(config_str);
} else {
Log::Fatal("Cannot find parser class '%s', please register first or check config format.", class_name.c_str());
return nullptr;
}
}
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser) {
const int n_read_line = 32;
auto lines = ReadKLineFromFile(filename, header, n_read_line);
......@@ -258,4 +284,34 @@ Parser* Parser::CreateParser(const char* filename, bool header, int num_features
return ret.release();
}
Parser* Parser::CreateParser(const char* filename, bool header, int num_features, int label_idx, bool precise_float_parser, std::string parser_config_str) {
// customized parser add-on.
if (!parser_config_str.empty()) {
std::unique_ptr<Parser> ret;
std::string class_name = Common::GetFromParserConfig(parser_config_str, "className");
Log::Info("Custom parser class name: %s", class_name.c_str());
Parser* p = ParserFactory::getInstance().getObject(class_name, parser_config_str);
ret.reset(p);
return ret.release();
}
return CreateParser(filename, header, num_features, label_idx, precise_float_parser);
}
std::string Parser::GenerateParserConfigStr(const char* filename, const char* parser_config_filename, bool header, int label_idx) {
TextReader<data_size_t> parser_config_reader(parser_config_filename, false);
parser_config_reader.ReadAllLines();
std::string parser_config_str = parser_config_reader.JoinedLines();
if (!parser_config_str.empty()) {
// save header to parser config in case needed.
if (header && Common::GetFromParserConfig(parser_config_str, "header").empty()) {
TextReader<data_size_t> text_reader(filename, header);
parser_config_str = Common::SaveToParserConfig(parser_config_str, "header", text_reader.first_line());
}
// save label id to parser config in case needed.
if (Common::GetFromParserConfig(parser_config_str, "labelId").empty()) {
parser_config_str = Common::SaveToParserConfig(parser_config_str, "labelId", std::to_string(label_idx));
}
}
return parser_config_str;
}
} // namespace LightGBM
......@@ -557,3 +557,15 @@ def test_init_score_for_multiclass_classification(init_score_type):
ds = lgb.Dataset(data, init_score=init_score).construct()
np.testing.assert_equal(ds.get_field('init_score'), init_score)
np.testing.assert_equal(ds.init_score, init_score)
def test_smoke_custom_parser(tmp_path):
data_path = Path(__file__).absolute().parents[2] / 'examples' / 'binary_classification' / 'binary.train'
parser_config_file = tmp_path / 'parser.ini'
with open(parser_config_file, 'w') as fout:
fout.write('{"className": "dummy", "id": "1"}')
data = lgb.Dataset(data_path, params={"parser_config_file": parser_config_file})
with pytest.raises(lgb.basic.LightGBMError,
match="Cannot find parser class 'dummy', please register first or check config format"):
data.construct()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment