Unverified Commit 346197c4 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Add `gb.index_select` and fix example inferencing. (#7051)

parent f4989867
...@@ -187,6 +187,7 @@ Utilities ...@@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str etype_tuple_to_str
isin isin
seed seed
index_select
expand_indptr expand_indptr
add_reverse_edges add_reverse_edges
exclude_seed_edges exclude_seed_edges
......
...@@ -79,14 +79,10 @@ class SAGE(nn.Module): ...@@ -79,14 +79,10 @@ class SAGE(nn.Module):
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
return hidden_x return hidden_x
def inference(self, graph, features, dataloader, device): def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat") pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
print("Start node embedding inference.") print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
...@@ -99,17 +95,17 @@ class SAGE(nn.Module): ...@@ -99,17 +95,17 @@ class SAGE(nn.Module):
device=buffer_device, device=buffer_device,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
feature = feature.to(device) for data in tqdm.tqdm(dataloader):
for step, data in tqdm.tqdm(enumerate(dataloader)): # len(blocks) = 1
x = feature[data.input_nodes] hidden_x = layer(data.blocks[0], data.node_features["feat"])
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous. # By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True buffer_device, non_blocking=True
) )
feature = y if not is_last_layer:
features.update("node", None, "feat", y)
return y return y
...@@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]: # [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.sample_neighbor(
graph, args.fanout if is_train else [-1]
)
############################################################################ ############################################################################
# [Input]: # [Input]:
...@@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features. # A FeatureFetcher object to fetch node features.
# [Role]: # [Role]:
# Initialize a feature fetcher for fetching features of the sampled # Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in evaluation/inference because features # subgraphs.
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
############################################################################ ############################################################################
if is_train: datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################ ############################################################################
# [Input]: # [Input]:
...@@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set): ...@@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
model.eval() model.eval()
evaluator = Evaluator(name="ogbl-citation2") evaluator = Evaluator(name="ogbl-citation2")
# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
args.fanout = [-1]
dataloader = create_dataloader( dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False args, graph, features, all_nodes_set, is_train=False
) )
# Compute node embeddings for the entire graph. # Compute node embeddings for the entire graph.
node_emb = model.inference(graph, features, dataloader, args.device) node_emb = model.inference(graph, features, dataloader, args.storage_device)
results = [] results = []
# Loop over both validation and test sets. # Loop over both validation and test sets.
......
...@@ -131,11 +131,9 @@ def create_dataloader( ...@@ -131,11 +131,9 @@ def create_dataloader(
# A FeatureFetcher object to fetch node features. # A FeatureFetcher object to fetch node features.
# [Role]: # [Role]:
# Initialize a feature fetcher for fetching features of the sampled # Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in inference because features are updated # subgraphs.
# as a whole during it, thus storing features in minibatch is unnecessary.
############################################################################ ############################################################################
if job != "infer": datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################ ############################################################################
# [Step-5]: # [Step-5]:
...@@ -194,14 +192,10 @@ class SAGE(nn.Module): ...@@ -194,14 +192,10 @@ class SAGE(nn.Module):
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
return hidden_x return hidden_x
def inference(self, graph, features, dataloader, device): def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat") pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)
buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
for layer_idx, layer in enumerate(self.layers): for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1 is_last_layer = layer_idx == len(self.layers) - 1
...@@ -213,11 +207,9 @@ class SAGE(nn.Module): ...@@ -213,11 +207,9 @@ class SAGE(nn.Module):
device=buffer_device, device=buffer_device,
pin_memory=pin_memory, pin_memory=pin_memory,
) )
feature = feature.to(device) for data in tqdm(dataloader):
# len(blocks) = 1
for step, data in tqdm(enumerate(dataloader)): hidden_x = layer(data.blocks[0], data.node_features["feat"])
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
...@@ -225,7 +217,8 @@ class SAGE(nn.Module): ...@@ -225,7 +217,8 @@ class SAGE(nn.Module):
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device buffer_device
) )
feature = y if not is_last_layer:
features.update("node", None, "feat", y)
return y return y
...@@ -245,7 +238,7 @@ def layerwise_infer( ...@@ -245,7 +238,7 @@ def layerwise_infer(
num_workers=args.num_workers, num_workers=args.num_workers,
job="infer", job="infer",
) )
pred = model.inference(graph, features, dataloader, args.device) pred = model.inference(graph, features, dataloader, args.storage_device)
pred = pred[test_set._items[0]] pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device) label = test_set._items[1].to(pred.device)
......
...@@ -15,6 +15,7 @@ __all__ = [ ...@@ -15,6 +15,7 @@ __all__ = [
"etype_tuple_to_str", "etype_tuple_to_str",
"CopyTo", "CopyTo",
"isin", "isin",
"index_select",
"expand_indptr", "expand_indptr",
"CSCFormatBase", "CSCFormatBase",
"seed", "seed",
...@@ -102,6 +103,33 @@ def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None): ...@@ -102,6 +103,33 @@ def expand_indptr(indptr, dtype=None, node_ids=None, output_size=None):
) )
def index_select(tensor, index):
"""Returns a new tensor which indexes the input tensor along dimension dim
using the entries in index.
The returned tensor has the same number of dimensions as the original tensor
(tensor). The first dimension has the same size as the length of index;
other dimensions have the same size as in the original tensor.
When tensor is a pinned tensor and index.is_cuda is True, the operation runs
on the CUDA device and the returned tensor will also be on CUDA.
Parameters
----------
tensor : torch.Tensor
The input tensor.
index : torch.Tensor
The 1-D tensor containing the indices to index.
Returns
-------
torch.Tensor
The indexed input tensor, equivalent to tensor[index].
"""
assert index.dim() == 1, "Index should be 1D tensor."
return torch.ops.graphbolt.index_select(tensor, index)
def etype_tuple_to_str(c_etype): def etype_tuple_to_str(c_etype):
"""Convert canonical etype from tuple to string. """Convert canonical etype from tuple to string.
......
...@@ -7,6 +7,7 @@ from typing import Dict, List ...@@ -7,6 +7,7 @@ from typing import Dict, List
import numpy as np import numpy as np
import torch import torch
from ..base import index_select
from ..feature_store import Feature from ..feature_store import Feature
from .basic_feature_store import BasicFeatureStore from .basic_feature_store import BasicFeatureStore
from .ondisk_metadata import OnDiskFeatureData from .ondisk_metadata import OnDiskFeatureData
...@@ -117,7 +118,7 @@ class TorchBasedFeature(Feature): ...@@ -117,7 +118,7 @@ class TorchBasedFeature(Feature):
if self._tensor.is_pinned(): if self._tensor.is_pinned():
return self._tensor.cuda() return self._tensor.cuda()
return self._tensor return self._tensor
return torch.ops.graphbolt.index_select(self._tensor, ids) return index_select(self._tensor, ids)
def size(self): def size(self):
"""Get the size of the feature. """Get the size of the feature.
...@@ -144,11 +145,6 @@ class TorchBasedFeature(Feature): ...@@ -144,11 +145,6 @@ class TorchBasedFeature(Feature):
updated. updated.
""" """
if ids is None: if ids is None:
assert self.size() == value.size()[1:], (
f"ids is None, so the entire feature will be updated. "
f"But the size of the feature is {self.size()}, "
f"while the size of the value is {value.size()[1:]}."
)
self._tensor = value self._tensor = value
else: else:
assert ids.shape[0] == value.shape[0], ( assert ids.shape[0] == value.shape[0], (
......
...@@ -250,6 +250,34 @@ def test_isin_non_1D_dim(): ...@@ -250,6 +250,34 @@ def test_isin_non_1D_dim():
gb.isin(elements, test_elements) gb.isin(elements, test_elements)
@pytest.mark.parametrize(
"dtype",
[
torch.bool,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
],
)
@pytest.mark.parametrize("idtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("pinned", [False, True])
def test_index_select(dtype, idtype, pinned):
if F._default_context_str != "gpu" and pinned:
pytest.skip("Pinned tests are available only on GPU.")
tensor = torch.tensor([[2, 3], [5, 5], [20, 13]], dtype=dtype)
tensor = tensor.pin_memory() if pinned else tensor.to(F.ctx())
index = torch.tensor([0, 2], dtype=idtype, device=F.ctx())
gb_result = gb.index_select(tensor, index)
torch_result = tensor.to(F.ctx())[index.long()]
assert torch.equal(torch_result, gb_result)
def torch_expand_indptr(indptr, dtype, nodes=None): def torch_expand_indptr(indptr, dtype, nodes=None):
if nodes is None: if nodes is None:
nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device) nodes = torch.arange(len(indptr) - 1, dtype=dtype, device=indptr.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment