Unverified Commit 4bc256b1 authored by Andrzej Kotłowski's avatar Andrzej Kotłowski Committed by GitHub
Browse files

[Test] Run test_spmm also on floats (#5930)

parent ac53c1fa
...@@ -11,8 +11,12 @@ from dgl.ops import gather_mm, gsddmm, gspmm, segment_reduce ...@@ -11,8 +11,12 @@ from dgl.ops import gather_mm, gsddmm, gspmm, segment_reduce
from utils import parametrize_idtype from utils import parametrize_idtype
from utils.graph_cases import get_cases from utils.graph_cases import get_cases
random.seed(42) # Set seeds to make tests fully reproducible.
np.random.seed(42) SEED = 12345 # random.randint(1, 99999)
random.seed(SEED)
np.random.seed(SEED)
dgl.seed(SEED)
F.seed(SEED)
udf_msg = { udf_msg = {
"add": lambda edges: {"m": edges.src["x"] + edges.data["w"]}, "add": lambda edges: {"m": edges.src["x"] + edges.data["w"]},
...@@ -111,13 +115,18 @@ sddmm_shapes = [ ...@@ -111,13 +115,18 @@ sddmm_shapes = [
) )
@pytest.mark.parametrize("reducer", ["sum", "min", "max"]) @pytest.mark.parametrize("reducer", ["sum", "min", "max"])
@parametrize_idtype @parametrize_idtype
def test_spmm(idtype, g, shp, msg, reducer): @pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_spmm(idtype, dtype, g, shp, msg, reducer):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
print(g) print(g)
print(g.idtype) print(g.idtype)
hu = F.tensor(np.random.rand(*((g.number_of_src_nodes(),) + shp[0])) + 1) hu = F.tensor(
he = F.tensor(np.random.rand(*((g.num_edges(),) + shp[1])) + 1) np.random.rand(*((g.number_of_src_nodes(),) + shp[0])).astype(dtype) + 1
)
he = F.tensor(
np.random.rand(*((g.num_edges(),) + shp[1])).astype(dtype) + 1
)
print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he))) print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he)))
g.srcdata["x"] = F.attach_grad(F.clone(hu)) g.srcdata["x"] = F.attach_grad(F.clone(hu))
......
...@@ -25,6 +25,10 @@ from dgl.distributed import ( ...@@ -25,6 +25,10 @@ from dgl.distributed import (
from dgl.distributed.optim import SparseAdagrad, SparseAdam from dgl.distributed.optim import SparseAdagrad, SparseAdam
from scipy import sparse as spsp from scipy import sparse as spsp
# Set seeds to make tests fully reproducible.
SEED = 12345 # random.randint(1, 99999)
F.seed(SEED)
def create_random_graph(n): def create_random_graph(n):
arr = ( arr = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment