"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f5bba284f8cb471fe1db06c2fd9bdc228038e425"
Unverified Commit 27f6561a authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Graphbolt] Improve basic feature store test. (#6209)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ae49309c
...@@ -4,23 +4,76 @@ import torch ...@@ -4,23 +4,76 @@ import torch
from dgl import graphbolt as gb from dgl import graphbolt as gb
def test_basic_feature_store(): def test_basic_feature_store_homo():
a = torch.tensor([3, 2, 1]) a = torch.tensor([3, 2, 1])
b = torch.tensor([2, 5, 3]) b = torch.tensor([2, 5, 3])
features = {}
features[("node", None, "a")] = gb.TorchBasedFeature(a)
features[("node", None, "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
# Test read the entire feature.
assert torch.equal(
feature_store.read("node", None, "a"), torch.tensor([3, 2, 1])
)
assert torch.equal(
feature_store.read("node", None, "b"), torch.tensor([2, 5, 3])
)
# Test read with ids.
assert torch.equal(
feature_store.read("node", None, "a", torch.tensor([0, 1])),
torch.tensor([3, 2]),
)
def test_basic_feature_store_hetero():
a = torch.tensor([3, 2, 1])
b = torch.tensor([2, 5, 3])
c = torch.tensor([6, 8, 9])
features = {} features = {}
features[("node", "paper", "a")] = gb.TorchBasedFeature(a) features[("node", "paper", "a")] = gb.TorchBasedFeature(a)
features[("edge", "paper:cites:paper", "b")] = gb.TorchBasedFeature(b) features[("node", "author", "b")] = gb.TorchBasedFeature(b)
features[("edge", "paper:cites:paper", "c")] = gb.TorchBasedFeature(c)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
# Test read the entire feature.
assert torch.equal( assert torch.equal(
feature_store.read("node", "paper", "a"), torch.tensor([3, 2, 1]) feature_store.read("node", "paper", "a"), torch.tensor([3, 2, 1])
) )
assert torch.equal( assert torch.equal(
feature_store.read("edge", "paper:cites:paper", "b"), feature_store.read("node", "author", "b"), torch.tensor([2, 5, 3])
torch.tensor([2, 5, 3]), )
assert torch.equal(
feature_store.read("edge", "paper:cites:paper", "c"),
torch.tensor([6, 8, 9]),
) )
# Test read with ids.
assert torch.equal( assert torch.equal(
feature_store.read("node", "paper", "a", torch.tensor([0, 1])), feature_store.read("node", "paper", "a", torch.tensor([0, 1])),
torch.tensor([3, 2]), torch.tensor([3, 2]),
) )
def test_basic_feature_store_errors():
a = torch.tensor([3, 2, 1])
b = torch.tensor([2, 5, 3])
features = {}
features[("node", "paper", "a")] = gb.TorchBasedFeature(a)
features[("node", "author", "b")] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features)
# Test error when key does not exist.
with pytest.raises(KeyError):
feature_store.read("node", "paper", "b")
# Test error when at least one id is out of bound.
with pytest.raises(IndexError):
feature_store.read("node", "paper", "a", torch.tensor([0, 3]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment