Unverified Commit 44a9faad authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add converter of etype and update in dataset (#6230)

parent b642c5d8
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
import torch import torch
from .._ffi import libinfo from .._ffi import libinfo
from .copy_to import * from .base import *
from .data_block import * from .data_block import *
from .data_format import * from .data_format import *
from .dataloader import * from .dataloader import *
......
"""Graph Bolt CUDA-related Data Pipelines""" """Base types and utilities for Graph Bolt."""
from torchdata.datapipes.iter import IterDataPipe from torchdata.datapipes.iter import IterDataPipe
from ..utils import recursive_apply from ..utils import recursive_apply
__all__ = [
"CANONICAL_ETYPE_DELIMITER",
"etype_str_to_tuple",
"etype_tuple_to_str",
"CopyTo",
]
CANONICAL_ETYPE_DELIMITER = ":"
def etype_tuple_to_str(c_etype):
"""Convert canonical etype from tuple to string.
Examples
--------
>>> c_etype = ("user", "like", "item")
>>> c_etype_str = _etype_tuple_to_str(c_etype)
>>> print(c_etype_str)
"user:like:item"
"""
assert isinstance(c_etype, tuple) and len(c_etype) == 3, (
"Passed-in canonical etype should be in format of (str, str, str). "
f"But got {c_etype}."
)
return CANONICAL_ETYPE_DELIMITER.join(c_etype)
def etype_str_to_tuple(c_etype):
"""Convert canonical etype from tuple to string.
Examples
--------
>>> c_etype_str = "user:like:item"
>>> c_etype = _etype_str_to_tuple(c_etype_str)
>>> print(c_etype)
("user", "like", "item")
"""
ret = tuple(c_etype.split(CANONICAL_ETYPE_DELIMITER))
assert len(ret) == 3, (
"Passed-in canonical etype should be in format of 'str:str:str'. "
f"But got {c_etype}."
)
return ret
def _to(x, device): def _to(x, device):
return x.to(device) if hasattr(x, "to") else x return x.to(device) if hasattr(x, "to") else x
......
...@@ -11,6 +11,7 @@ import yaml ...@@ -11,6 +11,7 @@ import yaml
import dgl import dgl
from ..base import etype_str_to_tuple
from ..dataset import Dataset, Task from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict from ..itemset import ItemSet, ItemSetDict
from ..utils import read_data, save_data from ..utils import read_data, save_data
...@@ -126,7 +127,7 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str: ...@@ -126,7 +127,7 @@ def preprocess_ondisk_dataset(dataset_dir: str) -> str:
) )
src = torch.tensor(edge_data["src"]) src = torch.tensor(edge_data["src"])
dst = torch.tensor(edge_data["dst"]) dst = torch.tensor(edge_data["dst"])
data_dict[tuple(edge_info["type"].split(":"))] = (src, dst) data_dict[etype_str_to_tuple(edge_info["type"])] = (src, dst)
# Construct the heterograph. # Construct the heterograph.
g = dgl.heterograph(data_dict, num_nodes_dict) g = dgl.heterograph(data_dict, num_nodes_dict)
......
import re
import unittest
import backend as F
import dgl.graphbolt as gb
import pytest
import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.MinibatchSampler(torch.randn(20), 4)
dp = gb.CopyTo(dp, "cuda")
for data in dp:
assert data.device.type == "cuda"
def test_etype_tuple_to_str():
"""Convert etype from tuple to string."""
# Test for expected input.
c_etype = ("user", "like", "item")
c_etype_str = gb.etype_tuple_to_str(c_etype)
assert c_etype_str == "user:like:item"
# Test for unexpected input: not a tuple.
c_etype = "user:like:item"
with pytest.raises(
AssertionError,
match=re.escape(
"Passed-in canonical etype should be in format of (str, str, str). "
"But got user:like:item."
),
):
_ = gb.etype_tuple_to_str(c_etype)
# Test for unexpected input: tuple with wrong length.
c_etype = ("user", "like")
with pytest.raises(
AssertionError,
match=re.escape(
"Passed-in canonical etype should be in format of (str, str, str). "
"But got ('user', 'like')."
),
):
_ = gb.etype_tuple_to_str(c_etype)
def test_etype_str_to_tuple():
"""Convert etype from string to tuple."""
# Test for expected input.
c_etype_str = "user:like:item"
c_etype = gb.etype_str_to_tuple(c_etype_str)
assert c_etype == ("user", "like", "item")
# Test for unexpected input: string with wrong format.
c_etype_str = "user:like"
with pytest.raises(
AssertionError,
match=re.escape(
"Passed-in canonical etype should be in format of 'str:str:str'. "
"But got user:like."
),
):
_ = gb.etype_str_to_tuple(c_etype_str)
import unittest
import backend as F
import dgl.graphbolt
import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = dgl.graphbolt.MinibatchSampler(torch.randn(20), 4)
dp = dgl.graphbolt.CopyTo(dp, "cuda")
for data in dp:
assert data.device.type == "cuda"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment