Commit 60c77487 authored by Guolin Ke's avatar Guolin Ke
Browse files

fill nan by 0 in c_api.

parent 57ad0149
......@@ -1094,6 +1094,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
}
return ret;
};
......@@ -1102,6 +1105,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
}
return ret;
};
......@@ -1114,6 +1120,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
auto tmp_ptr = data_ptr + static_cast<size_t>(num_col) * row_idx;
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(tmp_ptr + i));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
}
return ret;
};
......@@ -1122,6 +1131,9 @@ RowFunctionFromDenseMatric(const void* data, int num_row, int num_col, int data_
std::vector<double> ret(num_col);
for (int i = 0; i < num_col; ++i) {
ret[i] = static_cast<double>(*(data_ptr + static_cast<size_t>(num_row) * i + row_idx));
if (std::isnan(ret[i])) {
ret[i] = 0.0f;
}
}
return ret;
};
......@@ -1159,8 +1171,10 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) {
if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
}
return ret;
};
} else if (indptr_type == C_API_DTYPE_INT64) {
......@@ -1170,8 +1184,10 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) {
if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
}
return ret;
};
}
......@@ -1184,8 +1200,10 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) {
if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
}
return ret;
};
} else if (indptr_type == C_API_DTYPE_INT64) {
......@@ -1195,8 +1213,10 @@ RowFunctionFromCSR(const void* indptr, int indptr_type, const int32_t* indices,
int64_t start = ptr_indptr[idx];
int64_t end = ptr_indptr[idx + 1];
for (int64_t i = start; i < end; ++i) {
if (!std::isnan(data_ptr[i])) {
ret.emplace_back(indices[i], data_ptr[i]);
}
}
return ret;
};
}
......@@ -1220,6 +1240,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val);
};
} else if (col_ptr_type == C_API_DTYPE_INT64) {
......@@ -1233,6 +1254,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val);
};
}
......@@ -1249,6 +1271,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val);
};
} else if (col_ptr_type == C_API_DTYPE_INT64) {
......@@ -1262,6 +1285,7 @@ IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* ind
}
int idx = static_cast<int>(indices[i]);
double val = static_cast<double>(data_ptr[i]);
if (std::isnan(val)) { val = 0.0f; }
return std::make_pair(idx, val);
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment