Commit 60c77487 authored by Guolin Ke's avatar Guolin Ke
Browse files

fill nan by 0 in c_api.

parent 57ad0149
...@@ -1094,6 +1094,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1094,6 +1094,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i)); ret[i] = static_cast<double>(*(tmp_ptr + i));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
} }
return ret; return ret;
}; };
...@@ -1102,6 +1105,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1102,6 +1105,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
} }
return ret; return ret;
}; };
...@@ -1114,6 +1120,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1114,6 +1120,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx; auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i)); ret[i] = static_cast<double>(*(tmp_ptr + i));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
} }
return ret; return ret;
}; };
...@@ -1122,6 +1131,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_ ...@@ -1122,6 +1131,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
std::vector<double> ret(num_col); std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) { for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx)); ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
} }
return ret; return ret;
}; };
...@@ -1159,7 +1171,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1159,7 +1171,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) { for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]); if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
} }
return ret; return ret;
}; };
...@@ -1170,7 +1184,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1170,7 +1184,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) { for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]); if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
} }
return ret; return ret;
}; };
...@@ -1184,7 +1200,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1184,7 +1200,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) { for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]); if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
} }
return ret; return ret;
}; };
...@@ -1195,7 +1213,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices, ...@@ -1195,7 +1213,9 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx]; int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1]; int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) { for (int64_t i = start; i < end; ++i) {
ret.emplace_back(indices[i], data_ptr[i]); if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
} }
return ret; return ret;
}; };
...@@ -1220,6 +1240,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1220,6 +1240,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
} }
int idx = static_cast<int>(indices[i]); int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]); double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val); return std::make_pair(idx, val);
}; };
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
...@@ -1233,6 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1233,6 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
} }
int idx = static_cast<int>(indices[i]); int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]); double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val); return std::make_pair(idx, val);
}; };
} }
...@@ -1249,6 +1271,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1249,6 +1271,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
} }
int idx = static_cast<int>(indices[i]); int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]); double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val); return std::make_pair(idx, val);
}; };
} else if (col_ptr_type == C_API_DTYPE_INT64) { } else if (col_ptr_type == C_API_DTYPE_INT64) {
...@@ -1262,6 +1285,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind ...@@ -1262,6 +1285,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
} }
int idx = static_cast<int>(indices[i]); int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]); double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val); return std::make_pair(idx, val);
}; };
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment