Commit f42e6c36 authored by Guolin Ke's avatar Guolin Ke
Browse files

[R] fix EncodeChar

parent 8a5ec366
...@@ -273,23 +273,21 @@ public: ...@@ -273,23 +273,21 @@ public:
return ret; return ret;
} }
#pragma warning(disable : 4996)
int GetEvalNames(char** out_strs) const { int GetEvalNames(char** out_strs) const {
int idx = 0; int idx = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) { for (const auto& name : metric->GetName()) {
std::strcpy(out_strs[idx], name.c_str()); std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx; ++idx;
} }
} }
return idx; return idx;
} }
#pragma warning(disable : 4996)
int GetFeatureNames(char** out_strs) const { int GetFeatureNames(char** out_strs) const {
int idx = 0; int idx = 0;
for (const auto& name : boosting_->FeatureNames()) { for (const auto& name : boosting_->FeatureNames()) {
std::strcpy(out_strs[idx], name.c_str()); std::memcpy(out_strs[idx], name.c_str(), name.size() + 1);
++idx; ++idx;
} }
return idx; return idx;
...@@ -719,7 +717,6 @@ int LGBM_DatasetSetFeatureNames( ...@@ -719,7 +717,6 @@ int LGBM_DatasetSetFeatureNames(
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_DatasetGetFeatureNames( int LGBM_DatasetGetFeatureNames(
DatasetHandle handle, DatasetHandle handle,
char** feature_names, char** feature_names,
...@@ -729,7 +726,7 @@ int LGBM_DatasetGetFeatureNames( ...@@ -729,7 +726,7 @@ int LGBM_DatasetGetFeatureNames(
auto inside_feature_name = dataset->feature_names(); auto inside_feature_name = dataset->feature_names();
*num_feature_names = static_cast<int>(inside_feature_name.size()); *num_feature_names = static_cast<int>(inside_feature_name.size());
for (int i = 0; i < *num_feature_names; ++i) { for (int i = 0; i < *num_feature_names; ++i) {
std::strcpy(feature_names[i], inside_feature_name[i].c_str()); std::memcpy(feature_names[i], inside_feature_name[i].c_str(), inside_feature_name[i].size() + 1);
} }
API_END(); API_END();
} }
...@@ -1138,7 +1135,6 @@ int LGBM_BoosterSaveModel(BoosterHandle handle, ...@@ -1138,7 +1135,6 @@ int LGBM_BoosterSaveModel(BoosterHandle handle,
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_BoosterSaveModelToString(BoosterHandle handle, int LGBM_BoosterSaveModelToString(BoosterHandle handle,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
...@@ -1149,12 +1145,11 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle, ...@@ -1149,12 +1145,11 @@ int LGBM_BoosterSaveModelToString(BoosterHandle handle,
std::string model = ref_booster->SaveModelToString(num_iteration); std::string model = ref_booster->SaveModelToString(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str()); std::memcpy(out_str, model.c_str(), *out_len);
} }
API_END(); API_END();
} }
#pragma warning(disable : 4996)
int LGBM_BoosterDumpModel(BoosterHandle handle, int LGBM_BoosterDumpModel(BoosterHandle handle,
int num_iteration, int num_iteration,
int64_t buffer_len, int64_t buffer_len,
...@@ -1165,7 +1160,7 @@ int LGBM_BoosterDumpModel(BoosterHandle handle, ...@@ -1165,7 +1160,7 @@ int LGBM_BoosterDumpModel(BoosterHandle handle,
std::string model = ref_booster->DumpModel(num_iteration); std::string model = ref_booster->DumpModel(num_iteration);
*out_len = static_cast<int64_t>(model.size()) + 1; *out_len = static_cast<int64_t>(model.size()) + 1;
if (*out_len <= buffer_len) { if (*out_len <= buffer_len) {
std::strcpy(out_str, model.c_str()); std::memcpy(out_str, model.c_str(), *out_len);
} }
API_END(); API_END();
} }
......
...@@ -33,15 +33,14 @@ ...@@ -33,15 +33,14 @@
using namespace LightGBM; using namespace LightGBM;
LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len) { LGBM_SE EncodeChar(LGBM_SE dest, const char* src, LGBM_SE buf_len, LGBM_SE actual_len) {
int str_len = static_cast<int>(std::strlen(src)); size_t str_len = std::strlen(src);
R_INT_PTR(actual_len)[0] = str_len; if (str_len > INT32_MAX) {
Log::Fatal("Don't support large string in R-package.");
}
R_INT_PTR(actual_len)[0] = static_cast<int>(str_len);
if (R_AS_INT(buf_len) < str_len) { return dest; } if (R_AS_INT(buf_len) < str_len) { return dest; }
auto ptr = R_CHAR_PTR(dest); auto ptr = R_CHAR_PTR(dest);
int i = 0; std::memcpy(ptr, src, str_len);
while (src[i] != '\0') {
ptr[i] = src[i];
++i;
}
return dest; return dest;
} }
...@@ -604,15 +603,7 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle, ...@@ -604,15 +603,7 @@ LGBM_SE LGBM_BoosterSaveModelToString_R(LGBM_SE handle,
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterSaveModelToString(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
R_API_END(); R_API_END();
} }
...@@ -626,14 +617,6 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle, ...@@ -626,14 +617,6 @@ LGBM_SE LGBM_BoosterDumpModel_R(LGBM_SE handle,
int64_t out_len = 0; int64_t out_len = 0;
std::vector<char> inner_char_buf(R_AS_INT(buffer_len)); std::vector<char> inner_char_buf(R_AS_INT(buffer_len));
CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data())); CHECK_CALL(LGBM_BoosterDumpModel(R_GET_PTR(handle), R_AS_INT(num_iteration), R_AS_INT(buffer_len), &out_len, inner_char_buf.data()));
if (out_len < R_AS_INT(buffer_len)) {
EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len); EncodeChar(out_str, inner_char_buf.data(), buffer_len, actual_len);
} else {
if (out_len <= INT32_MAX) {
R_INT_PTR(actual_len)[0] = static_cast<int>(out_len);
} else {
Log::Fatal("Don't support large model in R package.");
}
}
R_API_END(); R_API_END();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment