Unverified Commit 9cd0d3f8 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Aten] Add nonzero op (#1746)



* add nonzero op

* 111

* fix doc
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent cc73c60c
......@@ -127,6 +127,9 @@ IdArray NE(int64_t lhs, IdArray rhs);
/*! \brief Stack two arrays (of len L) into a 2*L length array */
IdArray HStack(IdArray arr1, IdArray arr2);
/*! \brief Return the indices of the elements that are non-zero. */
IdArray NonZero(BoolArray bool_arr);
/*!
* \brief Return the data under the index. In numpy notation, A[I]
* \tparam ValueType The type of return value.
......
......@@ -84,6 +84,16 @@ IdArray HStack(IdArray lhs, IdArray rhs) {
return ret;
}
IdArray NonZero(BoolArray bool_arr) {
IdArray ret;
ATEN_XPU_SWITCH(bool_arr->ctx.device_type, XPU, "NonZero", {
ATEN_ID_TYPE_SWITCH(bool_arr->dtype, IdType, {
ret = impl::NonZero<XPU, IdType>(bool_arr);
});
});
return ret;
}
NDArray IndexSelect(NDArray array, IdArray index) {
NDArray ret;
CHECK_SAME_CONTEXT(array, index);
......
......@@ -46,6 +46,9 @@ NDArray IndexSelect(NDArray array, IdArray index);
template <DLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);
......
......@@ -236,6 +236,26 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// NonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray NonZero(BoolArray bool_arr) {
const IdType* bool_data = static_cast<IdType*>(bool_arr->data);
CHECK(bool_arr->ndim == 1) << "NonZero only supports 1D array";
std::vector<IdType> nonzero_indices;
for (int64_t i = 0; i < bool_arr->shape[0]; i++) {
if ((bool_data[i]) != 0) {
nonzero_indices.push_back(i);
}
}
return VecToIdArray(nonzero_indices, sizeof(IdType) * 8);
}
// TODO(Allen): Implement GPU version
template IdArray NonZero<kDLCPU, int32_t>(BoolArray bool_arr);
template IdArray NonZero<kDLCPU, int64_t>(BoolArray bool_arr);
} // namespace impl
} // namespace aten
} // namespace dgl
......@@ -68,7 +68,7 @@ bool SaveDGLGraphs(std::string filename, List<GraphData> graph_data,
std::vector<NamedTensor> labels_list) {
auto fs = std::unique_ptr<SeekStream>(dynamic_cast<SeekStream *>(
SeekStream::Create(filename.c_str(), "w", true)));
CHECK(fs) << "File name is not a valid local file name";
CHECK(fs) << "File name " << filename << " is not a valid local file name";
// Write DGL MetaData
const uint64_t kVersion = 1;
......
......@@ -612,3 +612,16 @@ TEST(ArrayTest, CumSum) {
_TestCumSum<int64_t>(GPU);
#endif
}
template <typename IDX>
void _TestNonZero() {
BoolArray a = aten::VecToIdArray(std::vector<IDX>({1, 0, 1, 1, 0, 0, 1}));
IdArray indices = aten::NonZero(a);
IdArray expected = aten::VecToIdArray(std::vector<IDX>({0, 2, 3, 6}));
ASSERT_TRUE(ArrayEQ<IDX>(indices, expected));
}
TEST(ArrayTest, NonZero) {
_TestNonZero<int32_t>();
_TestNonZero<int64_t>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment