Unverified Commit 0348ad3d authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Move to_dgl out of data loader in examples. (#6705)

parent 7e12f973
...@@ -128,7 +128,6 @@ def create_dataloader( ...@@ -128,7 +128,6 @@ def create_dataloader(
) )
datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
############################################################################ ############################################################################
# [Note]: # [Note]:
...@@ -154,6 +153,7 @@ def evaluate(rank, model, dataloader, num_classes, device): ...@@ -154,6 +153,7 @@ def evaluate(rank, model, dataloader, num_classes, device):
for step, data in ( for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader) tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
): ):
data = data.to_dgl()
blocks = data.blocks blocks = data.blocks
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
...@@ -206,6 +206,9 @@ def train( ...@@ -206,6 +206,9 @@ def train(
if rank == 0 if rank == 0
else enumerate(train_dataloader) else enumerate(train_dataloader)
): ):
# Convert data to DGL format.
data = data.to_dgl()
# The input features are from the source nodes in the first # The input features are from the source nodes in the first
# layer's computation graph. # layer's computation graph.
x = data.node_features["feat"] x = data.node_features["feat"]
......
...@@ -93,6 +93,7 @@ class SAGE(LightningModule): ...@@ -93,6 +93,7 @@ class SAGE(LightningModule):
) )
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks] blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"] x = batch.node_features["feat"]
y = batch.labels.to("cuda") y = batch.labels.to("cuda")
...@@ -110,6 +111,7 @@ class SAGE(LightningModule): ...@@ -110,6 +111,7 @@ class SAGE(LightningModule):
return loss return loss
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks] blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"] x = batch.node_features["feat"]
y = batch.labels.to("cuda") y = batch.labels.to("cuda")
...@@ -158,7 +160,6 @@ class DataModule(LightningDataModule): ...@@ -158,7 +160,6 @@ class DataModule(LightningDataModule):
) )
datapipe = sampler(self.graph, self.fanouts) datapipe = sampler(self.graph, self.fanouts)
datapipe = datapipe.fetch_feature(self.feature_store, ["feat"]) datapipe = datapipe.fetch_feature(self.feature_store, ["feat"])
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers) dataloader = gb.DataLoader(datapipe, num_workers=self.num_workers)
return dataloader return dataloader
...@@ -214,7 +215,7 @@ if __name__ == "__main__": ...@@ -214,7 +215,7 @@ if __name__ == "__main__":
args.num_workers, args.num_workers,
) )
in_size = dataset.feature.size("node", None, "feat")[0] in_size = dataset.feature.size("node", None, "feat")[0]
model = SAGE(in_size, 256, datamodule.num_classes).to(torch.double) model = SAGE(in_size, 256, datamodule.num_classes)
# Train. # Train.
checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max") checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")
......
...@@ -100,6 +100,7 @@ class SAGE(nn.Module): ...@@ -100,6 +100,7 @@ class SAGE(nn.Module):
) )
feature = feature.to(device) feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)): for step, data in tqdm.tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes] x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
...@@ -207,18 +208,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -207,18 +208,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
if is_train: if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
# [Step-4]:
# datapipe.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = datapipe.to_dgl()
############################################################################ ############################################################################
# [Input]: # [Input]:
# 'device': The device to copy the data to. # 'device': The device to copy the data to.
...@@ -332,6 +321,9 @@ def train(args, model, graph, features, train_set): ...@@ -332,6 +321,9 @@ def train(args, model, graph, features, train_set):
total_loss = 0 total_loss = 0
start_epoch_time = time.time() start_epoch_time = time.time()
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data) compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
node_feature = data.node_features["feat"] node_feature = data.node_features["feat"]
......
...@@ -126,18 +126,6 @@ def create_dataloader( ...@@ -126,18 +126,6 @@ def create_dataloader(
############################################################################ ############################################################################
# [Step-4]: # [Step-4]:
# self.to_dgl()
# [Input]:
# 'datapipe': The previous datapipe object.
# [Output]:
# A DGLMiniBatch used for computing.
# [Role]:
# Convert a mini-batch to dgl-minibatch.
############################################################################
datapipe = datapipe.to_dgl()
############################################################################
# [Step-5]:
# self.copy_to() # self.copy_to()
# [Input]: # [Input]:
# 'device': The device to copy the data to. # 'device': The device to copy the data to.
...@@ -147,7 +135,7 @@ def create_dataloader( ...@@ -147,7 +135,7 @@ def create_dataloader(
datapipe = datapipe.copy_to(device=device) datapipe = datapipe.copy_to(device=device)
############################################################################ ############################################################################
# [Step-6]: # [Step-5]:
# gb.DataLoader() # gb.DataLoader()
# [Input]: # [Input]:
# 'datapipe': The datapipe object to be used for data loading. # 'datapipe': The datapipe object to be used for data loading.
...@@ -214,6 +202,7 @@ class SAGE(nn.Module): ...@@ -214,6 +202,7 @@ class SAGE(nn.Module):
feature = feature.to(device) feature = feature.to(device)
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes] x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1 hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer: if not is_last_layer:
...@@ -272,6 +261,7 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -272,6 +261,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
) )
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x))
...@@ -302,6 +292,9 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -302,6 +292,9 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
model.train() model.train()
total_loss = 0 total_loss = 0
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# The input features from the source nodes in the first layer's # The input features from the source nodes in the first layer's
# computation graph. # computation graph.
x = data.node_features["feat"] x = data.node_features["feat"]
......
...@@ -47,9 +47,6 @@ def create_dataloader(dateset, device, is_train=True): ...@@ -47,9 +47,6 @@ def create_dataloader(dateset, device, is_train=True):
dataset.feature, node_feature_keys=["feat"] dataset.feature, node_feature_keys=["feat"]
) )
# Convert the mini-batch to DGL format to train a DGL model.
datapipe = datapipe.to_dgl()
# Copy the mini-batch to the designated device for training. # Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
...@@ -101,6 +98,9 @@ def evaluate(model, dataset, device): ...@@ -101,6 +98,9 @@ def evaluate(model, dataset, device):
logits = [] logits = []
labels = [] labels = []
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, label = to_binary_link_dgl_computing_pack(data) compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
...@@ -140,6 +140,9 @@ def train(model, dataset, device): ...@@ -140,6 +140,9 @@ def train(model, dataset, device):
# mini-batches. # mini-batches.
######################################################################## ########################################################################
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data) compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
......
...@@ -25,9 +25,6 @@ def create_dataloader(dateset, itemset, device): ...@@ -25,9 +25,6 @@ def create_dataloader(dateset, itemset, device):
dataset.feature, node_feature_keys=["feat"] dataset.feature, node_feature_keys=["feat"]
) )
# Convert the mini-batch to DGL format to train a DGL model.
datapipe = datapipe.to_dgl()
# Copy the mini-batch to the designated device for training. # Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
...@@ -60,6 +57,7 @@ def evaluate(model, dataset, itemset, device): ...@@ -60,6 +57,7 @@ def evaluate(model, dataset, itemset, device):
dataloader = create_dataloader(dataset, itemset, device) dataloader = create_dataloader(dataset, itemset, device)
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
data = data.to_dgl()
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x))
...@@ -86,6 +84,9 @@ def train(model, dataset, device): ...@@ -86,6 +84,9 @@ def train(model, dataset, device):
# mini-batches. # mini-batches.
######################################################################## ########################################################################
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# The features of sampled nodes. # The features of sampled nodes.
x = data.node_features["feat"] x = data.node_features["feat"]
......
...@@ -124,9 +124,6 @@ def create_dataloader( ...@@ -124,9 +124,6 @@ def create_dataloader(
node_feature_keys["institution"] = ["feat"] node_feature_keys["institution"] = ["feat"]
datapipe = datapipe.fetch_feature(features, node_feature_keys) datapipe = datapipe.fetch_feature(features, node_feature_keys)
# Convert a mini-batch to dgl mini-batch for computing.
datapipe = datapipe.to_dgl()
# Move the mini-batch to the appropriate device. # Move the mini-batch to the appropriate device.
# `device`: # `device`:
# The device to move the mini-batch to. # The device to move the mini-batch to.
...@@ -490,6 +487,9 @@ def evaluate( ...@@ -490,6 +487,9 @@ def evaluate(
y_true = list() y_true = list()
for data in tqdm(data_loader, desc="Inference"): for data in tqdm(data_loader, desc="Inference"):
# Convert data to DGL format for computing.
data = data.to_dgl()
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features( node_features = extract_node_features(
name, blocks[0], data, node_embed, device name, blocks[0], data, node_embed, device
...@@ -558,6 +558,9 @@ def run( ...@@ -558,6 +558,9 @@ def run(
total_loss = 0 total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"): for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment