Unverified Commit 7c059e86 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Sparse] Add sparse matrix C++ implementation (#4773)

* [Sparse] Add sparse matrix C++ implementation

* Add documentation

* Update

* Minor fix

* Move Python code to dgl/mock_sparse2

* Move headers to include

* lint

* Update

* Add dgl_sparse directory

* Move src code to dgl_sparse

* Add __init__.py in tests to avoid naming conflict

* Add dgl sparse so in Jenkinsfile

* Complete docstring & SparseMatrix basic op

* lint

* Disable win tests
parent df089424
...@@ -315,6 +315,24 @@ if(BUILD_TORCH) ...@@ -315,6 +315,24 @@ if(BUILD_TORCH)
add_dependencies(dgl tensoradapter_pytorch) add_dependencies(dgl tensoradapter_pytorch)
endif(BUILD_TORCH) endif(BUILD_TORCH)
if(BUILD_SPARSE)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)
file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)
# TODO(zhenkun): MSVC support?
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/dgl_sparse/build.sh BUILD_SCRIPT)
add_custom_target(
dgl_sparse
${CMAKE_COMMAND} -E env
CMAKE_COMMAND=${CMAKE_CMD}
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${CMAKE_CURRENT_BINARY_DIR}
bash ${BUILD_SCRIPT} ${TORCH_PYTHON_INTERPS}
DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/dgl_sparse)
add_dependencies(dgl dgl_sparse)
endif(BUILD_SPARSE)
# Installation rules # Installation rules
install(TARGETS dgl DESTINATION lib${LIB_SUFFIX}) install(TARGETS dgl DESTINATION lib${LIB_SUFFIX})
......
#!/usr/bin/env groovy #!/usr/bin/env groovy
dgl_linux_libs = 'build/libdgl.so, build/runUnitTests, python/dgl/_ffi/_cy3/core.cpython-*-x86_64-linux-gnu.so, build/tensoradapter/pytorch/*.so' dgl_linux_libs = 'build/libdgl.so, build/runUnitTests, python/dgl/_ffi/_cy3/core.cpython-*-x86_64-linux-gnu.so, build/tensoradapter/pytorch/*.so, build/dgl_sparse/*.so'
// Currently DGL on Windows is not working with Cython yet // Currently DGL on Windows is not working with Cython yet
dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll" dgl_win64_libs = "build\\dgl.dll, build\\runUnitTests.exe, build\\tensoradapter\\pytorch\\*.dll"
......
cmake_minimum_required(VERSION 3.5)
project(dgl_sparse C CXX)
# Find PyTorch cmake files and PyTorch versions with the python interpreter $PYTHON_INTERP
# ("python3" or "python" if empty)
if(NOT PYTHON_INTERP)
find_program(PYTHON_INTERP NAMES python3 python)
endif()
message(STATUS "Using Python interpreter: ${PYTHON_INTERP}")
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/find_cmake.py FIND_CMAKE_PY)
execute_process(
COMMAND ${PYTHON_INTERP} ${FIND_CMAKE_PY}
OUTPUT_VARIABLE TORCH_PREFIX_VER
OUTPUT_STRIP_TRAILING_WHITESPACE)
message(STATUS "find_cmake.py output: ${TORCH_PREFIX_VER}")
list(GET TORCH_PREFIX_VER 0 TORCH_PREFIX)
list(GET TORCH_PREFIX_VER 1 TORCH_VER)
message(STATUS "Configuring for PyTorch ${TORCH_VER}")
if(USE_CUDA)
add_definitions(-DDGL_USE_CUDA)
endif()
set(Torch_DIR "${TORCH_PREFIX}/Torch")
message(STATUS "Setting directory to ${Torch_DIR}")
find_package(Torch REQUIRED)
set(LIB_DGL_SPARSE_NAME "dgl_sparse_pytorch_${TORCH_VER}")
set(SPARSE_DIR "${CMAKE_SOURCE_DIR}/src")
set(SPARSE_INCLUDE "${CMAKE_SOURCE_DIR}/include")
file(GLOB SPARSE_HEADERS ${SPARSE_INCLUDE})
file(GLOB SPARSE_SRC
${SPARSE_DIR}/*.cc
)
add_library(${LIB_DGL_SPARSE_NAME} SHARED ${SPARSE_SRC} ${SPARSE_HEADERS})
target_include_directories(
${LIB_DGL_SPARSE_NAME} PRIVATE ${SPARSE_DIR} ${SPARSE_HEADERS})
target_link_libraries(${LIB_DGL_SPARSE_NAME} "${TORCH_LIBRARIES}")
add_subdirectory("${CMAKE_SOURCE_DIR}/../third_party/dmlc-core" "${CMAKE_SOURCE_DIR}/build/third_party/dmlc-core")
target_include_directories(${LIB_DGL_SPARSE_NAME} PRIVATE "${CMAKE_SOURCE_DIR}/../third_party/dmlc-core/include")
target_link_libraries(${LIB_DGL_SPARSE_NAME} dmlc)
set(GOOGLE_TEST 0) # Turn off dmlc-core test
#!/bin/bash
# Helper script to build dgl sparse libraries
set -e
rm -rf build
mkdir -p build
mkdir -p $BINDIR/dgl_sparse
cd build
if [ $(uname) = 'Darwin' ]; then
CPSOURCE=*.dylib
else
CPSOURCE=*.so
fi
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DTORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST -DUSE_CUDA=$USE_CUDA"
if [ $# -eq 0 ]; then
$CMAKE_COMMAND $CMAKE_FLAGS ..
make -j
cp -v $CPSOURCE $BINDIR/dgl_sparse
else
for PYTHON_INTERP in $@; do
rm -rf *
$CMAKE_COMMAND $CMAKE_FLAGS -DPYTHON_INTERP=$PYTHON_INTERP ..
make -j
cp -v $CPSOURCE $BINDIR/dgl_sparse
done
fi
import os
import torch
cmake_prefix_path = getattr(
torch.utils,
"cmake_prefix_path",
os.path.join(os.path.dirname(torch.__file__), "share", "cmake"),
)
version = torch.__version__.split("+")[0]
print(";".join([cmake_prefix_path, version]))
/*!
* Copyright (c) 2022 by Contributors
* \file sparse/elementwise_op.h
* \brief DGL C++ sparse elementwise operators
*/
#ifndef SPARSE_ELEMENTWISE_OP_H_
#define SPARSE_ELEMENTWISE_OP_H_
#include <sparse/sparse_matrix.h>
#include <torch/custom_class.h>
namespace dgl {
namespace sparse {
// TODO(zhenkun): support addition of matrices with different sparsity.
/*!
* @brief Adds two sparse matrices. Currently does not support two matrices with
* different sparsity.
*
* @param A SparseMatrix
* @param B SparseMatrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B);
} // namespace sparse
} // namespace dgl
#endif // SPARSE_ELEMENTWISE_OP_H_
/*!
* Copyright (c) 2022 by Contributors
* @file sparse/sparse_matrix.h
* @brief DGL C++ sparse matrix header
*/
#ifndef SPARSE_SPARSE_MATRIX_H_
#define SPARSE_SPARSE_MATRIX_H_
#include <torch/custom_class.h>
#include <torch/script.h>
#include <memory>
#include <vector>
namespace dgl {
namespace sparse {
/*! @brief SparseFormat enumeration */
enum SparseFormat { kCOO, kCSR, kCSC };
/*! @brief CSR sparse structure */
struct CSR {
// CSR format index pointer array of the matrix
torch::Tensor indptr;
// CSR format index array of the matrix
torch::Tensor indices;
// The element order of the sparse format. In the SparseMatrix, we have data
// (value_) for each non-zero value. The order of non-zero values in (value_)
// may differ from the order of non-zero entries in CSR. So we store
// `value_indices` in CSR to indicate its relative non-zero value order to the
// SparseMatrix. With `value_indices`, we can retrieve the correct value for
// CSR, i.e., `value_[value_indices]`. If `value_indices` is not defined, this
// CSR follows the same non-zero value order as the SparseMatrix.
torch::optional<torch::Tensor> value_indices;
};
/*! @brief COO sparse structure */
struct COO {
// COO format row array of the matrix
torch::Tensor row;
// COO format column array of the matrix
torch::Tensor col;
};
/*! @brief SparseMatrix bound to Python */
class SparseMatrix : public torch::CustomClassHolder {
public:
/*!
* @brief General constructor to construct a sparse matrix for different
* sparse formats. At least one of the sparse formats should be provided,
* while others could be nullptrs.
*
* @param coo The COO format.
* @param csr The CSR format.
* @param csc The CSC format.
* @param value Value of the sparse matrix.
* @param shape Shape of the sparse matrix.
*/
SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape);
/*!
* @brief Construct a SparseMatrix from a COO format.
* @param coo The COO format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromCOO(
const std::shared_ptr<COO>& coo, torch::Tensor value,
const std::vector<int64_t>& shape);
/*!
* @brief Construct a SparseMatrix from a CSR format.
* @param csr The CSR format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromCSR(
const std::shared_ptr<CSR>& csr, torch::Tensor value,
const std::vector<int64_t>& shape);
/*!
* @brief Construct a SparseMatrix from a CSC format.
* @param csc The CSC format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromCSC(
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape);
/*! @return Value of the sparse matrix. */
inline torch::Tensor value() const { return value_; }
/*! @return Shape of the sparse matrix. */
inline const std::vector<int64_t>& shape() const { return shape_; }
/*! @return Number of non-zero values */
inline int64_t nnz() const { return value_.size(0); }
/*! @return Non-zero value data type */
inline caffe2::TypeMeta dtype() const { return value_.dtype(); }
/*! @return Device of the sparse matrix */
inline torch::Device device() const { return value_.device(); }
/*! @return COO of the sparse matrix. The COO is created if not exists. */
std::shared_ptr<COO> COOPtr();
/*! @return CSR of the sparse matrix. The CSR is created if not exists. */
std::shared_ptr<CSR> CSRPtr();
/*! @return CSC of the sparse matrix. The CSC is created if not exists. */
std::shared_ptr<CSR> CSCPtr();
/*! @brief Check whether this sparse matrix has COO format. */
inline bool HasCOO() const { return coo_ != nullptr; }
/*! @brief Check whether this sparse matrix has CSR format. */
inline bool HasCSR() const { return csr_ != nullptr; }
/*! @brief Check whether this sparse matrix has CSC format. */
inline bool HasCSC() const { return csc_ != nullptr; }
/*! @return {row, col, value} tensors in the COO format. */
std::vector<torch::Tensor> COOTensors();
/*! @return {row, col, value} tensors in the CSR format. */
std::vector<torch::Tensor> CSRTensors();
/*! @return {row, col, value} tensors in the CSC format. */
std::vector<torch::Tensor> CSCTensors();
private:
/*! @brief Create the COO format for the sparse matrix internally */
void _CreateCOO();
/*! @brief Create the CSR format for the sparse matrix internally */
void _CreateCSR();
/*! @brief Create the CSC format for the sparse matrix internally */
void _CreateCSC();
// COO/CSC/CSR pointers. Nullptr indicates non-existence.
std::shared_ptr<COO> coo_;
std::shared_ptr<CSR> csr_, csc_;
// Value of the SparseMatrix
torch::Tensor value_;
// Shape of the SparseMatrix
const std::vector<int64_t> shape_;
};
/*!
* @brief Create a SparseMatrix from tensors in COO format.
* @param row Row indices of the COO.
* @param col Column indices of the COO.
* @param value Values of the sparse matrix.
* @param shape Shape of the sparse matrix.
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
const std::vector<int64_t>& shape);
/*!
* @brief Create a SparseMatrix from tensors in CSR format.
* @param indptr Index pointer array of the CSR
* @param indices Indices array of the CSR
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape);
/*!
* @brief Create a SparseMatrix from tensors in CSC format.
* @param indptr Index pointer array of the CSC
* @param indices Indices array of the CSC
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape);
} // namespace sparse
} // namespace dgl
#endif // SPARSE_SPARSE_MATRIX_H_
/*!
* Copyright (c) 2022 by Contributors
* @file elementwise_op.cc
* @brief DGL C++ sparse elementwise operator implementation
*/
#include <dmlc/logging.h>
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/custom_class.h>
#include <torch/script.h>
#include <memory>
#include "./utils.h"
namespace dgl {
namespace sparse {
c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
auto fmt = FindAnyExistingFormat(A, B);
auto value = A->value() + B->value();
ElementwiseOpSanityCheck(A, B);
if (fmt == SparseFormat::kCOO) {
return SparseMatrix::FromCOO(A->COOPtr(), value, A->shape());
} else if (fmt == SparseFormat::kCSR) {
return SparseMatrix::FromCSR(A->CSRPtr(), value, A->shape());
} else {
return SparseMatrix::FromCSC(A->CSCPtr(), value, A->shape());
}
}
} // namespace sparse
} // namespace dgl
/*!
* Copyright (c) 2022 by Contributors
* @file python_binding.cc
* @brief DGL sparse library Python binding
*/
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/custom_class.h>
#include <torch/script.h>
namespace dgl {
namespace sparse {
TORCH_LIBRARY(dgl_sparse, m) {
m.class_<SparseMatrix>("SparseMatrix")
.def("val", &SparseMatrix::value)
.def("nnz", &SparseMatrix::nnz)
.def("device", &SparseMatrix::device)
.def("shape", &SparseMatrix::shape)
.def("coo", &SparseMatrix::COOTensors)
.def("csr", &SparseMatrix::CSRTensors)
.def("csc", &SparseMatrix::CSCTensors);
m.def("create_from_coo", &CreateFromCOO)
.def("create_from_csr", &CreateFromCSR)
.def("create_from_csc", &CreateFromCSC)
.def("spsp_add", &SpSpAdd);
}
} // namespace sparse
} // namespace dgl
/*!
* Copyright (c) 2022 by Contributors
* @file sparse_matrix.cc
* @brief DGL C++ sparse matrix implementations
*/
#include <dmlc/logging.h>
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
namespace dgl {
namespace sparse {
SparseMatrix::SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape)
: coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) {
CHECK(coo != nullptr || csr != nullptr || csc != nullptr)
<< "At least one of CSR/COO/CSC is provided to construct a "
"SparseMatrix";
CHECK_EQ(shape.size(), 2)
<< "The shape of a sparse matrix should be 2-dimensional";
// NOTE: Currently all the tensors of a SparseMatrix should on the same
// device. Do we allow the graph structure and values are on different
// devices?
if (coo != nullptr) {
CHECK_EQ(coo->row.dim(), 1);
CHECK_EQ(coo->col.dim(), 1);
CHECK_EQ(coo->row.size(0), coo->col.size(0));
CHECK_EQ(coo->row.size(0), value.size(0));
CHECK_EQ(coo->row.device(), value.device());
CHECK_EQ(coo->col.device(), value.device());
}
if (csr != nullptr) {
CHECK_EQ(csr->indptr.dim(), 1);
CHECK_EQ(csr->indices.dim(), 1);
CHECK_EQ(csr->indptr.size(0), shape[0] + 1);
CHECK_EQ(csr->indices.size(0), value.size(0));
CHECK_EQ(csr->indptr.device(), value.device());
CHECK_EQ(csr->indices.device(), value.device());
}
if (csc != nullptr) {
CHECK_EQ(csc->indptr.dim(), 1);
CHECK_EQ(csc->indices.dim(), 1);
CHECK_EQ(csc->indptr.size(0), shape[1] + 1);
CHECK_EQ(csc->indices.size(0), value.size(0));
CHECK_EQ(csc->indptr.device(), value.device());
CHECK_EQ(csc->indices.device(), value.device());
}
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
const std::shared_ptr<COO>& coo, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(coo, nullptr, nullptr, value, shape);
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSR(
const std::shared_ptr<CSR>& csr, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, csr, nullptr, value, shape);
}
c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, nullptr, csc, value, shape);
}
std::shared_ptr<COO> SparseMatrix::COOPtr() {
if (coo_ == nullptr) {
_CreateCOO();
}
return coo_;
}
std::shared_ptr<CSR> SparseMatrix::CSRPtr() {
if (csr_ == nullptr) {
_CreateCSR();
}
return csr_;
}
std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
if (csc_ == nullptr) {
_CreateCSC();
}
return csc_;
}
std::vector<torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr();
auto val = value();
return {coo->row, coo->col, val};
}
std::vector<torch::Tensor> SparseMatrix::CSRTensors() {
auto csr = CSRPtr();
auto val = value();
if (csr->value_indices.has_value()) {
val = val[csr->value_indices.value()];
}
return {csr->indptr, csr->indices, val};
}
std::vector<torch::Tensor> SparseMatrix::CSCTensors() {
auto csc = CSCPtr();
auto val = value();
if (csc->value_indices.has_value()) {
val = val[csc->value_indices.value()];
}
return {csc->indptr, csc->indices, val};
}
// TODO(zhenkun): format conversion
void SparseMatrix::_CreateCOO() {}
void SparseMatrix::_CreateCSR() {}
void SparseMatrix::_CreateCSC() {}
c10::intrusive_ptr<SparseMatrix> CreateFromCOO(
torch::Tensor row, torch::Tensor col, torch::Tensor value,
const std::vector<int64_t>& shape) {
auto coo = std::make_shared<COO>(COO{row, col});
return SparseMatrix::FromCOO(coo, value, shape);
}
c10::intrusive_ptr<SparseMatrix> CreateFromCSR(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape) {
auto csr = std::make_shared<CSR>(
CSR{indptr, indices, torch::optional<torch::Tensor>()});
return SparseMatrix::FromCSR(csr, value, shape);
}
c10::intrusive_ptr<SparseMatrix> CreateFromCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape) {
auto csc = std::make_shared<CSR>(
CSR{indptr, indices, torch::optional<torch::Tensor>()});
return SparseMatrix::FromCSC(csc, value, shape);
}
} // namespace sparse
} // namespace dgl
/*!
* Copyright (c) 2022 by Contributors
* @file utils.h
* @brief DGL C++ sparse API utilities
*/
#ifndef DGL_SPARSE_UTILS_H_
#define DGL_SPARSE_UTILS_H_
#include <dmlc/logging.h>
#include <sparse/sparse_matrix.h>
namespace dgl {
namespace sparse {
/*! @brief Find a proper sparse format for two sparse matrices. It chooses
* COO if anyone of the sparse matrices has COO format. If none of them has
* COO, it tries CSR and CSC in the same manner. */
inline static SparseFormat FindAnyExistingFormat(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
SparseFormat fmt;
if (A->HasCOO() || B->HasCOO()) {
fmt = SparseFormat::kCOO;
} else if (A->HasCSR() || B->HasCSR()) {
fmt = SparseFormat::kCSR;
} else {
fmt = SparseFormat::kCSC;
}
return fmt;
}
/*! @brief Check whether two matrices has the same dtype and shape for
* elementwise operators. */
inline static void ElementwiseOpSanityCheck(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
CHECK(A->value().dtype() == B->value().dtype())
<< "Elementwise operators do not support two sparse matrices with "
"different dtypes. ("
<< A->value().dtype() << " vs " << B->value().dtype() << ")";
CHECK(A->shape()[0] == B->shape()[0] && A->shape()[1] == B->shape()[1])
<< "Elementwise operator do not support two sparse matrices with "
"different shapes. (["
<< A->shape()[0] << ", " << A->shape()[1] << "] vs [" << B->shape()[0]
<< ", " << B->shape()[1] << "])";
}
} // namespace sparse
} // namespace dgl
#endif // DGL_SPARSE_UTILS_H_
"""dgl sparse class."""
import sys
import os
import torch
from .._ffi import libinfo
from .sparse_matrix import *
from .diag_matrix import *
from .elementwise_op import *
def load_dgl_sparse():
"""Load DGL C++ sparse library"""
version = torch.__version__.split("+", maxsplit=1)[0]
basename = f"libdgl_sparse_pytorch_{version}.so"
dirname = os.path.dirname(libinfo.find_lib_path()[0])
path = os.path.join(dirname, "dgl_sparse", basename)
try:
torch.classes.load_library(path)
except Exception: # pylint: disable=W0703
raise ImportError("Cannot load DGL C++ sparse library")
# TODO(zhenkun): support other platforms
if sys.platform.startswith("linux"):
load_dgl_sparse()
"""DGL diagonal matrix module."""
from typing import Optional, Tuple
import torch
from .sparse_matrix import SparseMatrix, create_from_coo
class DiagMatrix:
"""Diagonal Matrix Class
Parameters
----------
val : torch.Tensor
Diagonal of the matrix. It can take shape (N) or (N, D).
shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`val`, i.e.,
(N, N). Otherwise, :attr:`len(val)` must be equal to :attr:`min(shape)`.
Attributes
----------
val : torch.Tensor
Diagonal of the matrix.
shape : tuple[int, int]
Shape of the matrix.
"""
def __init__(
self, val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
):
len_val = len(val)
if shape is not None:
assert len_val == min(shape), (
f"Expect len(val) to be min(shape), got {len_val} for len(val)"
"and {shape} for shape."
)
else:
shape = (len_val, len_val)
self.val = val
self.shape = shape
def __repr__(self):
return f"DiagMatrix(val={self.val}, \nshape={self.shape})"
def __call__(self, x: torch.Tensor):
"""Create a new diagonal matrix with the same shape as self
but different values.
Parameters
----------
x : torch.Tensor
Values of the diagonal matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> import torch
>>> val = torch.ones(5)
>>> mat = diag(val)
>>> print(mat)
DiagMatrix(val=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5))
>>> val = torch.ones(5) + 1
>>> mat = mat(val)
>>> print(mat)
DiagMatrix(val=tensor([2., 2., 2., 2., 2.]),
shape=(5, 5))
"""
return diag(x, self.shape)
@property
def nnz(self) -> int:
"""Return the number of non-zero values in the matrix
Returns
-------
int
The number of non-zero values in the matrix
"""
return self.val.shape[0]
@property
def dtype(self) -> torch.dtype:
"""Return the data type of the matrix
Returns
-------
torch.dtype
Data type of the matrix
"""
return self.val.dtype
@property
def device(self) -> torch.device:
"""Return the device of the matrix
Returns
-------
torch.device
Device of the matrix
"""
return self.val.device
def as_sparse(self) -> SparseMatrix:
"""Convert the diagonal matrix into a sparse matrix object
Returns
-------
SparseMatrix
The converted sparse matrix object
Example
-------
>>> import torch
>>> val = torch.ones(5)
>>> mat = diag(val)
>>> sp_mat = mat.as_sparse()
>>> print(sp_mat)
SparseMatrix(indices=tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5), nnz=5)
"""
row = col = torch.arange(len(self.val)).to(self.device)
return create_from_coo(row=row, col=col, val=self.val, shape=self.shape)
def t(self):
"""Alias of :meth:`transpose()`"""
return self.transpose()
@property
def T(self): # pylint: disable=C0103
"""Alias of :meth:`transpose()`"""
return self.transpose()
def transpose(self):
"""Return the transpose of the matrix.
Returns
-------
DiagMatrix
The transpose of the matrix.
Example
--------
>>> val = torch.arange(1, 5).float()
>>> mat = diag(val, shape=(4, 5))
>>> mat = mat.transpose()
>>> print(mat)
DiagMatrix(val=tensor([1., 2., 3., 4.]),
shape=(5, 4))
"""
return DiagMatrix(self.val, self.shape[::-1])
def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
) -> DiagMatrix:
"""Create a diagonal matrix based on the diagonal values
Parameters
----------
val : torch.Tensor
Diagonal of the matrix. It can take shape (N) or (N, D).
shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`val`, i.e.,
(N, N). Otherwise, :attr:`len(val)` must be equal to :attr:`min(shape)`.
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
Case1: 5-by-5 diagonal matrix with scaler values on the diagonal
>>> import torch
>>> val = torch.ones(5)
>>> mat = diag(val)
>>> print(mat)
DiagMatrix(val=tensor([1., 1., 1., 1., 1.]),
shape=(5, 5))
Case2: 5-by-10 diagonal matrix with scaler values on the diagonal
>>> val = torch.ones(5)
>>> mat = diag(val, shape=(5, 10))
>>> print(mat)
DiagMatrix(val=tensor([1., 1., 1., 1., 1.]),
shape=(5, 10))
Case3: 5-by-5 diagonal matrix with tensor values on the diagonal
>>> val = torch.randn(5, 3)
>>> mat = diag(val)
>>> mat.shape
(5, 5)
>>> mat.nnz
5
"""
# NOTE(Mufei): this may not be needed if DiagMatrix is simple enough
return DiagMatrix(val, shape)
def identity(
shape: Tuple[int, int],
d: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> DiagMatrix:
"""Create a diagonal matrix with ones on the diagonal and zeros elsewhere
Parameters
----------
shape : tuple[int, int]
Shape of the matrix.
d : int, optional
If None, the diagonal entries will be scaler 1. Otherwise, the diagonal
entries will be a 1-valued tensor of shape (d).
dtype : torch.dtype, optional
The data type of the matrix
device : torch.device, optional
The device of the matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
Case1: 3-by-3 matrix with scaler diagonal values
[[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
>>> mat = identity(shape=(3, 3))
>>> print(mat)
DiagMatrix(val=tensor([1., 1., 1.]),
shape=(3, 3))
Case2: 3-by-5 matrix with scaler diagonal values
[[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0]]
>>> mat = identity(shape=(3, 5))
>>> print(mat)
DiagMatrix(val=tensor([1., 1., 1.]),
shape=(3, 5))
Case3: 3-by-3 matrix with tensor diagonal values
>>> mat = identity(shape=(3, 3), d=2)
>>> print(mat)
DiagMatrix(val=tensor([[1., 1.],
[1., 1.],
[1., 1.]]),
shape=(3, 3))
"""
len_val = min(shape)
if d is None:
val_shape = (len_val,)
else:
val_shape = (len_val, d)
val = torch.ones(val_shape, dtype=dtype, device=device)
return diag(val, shape)
"""DGL elementwise operator module."""
from typing import Union
from .diag_matrix import DiagMatrix
from .elementwise_op_diag import diag_add
from .elementwise_op_sp import sp_add
from .sparse_matrix import SparseMatrix
__all__ = ["add"]
def add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
"""Elementwise addition"""
if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
return diag_add(A, B)
return sp_add(A, B)
"""DGL elementwise operators for diagonal matrix module."""
from typing import Union
from .diag_matrix import DiagMatrix
__all__ = ["diag_add", "diag_sub", "diag_mul", "diag_div", "diag_power"]
def diag_add(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise addition.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix
Diagonal matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 + D2
DiagMatrix(val=tensor([11, 13, 15]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" " D2 {} must match.".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val + D2.val)
def diag_sub(D1: DiagMatrix, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise subtraction.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix
Diagonal matrix
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 -D2
DiagMatrix(val=tensor([-9, -9, -9]),
shape=(3, 3))
"""
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val - D2.val)
def diag_mul(
D1: Union[DiagMatrix, float], D2: Union[DiagMatrix, float]
) -> DiagMatrix:
"""Elementwise multiplication.
Parameters
----------
D1 : DiagMatrix or scalar
Diagonal matrix or scalar value
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
DiagMatrix(val=tensor([10, 22, 36]),
shape=(3, 3))
>>> D1 * 2.5
DiagMatrix(val=tensor([2.5000, 5.0000, 7.5000]),
shape=(3, 3))
>>> 2 * D1
DiagMatrix(val=tensor([2, 4, 6]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val * D2.val)
return DiagMatrix(D1.val * D2)
def diag_div(D1: DiagMatrix, D2: Union[DiagMatrix, float]) -> DiagMatrix:
"""Elementwise division.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value
Returns
-------
DiagMatrix
diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> D2 = DiagMatrix(torch.arange(10, 13))
>>> D1 / D2
>>> D1/D2
DiagMatrix(val=tensor([0.1000, 0.1818, 0.2500]),
shape=(3, 3))
>>> D1/2.5
DiagMatrix(val=tensor([0.4000, 0.8000, 1.2000]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(D1.val / D2.val)
return DiagMatrix(D1.val / D2)
def diag_rdiv(D1: float, D2: DiagMatrix):
"""Elementwise division.
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
Diagonal matrix
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
)
def diag_power(D1: DiagMatrix, D2: float) -> DiagMatrix:
"""Elementwise power operation.
Parameters
----------
D1 : DiagMatrix
Diagonal matrix
D2 : DiagMatrix or scalar
Diagonal matrix or scalar value.
Returns
-------
DiagMatrix
Diagonal matrix
Examples
--------
>>> D1 = DiagMatrix(torch.arange(1, 4))
>>> pow(D1, 2)
DiagMatrix(val=tensor([1, 4, 9]),
shape=(3, 3))
"""
if isinstance(D1, DiagMatrix) and isinstance(D2, DiagMatrix):
assert (
D1.shape == D2.shape
), "The shape of diagonal matrix D1 {} and" "D2 {} must match".format(
D1.shape, D2.shape
)
return DiagMatrix(pow(D1.val, D2.val))
return DiagMatrix(pow(D1.val, D2))
def diag_rpower(D1: float, D2: DiagMatrix) -> DiagMatrix:
"""Elementwise power operator.
Parameters
----------
D1 : scalar
scalar value
D2 : DiagMatrix
Diagonal matrix
"""
raise RuntimeError(
"Elementwise subtraction between {} and {} is not "
"supported.".format(type(D1), type(D2))
)
DiagMatrix.__add__ = diag_add
DiagMatrix.__radd__ = diag_add
DiagMatrix.__sub__ = diag_sub
DiagMatrix.__rsub__ = diag_sub
DiagMatrix.__mul__ = diag_mul
DiagMatrix.__rmul__ = diag_mul
DiagMatrix.__truediv__ = diag_div
DiagMatrix.__rtruediv__ = diag_rdiv
DiagMatrix.__pow__ = diag_power
DiagMatrix.__rpow__ = diag_rpower
"""DGL elementwise operators for sparse matrix module."""
from typing import Union
import torch
from .diag_matrix import DiagMatrix
from .sparse_matrix import SparseMatrix
__all__ = ["sp_add"]
def spsp_add(A, B):
""" Invoke C++ sparse library for addition """
return SparseMatrix(
torch.ops.dgl_sparse.spsp_add(A.c_sparse_matrix, B.c_sparse_matrix)
)
def sp_add(
A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> SparseMatrix:
"""Elementwise addition.
Parameters
----------
A : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
B : SparseMatrix or DiagMatrix
Sparse matrix or diagonal matrix
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case 1: Add two sparse matrices of same sparsity structure
>>> rowA = torch.tensor([1, 0, 2])
>>> colA = torch.tensor([0, 3, 2])
>>> valA = torch.tensor([10, 20, 30])
>>> A = SparseMatrix(rowA, colA, valA, shape=(3, 4))
>>> A + A
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([40, 20, 60]),
shape=(3, 4), nnz=3)
>>> w = torch.arange(1, len(rowA)+1)
>>> A + A(w)
SparseMatrix(indices=tensor([[0, 1, 2],
[3, 0, 2]]),
values=tensor([21, 12, 33]),
shape=(3, 4), nnz=3)
Case 2: Add two sparse matrices of different sparsity structure
>>> rowB = torch.tensor([1, 2, 0, 2, 1])
>>> colB = torch.tensor([0, 2, 1, 3, 3])
>>> valB = torch.tensor([1, 2, 3, 4, 5])
>>> B = SparseMatrix(rowB, colB, valB, shape=(3 ,4))
>>> A + B
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2, 2],
[1, 3, 0, 3, 2, 3]]),
values=tensor([ 3, 20, 11, 5, 32, 4]),
shape=(3, 4), nnz=6)
Case 3: Add sparse matrix and diagonal matrix
>>> D = diag(torch.arange(2, 5), shape=A.shape)
>>> A + D
SparseMatrix(indices=tensor([[0, 0, 1, 1, 2],
[0, 3, 0, 1, 2]]),
values=tensor([ 2, 20, 10, 3, 34]),
shape=(3, 4), nnz=5)
"""
B = B.as_sparse() if isinstance(B, DiagMatrix) else B
if isinstance(A, SparseMatrix) and isinstance(B, SparseMatrix):
return spsp_add(A, B)
raise RuntimeError(
"Elementwise addition between {} and {} is not "
"supported.".format(type(A), type(B))
)
SparseMatrix.__add__ = sp_add
SparseMatrix.__radd__ = sp_add
"""DGL sparse matrix module."""
from typing import Optional, Tuple
import torch
class SparseMatrix:
r"""Class for sparse matrix."""
def __init__(self, c_sparse_matrix: torch.ScriptObject):
self.c_sparse_matrix = c_sparse_matrix
@property
def val(self) -> torch.Tensor:
"""Get the values of the nonzero elements.
Returns
-------
torch.Tensor
Values of the nonzero elements
"""
return self.c_sparse_matrix.val()
@property
def shape(self) -> Tuple[int]:
"""Shape of the sparse matrix.
Returns
-------
Tuple[int]
The shape of the matrix
"""
return tuple(self.c_sparse_matrix.shape())
@property
def nnz(self) -> int:
"""The number of nonzero elements of the sparse matrix.
Returns
-------
int
The number of nonzero elements of the matrix
"""
return self.c_sparse_matrix.nnz()
@property
def dtype(self) -> torch.dtype:
"""Data type of the values of the sparse matrix.
Returns
-------
torch.dtype
Data type of the values of the matrix
"""
# FIXME: find a proper way to pass dtype from C++ to Python
return self.c_sparse_matrix.val().dtype
@property
def device(self) -> torch.device:
"""Device of the sparse matrix.
Returns
-------
torch.device
Device of the matrix
"""
return self.c_sparse_matrix.device()
def indices(
self, fmt: str, return_shuffle=False
) -> Tuple[torch.Tensor, ...]:
"""Get the indices of the nonzero elements.
Parameters
----------
fmt : str
Sparse matrix storage format. Can be COO or CSR or CSC.
return_shuffle: bool
If true, return an extra array of the nonzero value IDs
Returns
-------
tensor
Indices of the nonzero elements
"""
if fmt == "COO" and not return_shuffle:
row, col, _ = self.coo()
return torch.stack([row, col])
else:
raise NotImplementedError
def __repr__(self):
return f'SparseMatrix(indices={self.indices("COO")}, \nvalues={self.val}, \
\nshape={self.shape}, nnz={self.nnz})'
def coo(self) -> Tuple[torch.Tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and values.
"""
return self.c_sparse_matrix.coo()
def csr(self) -> Tuple[torch.Tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and values.
"""
return self.c_sparse_matrix.csr()
def csc(self) -> Tuple[torch.Tensor, ...]:
"""Get the coordinate (COO) representation of the sparse matrix.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
A tuple of tensors containing row, column coordinates and values.
"""
return self.c_sparse_matrix.csc()
def create_from_coo(
row: torch.Tensor,
col: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from row and column coordinates.
Parameters
----------
row : tensor
The row indices of shape (nnz).
col : tensor
The column indices of shape (nnz).
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`row` and :attr:`col`, i.e.,
(row.max() + 1, col.max() + 1). Otherwise, :attr:`shape` should be no smaller
than this.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix with row and column indices without values.
>>> src = torch.tensor([1, 1, 2])
>>> dst = torch.tensor([2, 4, 3])
>>> A = create_from_coo(src, dst)
>>> print(A)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(3, 5), nnz=3)
>>> # Specify shape
>>> A = create_from_coo(src, dst, shape=(5, 5))
>>> print(A)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([1., 1., 1.]),
shape=(5, 5), nnz=3)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3]])
>>> A = create_from_coo(src, dst, val)
SparseMatrix(indices=tensor([[1, 1, 2],
[2, 4, 3]]),
values=tensor([[1, 1],
[2, 2],
[3, 3]]),
shape=(3, 5), nnz=3)
"""
if shape is None:
shape = (torch.max(row).item() + 1, torch.max(col).item() + 1)
if val is None:
val = torch.ones(row.shape[0])
return SparseMatrix(
torch.ops.dgl_sparse.create_from_coo(row, col, val, shape)
)
# FIXME: The docstring cannot print A because we cannot print
# the indices of CSR/CSC
def create_from_csr(
indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from CSR indices.
For row i of the sparse matrix
- the column indices of the nonzero entries are stored in ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
Parameters
----------
indptr : tensor
Pointer to the column indices of shape (N + 1), where N is the number of rows.
indices : tensor
The column indices of shape (nnz).
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`indptr` and :attr:`indices`, i.e.,
(len(indptr) - 1, indices.max() + 1). Otherwise, :attr:`shape` should be no smaller
than this.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix without values
[[0, 1, 0],
[0, 0, 1],
[1, 1, 1]]
>>> indptr = torch.tensor([0, 1, 2, 5])
>>> indices = torch.tensor([1, 2, 0, 1, 2])
>>> A = create_from_csr(indptr, indices)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(3, 3), nnz=5)
>>> # Specify shape
>>> A = create_from_csr(indptr, indices, shape=(5, 3))
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 3), nnz=5)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csr(indptr, indices, val)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([[1, 1],
[2, 2],
[3, 3],
[4, 4],
[5, 5]]),
shape=(3, 3), nnz=5)
"""
if shape is None:
shape = (indptr.shape[0] - 1, torch.max(indices) + 1)
if val is None:
val = torch.ones(indices.shape[0])
return SparseMatrix(
torch.ops.dgl_sparse.create_from_csr(indptr, indices, val, shape)
)
# FIXME: The docstring cannot print A because we cannot print
# the indices of CSR/CSC
def create_from_csc(
indptr: torch.Tensor,
indices: torch.Tensor,
val: Optional[torch.Tensor] = None,
shape: Optional[Tuple[int, int]] = None,
) -> SparseMatrix:
"""Create a sparse matrix from CSC indices.
For column i of the sparse matrix
- the row indices of the nonzero entries are stored in ``indices[indptr[i]: indptr[i+1]]``
- the corresponding values are stored in ``val[indptr[i]: indptr[i+1]]``
Parameters
----------
indptr : tensor
Pointer to the row indices of shape N + 1, where N is the number of columns.
indices : tensor
The row indices of shape nnz.
val : tensor, optional
The values of shape (nnz) or (nnz, D). If None, it will be a tensor of shape (nnz)
filled by 1.
shape : tuple[int, int], optional
If not specified, it will be inferred from :attr:`indptr` and :attr:`indices`, i.e.,
(indices.max() + 1, len(indptr) - 1). Otherwise, :attr:`shape` should be no smaller
than this.
Returns
-------
SparseMatrix
Sparse matrix
Examples
--------
Case1: Sparse matrix without values
[[0, 1, 0],
[0, 0, 1],
[1, 1, 1]]
>>> indptr = torch.tensor([0, 1, 3, 5])
>>> indices = torch.tensor([2, 0, 2, 1, 2])
>>> A = create_from_csc(indptr, indices)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(3, 3), nnz=5)
>>> # Specify shape
>>> A = create_from_csc(indptr, indices, shape=(5, 3))
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(5, 3), nnz=5)
Case2: Sparse matrix with scalar/vector values. Following example is with
vector data.
>>> val = torch.tensor([[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> A = create_from_csc(indptr, indices, val)
>>> print(A)
SparseMatrix(indices=tensor([[0, 1, 2, 2, 2],
[1, 2, 0, 1, 2]]),
values=tensor([[2, 2],
[4, 4],
[1, 1],
[3, 3],
[5, 5]]),
shape=(3, 3), nnz=5)
"""
if shape is None:
shape = (torch.max(indices) + 1, indptr.shape[0] - 1)
if val is None:
val = torch.ones(indices.shape[0])
return SparseMatrix(
torch.ops.dgl_sparse.create_from_csc(indptr, indices, val, shape)
)
...@@ -58,9 +58,22 @@ def get_ta_lib_pattern(): ...@@ -58,9 +58,22 @@ def get_ta_lib_pattern():
return ta_lib_pattern return ta_lib_pattern
def get_dgl_sparse_pattern():
if sys.platform.startswith("linux"):
dgl_sparse_lib_pattern = "libdgl_sparse_*.so"
elif sys.platform.startswith("darwin"):
dgl_sparse_lib_pattern = "libdgl_sparse_*.dylib"
elif sys.platform.startswith("win"):
dgl_sparse_lib_pattern = "dgl_sparse_*.dll"
else:
raise NotImplementedError("Unsupported system: %s" % sys.platform)
return dgl_sparse_lib_pattern
LIBS, VERSION = get_lib_path() LIBS, VERSION = get_lib_path()
BACKENDS = ["pytorch"] BACKENDS = ["pytorch"]
TA_LIB_PATTERN = get_ta_lib_pattern() TA_LIB_PATTERN = get_ta_lib_pattern()
SPARSE_LIB_PATTERN = get_dgl_sparse_pattern()
def cleanup(): def cleanup():
...@@ -87,6 +100,17 @@ def cleanup(): ...@@ -87,6 +100,17 @@ def cleanup():
except BaseException: except BaseException:
pass pass
if backend == "pytorch":
for sparse_path in glob.glob(
os.path.join(
CURRENT_DIR, "dgl", "dgl_sparse", SPARSE_LIB_PATTERN
)
):
try:
os.remove(sparse_path)
except BaseException:
pass
def config_cython(): def config_cython():
"""Try to configure cython and return cython configuration""" """Try to configure cython and return cython configuration"""
...@@ -176,6 +200,23 @@ if wheel_include_libs: ...@@ -176,6 +200,23 @@ if wheel_include_libs:
fo.write( fo.write(
"include dgl/tensoradapter/%s/%s\n" % (backend, ta_name) "include dgl/tensoradapter/%s/%s\n" % (backend, ta_name)
) )
if backend == 'pytorch':
for sparse_path in glob.glob(
os.path.join(dir_, "dgl_sparse", SPARSE_LIB_PATTERN)
):
sparse_name = os.path.basename(sparse_path)
os.makedirs(
os.path.join(CURRENT_DIR, "dgl", "dgl_sparse"),
exist_ok=True,
)
shutil.copy(
os.path.join(dir_, "dgl_sparse", sparse_name),
os.path.join(CURRENT_DIR, "dgl", "dgl_sparse"),
)
fo.write(
"include dgl/dgl_sparse/%s\n" % sparse_name
)
setup_kwargs = {"include_package_data": True} setup_kwargs = {"include_package_data": True}
...@@ -199,6 +240,19 @@ if include_libs: ...@@ -199,6 +240,19 @@ if include_libs:
), ),
) )
) )
if backend == 'pytorch':
data_files.append(
(
"dgl/dgl_sparse",
glob.glob(
os.path.join(
os.path.dirname(os.path.relpath(path, CURRENT_DIR)),
"dgl_sparse",
SPARSE_LIB_PATTERN,
)
),
)
)
setup_kwargs = {"include_package_data": True, "data_files": data_files} setup_kwargs = {"include_package_data": True, "data_files": data_files}
setup( setup(
......
""" DGL mock_sparse tests"""
""" DGL mock_sparse2 tests"""
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment