Unverified Commit 0348ad3d authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Move to_dgl out of data loader in examples. (#6705)

parent 7e12f973
......@@ -128,7 +128,6 @@ def create_dataloader(
)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
############################################################################
# [Note]:
......@@ -154,6 +153,7 @@ def evaluate(rank, model, dataloader, num_classes, device):
for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
):
data = data.to_dgl()
blocks = data.blocks
x = data.node_features["feat"]
y.append(data.labels)
......@@ -206,6 +206,9 @@ def train(
if rank == 0
else enumerate(train_dataloader)
):
# Convert data to DGL format.
data = data.to_dgl()
# The input features are from the source nodes in the first
# layer's computation graph.
x = data.node_features["feat"]
......
......@@ -93,6 +93,7 @@ class SAGE(LightningModule):
)
def training_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
......@@ -110,6 +111,7 @@ class SAGE(LightningModule):
return loss
def validation_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
......@@ -158,7 +160,6 @@ class DataModule(LightningDataModule):
)
datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers)
return dataloader
......@@ -214,7 +215,7 @@ if __name__ == "__main__":
args.num_workers,
)
in_size = dataset.feature.size("node", None, "feat")[0]
model = SAGE(in_size, 256, datamodule.num_classes).to(torch.double)
model = SAGE(in_size, 256, datamodule.num_classes)
# Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
......
......@@ -100,6 +100,7 @@ class SAGE(nn.Module):
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
......@@ -207,18 +208,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
# [Step-4]:
# datapipe.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = datapipe.to_dgl()
############################################################################
# [Input]:
# 'device': The device to copy the data to.
......@@ -332,6 +321,9 @@ def train(args, model, graph, features, train_set):
total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"]
......
......@@ -126,18 +126,6 @@ def create_dataloader(
############################################################################
# [Step-4]:
# self.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = datapipe.to_dgl()
############################################################################
# [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
......@@ -147,7 +135,7 @@ def create_dataloader(
datapipe = datapipe.copy_to(device=device)
############################################################################
# [Step-6]:
# [Step-5]:
# gb.DataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
......@@ -214,6 +202,7 @@ class SAGE(nn.Module):
feature = feature.to(device)
for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
......@@ -272,6 +261,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)
for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
......@@ -302,6 +292,9 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
model.train()
total_loss = 0
for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
......
......@@ -47,9 +47,6 @@ def create_dataloader(dateset, device, is_train=True):
dataset.feature, node_feature_keys=["feat"]
)
# Convert the mini-batch to DGL format to train a DGL model.
datapipe = datapipe.to_dgl()
# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)
......@@ -101,6 +98,9 @@ def evaluate(model, dataset, device):
logits = []
labels = []
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
......@@ -140,6 +140,9 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
......
......@@ -25,9 +25,6 @@ def create_dataloader(dateset, itemset, device):
dataset.feature, node_feature_keys=["feat"]
)
# Convert the mini-batch to DGL format to train a DGL model.
datapipe = datapipe.to_dgl()
# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)
......@@ -60,6 +57,7 @@ def evaluate(model, dataset, itemset, device):
dataloader = create_dataloader(dataset, itemset, device)
for step, data in enumerate(dataloader):
data = data.to_dgl()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
......@@ -86,6 +84,9 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# The features of sampled nodes.
x = data.node_features["feat"]
......
......@@ -124,9 +124,6 @@ def create_dataloader(
node_feature_keys["institution"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys)
# Convert a mini-batch to dgl mini-batch for computing.
datapipe = datapipe.to_dgl()
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
......@@ -490,6 +487,9 @@ def evaluate(
y_true = list()
for data in tqdm(data_loader, desc="Inference"):
# Convert data to DGL format for computing.
data = data.to_dgl()
blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features(
name, blocks[0], data, node_embed, device
......@@ -558,6 +558,9 @@ def run(
total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment