Unverified Commit 2593c925 authored by Andrei Ivanov's avatar Andrei Ivanov Committed by GitHub
Browse files

Improving sparse tests. (#6168)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 5a417414
import sys
import warnings
import backend as F
import pytest
......@@ -19,6 +19,12 @@ from .utils import (
)
def _torch_sparse_mm(torch_A1, torch_A2):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return torch.sparse.mm(torch_A1, torch_A2)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(2, 7), (5, 2)])
@pytest.mark.parametrize("nnz", [1, 10])
......@@ -98,7 +104,7 @@ def test_spspmm(create_func1, create_func2, shape_n_m, shape_k, nnz1, nnz2):
torch_A1 = sparse_matrix_to_torch_sparse(A1)
torch_A2 = sparse_matrix_to_torch_sparse(A2)
torch_A3 = torch.sparse.mm(torch_A1, torch_A2)
torch_A3 = _torch_sparse_mm(torch_A1, torch_A2)
torch_A3_grad = sparse_matrix_to_torch_sparse(A3, grad)
torch_A3.backward(torch_A3_grad)
......@@ -161,7 +167,7 @@ def test_sparse_diag_mm(create_func, sparse_shape, nnz):
torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D)
torch_B = torch.sparse.mm(torch_A, torch_D)
torch_B = _torch_sparse_mm(torch_A, torch_D)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad)
......@@ -194,7 +200,7 @@ def test_diag_sparse_mm(create_func, sparse_shape, nnz):
torch_A = sparse_matrix_to_torch_sparse(A)
torch_D = sparse_matrix_to_torch_sparse(D)
torch_B = torch.sparse.mm(torch_D, torch_A)
torch_B = _torch_sparse_mm(torch_D, torch_A)
torch_B_grad = sparse_matrix_to_torch_sparse(B, grad)
torch_B.backward(torch_B_grad)
......
import sys
import unittest
import warnings
import backend as F
import pytest
......@@ -19,6 +19,12 @@ from dgl.sparse import (
)
def _torch_sparse_csr_tensor(indptr, indices, val, torch_sparse_shape):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
return torch.sparse_csr_tensor(indptr, indices, val, torch_sparse_shape)
@pytest.mark.parametrize("dense_dim", [None, 4])
@pytest.mark.parametrize("row", [(0, 0, 1, 2), (0, 1, 2, 4)])
@pytest.mark.parametrize("col", [(0, 1, 2, 2), (1, 3, 3, 4)])
......@@ -580,7 +586,7 @@ def test_torch_sparse_csr_conversion(indptr, indices, shape):
torch_sparse_shape = shape
val_shape = (indices.shape[0],)
val = torch.randn(val_shape).to(dev)
torch_sparse_csr = torch.sparse_csr_tensor(
torch_sparse_csr = _torch_sparse_csr_tensor(
indptr, indices, val, torch_sparse_shape
)
spmat = from_torch_sparse(torch_sparse_csr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment