Unverified Commit d3dd8e37 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt]Add to dgl datapipe warpper (#6390)

parent d8101fe4
......@@ -93,9 +93,8 @@ class SAGE(LightningModule):
)
def training_step(self, batch, batch_idx):
# TODO: Move this to the data pipeline as a stage.
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
x = blocks[0].srcdata["feat"]
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
y_hat = self(blocks, x)
loss = F.cross_entropy(y_hat, y)
......@@ -111,8 +110,8 @@ class SAGE(LightningModule):
return loss
def validation_step(self, batch, batch_idx):
blocks = [block.to("cuda") for block in batch.to_dgl_blocks()]
x = blocks[0].srcdata["feat"]
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
y_hat = self(blocks, x)
self.val_acc(torch.argmax(y_hat, 1), y)
......@@ -160,6 +159,7 @@ class DataModule(LightningDataModule):
)
datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
datapipe = datapipe.to_dgl()
dataloader = gb.MultiProcessDataLoader(
datapipe, num_workers=self.num_workers
)
......
......@@ -167,6 +167,18 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
############################################################################
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
# [Step-4]:
# gb.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = gb.to_dgl()
############################################################################
# [Input]:
# 'device': The device to copy the data to.
......@@ -193,19 +205,10 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
return dataloader
# TODO[Keli]: Remove this helper function later.
def to_binary_link_dgl_computing_pack(data: gb.MiniBatch):
"""Convert the minibatch to a training pair and a label tensor."""
batch_size = data.compacted_node_pairs[0].shape[0]
neg_ratio = data.compacted_negative_dsts.shape[0] // batch_size
pos_src, pos_dst = data.compacted_node_pairs
if data.compacted_negative_srcs is None:
neg_src = pos_src.repeat_interleave(neg_ratio, dim=0)
else:
neg_src = data.compacted_negative_srcs
neg_dst = data.compacted_negative_dsts
pos_src, pos_dst = data.positive_node_pairs
neg_src, neg_dst = data.negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
......@@ -234,7 +237,7 @@ def evaluate(args, graph, features, itemset, model):
# Unpack MiniBatch.
compacted_pairs, _ = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float()
blocks = data.to_dgl_blocks()
blocks = data.blocks
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
......@@ -272,7 +275,7 @@ def train(args, graph, features, train_set, valid_set, model):
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"].float()
# Convert sampled subgraphs to DGL blocks.
blocks = data.to_dgl_blocks()
blocks = data.blocks
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
......
......@@ -140,6 +140,18 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
############################################################################
# [Step-4]:
# self.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = datapipe.to_dgl()
############################################################################
# [Step-5]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
......@@ -167,11 +179,9 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)
for step, data in tqdm.tqdm(enumerate(dataloader)):
blocks = data.to_dgl_blocks()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(blocks, x))
y_hats.append(model(data.blocks, x))
res = MF.accuracy(
torch.cat(y_hats),
......@@ -201,9 +211,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
# in the last layer's computation graph.
y = data.labels
# TODO[Mingbang]: Move the to_dgl_blocks() to a datapipe stage later
# The predicted labels.
y_hat = model(data.to_dgl_blocks(), x)
y_hat = model(data.blocks, x)
# Compute loss.
loss = F.cross_entropy(y_hat, y)
......
......@@ -123,6 +123,9 @@ def create_dataloader(
node_feature_keys["institution"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys)
# Convert a mini-batch to dgl mini-batch for computing.
datapipe = datapipe.to_dgl()
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
......@@ -435,7 +438,7 @@ def extract_node_features(name, block, data, node_embed, device):
)
else:
node_features = {
ntype: block.srcnodes[ntype].data["feat"].to(device)
ntype: data.node_features[(ntype, "feat")]
for ntype in block.srctypes
}
# Original feature data are stored in float16 while model weights are
......@@ -495,7 +498,7 @@ def evaluate(
y_true = list()
for data in tqdm(data_loader, desc="Inference"):
blocks = [block.to(device) for block in data.to_dgl_blocks()]
blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features(
name, blocks[0], data, node_embed, device
)
......@@ -563,10 +566,10 @@ def run(
)
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Fetch the number of seed nodes in the batch.
num_seeds = data.seed_nodes[category].shape[0]
num_seeds = data.output_nodes[category].shape[0]
# Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.to_dgl_blocks()]
blocks = [block.to(device) for block in data.blocks]
# Extract the node features from embedding layer or raw features.
node_features = extract_node_features(
......
......@@ -35,3 +35,21 @@ class MiniBatchTransformer(Mapper):
minibatch, MiniBatch
), "The transformer output should be an instance of MiniBatch"
return minibatch
@functional_datapipe("to_dgl")
class DGLMiniBatchConverter(Mapper):
"""Convert a graphbolt mini-batch to a dgl mini-batch."""
def __init__(
self,
datapipe,
):
"""
Initlization for a subgraph transformer.
Parameters
----------
datapipe : DataPipe
The datapipe.
"""
super().__init__(datapipe, MiniBatch.to_dgl)
import dgl.graphbolt as gb
import gb_test_utils
import torch
def test_dgl_minibatch_converter():
N = 32
B = 4
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
feature_store = gb.BasicFeatureStore(features)
item_sampler = gb.ItemSampler(itemset, batch_size=B)
subgraph_sampler = gb.NeighborSampler(
item_sampler,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
feature_fetcher = gb.FeatureFetcher(
subgraph_sampler,
feature_store,
["a"],
)
dgl_converter = gb.DGLMiniBatchConverter(feature_fetcher)
dataloader = gb.SingleProcessDataLoader(dgl_converter)
assert len(list(dataloader)) == N // B
minibatch = next(iter(dataloader))
assert isinstance(minibatch, gb.DGLMiniBatch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment