Unverified Commit c8bc5588 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Sparse] Remove SparseMatrix.indices and add device/dtype conversions (#5108)



* Update

* Update

* Update

* lint

* lint

* Update

* CI

* Update diag_matrix.py

* Update sparse_matrix.py

* Update sparse_matrix.py
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-36-188.ap-northeast-1.compute.internal>
Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent c55ab2d1
......@@ -97,6 +97,13 @@ Attributes and methods
SparseMatrix.dense
SparseMatrix.t
SparseMatrix.T
SparseMatrix.to
SparseMatrix.cuda
SparseMatrix.cpu
SparseMatrix.float
SparseMatrix.double
SparseMatrix.int
SparseMatrix.long
SparseMatrix.transpose
SparseMatrix.reduce
SparseMatrix.sum
......@@ -139,6 +146,13 @@ Attributes and methods
DiagMatrix.t
DiagMatrix.T
DiagMatrix.transpose
DiagMatrix.to
DiagMatrix.cuda
DiagMatrix.cpu
DiagMatrix.float
DiagMatrix.double
DiagMatrix.int
DiagMatrix.long
DiagMatrix.neg
DiagMatrix.inv
DiagMatrix.__matmul__
......
"""DGL diagonal matrix module."""
# pylint: disable= invalid-name
from typing import Optional, Tuple
import torch
......@@ -166,6 +167,164 @@ class DiagMatrix:
"""
return DiagMatrix(self.val, self.shape[::-1])
def to(self, device=None, dtype=None):
"""Perform matrix dtype and/or device conversion. If the target device
and dtype are already in use, the original matrix will be returned.
Parameters
----------
device : torch.device, optional
The target device of the matrix if given, otherwise the current
device will be used
dtype : torch.dtype, optional
The target data type of the matrix values if given, otherwise the
current data type will be used
Returns
-------
DiagMatrix
The result matrix
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.to(device='cuda:0', dtype=torch.int32)
DiagMatrix(values=tensor([1, 1], device='cuda:0', dtype=torch.int32),
size=(2, 2))
"""
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
if device == self.device and dtype == self.dtype:
return self
return diag(self.val.to(device=device, dtype=dtype), self.shape)
def cuda(self):
"""Move the matrix to GPU. If the matrix is already on GPU, the
original matrix will be returned. If multiple GPU devices exist,
'cuda:0' will be selected.
Returns
-------
DiagMatrix
The matrix on GPU
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.cuda()
DiagMatrix(values=tensor([1., 1.], device='cuda:0'),
size=(2, 2))
"""
return self.to(device="cuda")
def cpu(self):
"""Move the matrix to CPU. If the matrix is already on CPU, the
original matrix will be returned.
Returns
-------
DiagMatrix
The matrix on CPU
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.cpu()
DiagMatrix(values=tensor([1., 1.]),
size=(2, 2))
"""
return self.to(device="cpu")
def float(self):
"""Convert the matrix values to float data type. If the matrix already
uses float data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with float values
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.float()
DiagMatrix(values=tensor([1., 1.]),
size=(2, 2))
"""
return self.to(dtype=torch.float)
def double(self):
"""Convert the matrix values to double data type. If the matrix already
uses double data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with double values
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.double()
DiagMatrix(values=tensor([1., 1.], dtype=torch.float64),
size=(2, 2))
"""
return self.to(dtype=torch.double)
def int(self):
"""Convert the matrix values to int data type. If the matrix already
uses int data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with int values
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.int()
DiagMatrix(values=tensor([1, 1], dtype=torch.int32),
size=(2, 2))
"""
return self.to(dtype=torch.int)
def long(self):
"""Convert the matrix values to long data type. If the matrix already
uses long data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with long values
Example
--------
>>> val = torch.ones(2)
>>> mat = diag(val)
>>> mat.long()
DiagMatrix(values=tensor([1, 1]),
size=(2, 2))
"""
return self.to(dtype=torch.long)
def diag(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
......@@ -287,7 +446,8 @@ def identity(
def _diag_matrix_str(spmat: DiagMatrix) -> str:
"""Internal function for converting a diagonal matrix to string representation."""
"""Internal function for converting a diagonal matrix to string
representation."""
values_str = str(spmat.val)
meta_str = f"size={spmat.shape}"
if spmat.val.dim() > 1:
......
"""DGL sparse matrix module."""
# pylint: disable= invalid-name
from typing import Optional, Tuple
import torch
......@@ -88,29 +89,6 @@ class SparseMatrix:
"""
return self.coo()[1]
def indices(
self, fmt: str, return_shuffle=False
) -> Tuple[torch.Tensor, ...]:
"""Get the indices of the nonzero elements.
Parameters
----------
fmt : str
Sparse matrix storage format. Can be COO or CSR or CSC.
return_shuffle: bool
If true, return an extra array of the nonzero value IDs
Returns
-------
tensor
Indices of the nonzero elements
"""
if fmt == "COO" and not return_shuffle:
row, col = self.coo()
return torch.stack([row, col])
else:
raise NotImplementedError
def __repr__(self):
return _sparse_matrix_str(self)
......@@ -194,6 +172,195 @@ class SparseMatrix:
"""
return SparseMatrix(self.c_sparse_matrix.transpose())
def to(self, device=None, dtype=None):
"""Perform matrix dtype and/or device conversion. If the target device
and dtype are already in use, the original matrix will be returned.
Parameters
----------
device : torch.device, optional
The target device of the matrix if provided, otherwise the current
device will be used
dtype : torch.dtype, optional
The target data type of the matrix values, otherwise the current
data type will be used
Returns
-------
SparseMatrix
The result matrix
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.to(device='cuda:0', dtype=torch.int32)
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]], device='cuda:0'),
values=tensor([1, 1, 1], device='cuda:0',
dtype=torch.int32),
size=(3, 4), nnz=3)
"""
if device is None:
device = self.device
if dtype is None:
dtype = self.dtype
if device == self.device and dtype == self.dtype:
return self
elif device == self.device:
return val_like(self, self.val.to(dtype=dtype))
else:
# TODO(#5119): Find a better moving strategy instead of always
# convert to COO format.
row, col = self.coo()
row = row.to(device=device)
col = col.to(device=device)
val = self.val.to(device=device, dtype=dtype)
return from_coo(row, col, val, self.shape)
def cuda(self):
"""Move the matrix to GPU. If the matrix is already on GPU, the
original matrix will be returned. If multiple GPU devices exist,
'cuda:0' will be selected.
Returns
-------
SparseMatrix
The matrix on GPU
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.cuda()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]], device='cuda:0'),
values=tensor([1., 1., 1.], device='cuda:0'),
size=(3, 4), nnz=3)
"""
return self.to(device="cuda")
def cpu(self):
"""Move the matrix to CPU. If the matrix is already on CPU, the
original matrix will be returned.
Returns
-------
SparseMatrix
The matrix on CPU
Example
--------
>>> row = torch.tensor([1, 1, 2]).to('cuda')
>>> col = torch.tensor([1, 2, 0]).to('cuda')
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.cpu()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]]),
values=tensor([1., 1., 1.]),
size=(3, 4), nnz=3)
"""
return self.to(device="cpu")
def float(self):
"""Convert the matrix values to float data type. If the matrix already
uses float data type, the original matrix will be returned.
Returns
-------
SparseMatrix
The matrix with float values
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> val = torch.ones(len(row)).long()
>>> A = from_coo(row, col, val, shape=(3, 4))
>>> A.float()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]]),
values=tensor([1., 1., 1.]),
size=(3, 4), nnz=3)
"""
return self.to(dtype=torch.float)
def double(self):
"""Convert the matrix values to double data type. If the matrix already
uses double data type, the original matrix will be returned.
Returns
-------
SparseMatrix
The matrix with double values
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.double()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]]),
values=tensor([1., 1., 1.], dtype=torch.float64),
size=(3, 4), nnz=3)
"""
return self.to(dtype=torch.double)
def int(self):
"""Convert the matrix values to int data type. If the matrix already
uses int data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with int values
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.int()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]]),
values=tensor([1, 1, 1], dtype=torch.int32),
size=(3, 4), nnz=3)
"""
return self.to(dtype=torch.int)
def long(self):
"""Convert the matrix values to long data type. If the matrix already
uses long data type, the original matrix will be returned.
Returns
-------
DiagMatrix
The matrix with long values
Example
--------
>>> row = torch.tensor([1, 1, 2])
>>> col = torch.tensor([1, 2, 0])
>>> A = from_coo(row, col, shape=(3, 4))
>>> A.long()
SparseMatrix(indices=tensor([[1, 1, 2],
[1, 2, 0]]),
values=tensor([1, 1, 1]),
size=(3, 4), nnz=3)
"""
return self.to(dtype=torch.long)
def coalesce(self):
"""Return a coalesced sparse matrix.
......@@ -521,8 +688,9 @@ def val_like(mat: SparseMatrix, val: torch.Tensor) -> SparseMatrix:
def _sparse_matrix_str(spmat: SparseMatrix) -> str:
"""Internal function for converting a sparse matrix to string representation."""
indices_str = str(spmat.indices("COO"))
"""Internal function for converting a sparse matrix to string
representation."""
indices_str = str(torch.stack(spmat.coo()))
values_str = str(spmat.val)
meta_str = f"size={spmat.shape}, nnz={spmat.nnz}"
if spmat.val.dim() > 1:
......
import sys
import unittest
import backend as F
import pytest
......@@ -85,3 +86,47 @@ def test_print():
val = torch.randn(3, 2).to(ctx)
A = diag(val)
print(A)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Device conversions don't need to be tested on CPU.",
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_to_device(device):
val = torch.randn(3)
mat_shape = (3, 4)
mat = diag(val, mat_shape)
target_val = mat.val.to(device)
mat2 = mat.to(device=device)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
mat2 = getattr(mat, device)()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
@pytest.mark.parametrize(
"dtype", [torch.float, torch.double, torch.int, torch.long]
)
def test_to_dtype(dtype):
val = torch.randn(3)
mat_shape = (3, 4)
mat = diag(val, mat_shape)
target_val = mat.val.to(dtype=dtype)
mat2 = mat.to(dtype=dtype)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
func_name = {
torch.float: "float",
torch.double: "double",
torch.int: "int",
torch.long: "long",
}
mat2 = getattr(mat, func_name[dtype])()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
......@@ -13,7 +13,7 @@ if not sys.platform.startswith("linux"):
def all_close_sparse(A, B):
assert torch.allclose(A.indices(), B.indices())
assert torch.allclose(torch.stack(A.coo()), torch.stack(B.coo()))
assert torch.allclose(A.values(), B.values())
assert A.shape == B.shape
......
import sys
import unittest
import backend as F
import pytest
......@@ -431,3 +432,55 @@ def test_print():
val = torch.randn(3, 2).to(ctx)
A = from_coo(row, col, val)
print(A)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Device conversions don't need to be tested on CPU.",
)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_to_device(device):
row = torch.tensor([1, 1, 2])
col = torch.tensor([1, 2, 0])
mat = from_coo(row, col, shape=(3, 4))
target_row = row.to(device)
target_col = col.to(device)
target_val = mat.val.to(device)
mat2 = mat.to(device=device)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.row, target_row)
assert torch.allclose(mat2.col, target_col)
assert torch.allclose(mat2.val, target_val)
mat2 = getattr(mat, device)()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.row, target_row)
assert torch.allclose(mat2.col, target_col)
assert torch.allclose(mat2.val, target_val)
@pytest.mark.parametrize(
"dtype", [torch.float, torch.double, torch.int, torch.long]
)
def test_to_dtype(dtype):
row = torch.tensor([1, 1, 2])
col = torch.tensor([1, 2, 0])
mat = from_coo(row, col, shape=(3, 4))
target_val = mat.val.to(dtype=dtype)
mat2 = mat.to(dtype=dtype)
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
func_name = {
torch.float: "float",
torch.double: "double",
torch.int: "int",
torch.long: "long",
}
mat2 = getattr(mat, func_name[dtype])()
assert mat2.shape == mat.shape
assert torch.allclose(mat2.val, target_val)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment