"docs/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "eb50defff206eb1d36d0739e42cee6a802a03650"
python_binding.cc 1.96 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file python_binding.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL sparse library Python binding.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

10
#include <sparse/elementwise_op.h>
11
#include <sparse/matrix_ops.h>
12
#include <sparse/reduction.h>
13
#include <sparse/sddmm.h>
14
#include <sparse/softmax.h>
15
#include <sparse/sparse_matrix.h>
16
#include <sparse/spmm.h>
17
#include <sparse/spspmm.h>
18
19
#include <torch/custom_class.h>
#include <torch/script.h>
20
21
22
23
24
25
26
27
28
29
30

namespace dgl {
namespace sparse {

TORCH_LIBRARY(dgl_sparse, m) {
  m.class_<SparseMatrix>("SparseMatrix")
      .def("val", &SparseMatrix::value)
      .def("nnz", &SparseMatrix::nnz)
      .def("device", &SparseMatrix::device)
      .def("shape", &SparseMatrix::shape)
      .def("coo", &SparseMatrix::COOTensors)
31
      .def("indices", &SparseMatrix::Indices)
32
      .def("csr", &SparseMatrix::CSRTensors)
33
      .def("csc", &SparseMatrix::CSCTensors)
34
35
      .def("transpose", &SparseMatrix::Transpose)
      .def("coalesce", &SparseMatrix::Coalesce)
36
      .def("has_duplicate", &SparseMatrix::HasDuplicate)
37
38
      .def("is_diag", &SparseMatrix::HasDiag)
      .def("index_select", &SparseMatrix::IndexSelect)
39
40
      .def("range_select", &SparseMatrix::RangeSelect)
      .def("sample", &SparseMatrix::Sample);
41
42
43
  m.def("from_coo", &SparseMatrix::FromCOO)
      .def("from_csr", &SparseMatrix::FromCSR)
      .def("from_csc", &SparseMatrix::FromCSC)
44
      .def("from_diag", &SparseMatrix::FromDiag)
45
      .def("spsp_add", &SpSpAdd)
czkkkkkk's avatar
czkkkkkk committed
46
      .def("spsp_mul", &SpSpMul)
czkkkkkk's avatar
czkkkkkk committed
47
      .def("spsp_div", &SpSpDiv)
48
49
50
51
52
53
      .def("reduce", &Reduce)
      .def("sum", &ReduceSum)
      .def("smean", &ReduceMean)
      .def("smin", &ReduceMin)
      .def("smax", &ReduceMax)
      .def("sprod", &ReduceProd)
54
      .def("val_like", &SparseMatrix::ValLike)
55
      .def("spmm", &SpMM)
56
      .def("sddmm", &SDDMM)
57
      .def("softmax", &Softmax)
58
59
      .def("spspmm", &SpSpMM)
      .def("compact", &Compact);
60
61
62
63
}

}  // namespace sparse
}  // namespace dgl