"tests/python/vscode:/vscode.git/clone" did not exist on "55af15d4a9736a530eb53faef4bca15d040090ca"
Unverified Commit 9d417346 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] update to_dgl() in examples (#6763)

parent a5e5f11a
......@@ -153,7 +153,6 @@ def evaluate(rank, model, dataloader, num_classes, device):
for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
):
data = data.to_dgl()
blocks = data.blocks
x = data.node_features["feat"]
y.append(data.labels)
......@@ -206,9 +205,6 @@ def train(
if rank == 0
else enumerate(train_dataloader)
):
# Convert data to DGL format.
data = data.to_dgl()
# The input features are from the source nodes in the first
# layer's computation graph.
x = data.node_features["feat"]
......
......@@ -93,7 +93,6 @@ class SAGE(LightningModule):
)
def training_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
......@@ -111,7 +110,6 @@ class SAGE(LightningModule):
return loss
def validation_step(self, batch, batch_idx):
batch = batch.to_dgl()
blocks = [block.to("cuda") for block in batch.blocks]
x = batch.node_features["feat"]
y = batch.labels.to("cuda")
......
......@@ -101,7 +101,6 @@ class SAGE(nn.Module):
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
......@@ -237,20 +236,6 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
return dataloader
def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src, pos_dst = data.positive_node_pairs
neg_src, neg_dst = data.negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
@torch.no_grad()
def compute_mrr(args, model, evaluator, node_emb, src, dst, neg_dst):
"""Compute the Mean Reciprocal Rank (MRR) for given source and destination
......@@ -324,11 +309,8 @@ def train(args, model, graph, features, train_set):
total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels
node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks.
blocks = data.blocks
......
......@@ -202,7 +202,6 @@ class SAGE(nn.Module):
feature = feature.to(device)
for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
......@@ -261,7 +260,6 @@ def evaluate(args, model, graph, features, itemset, num_classes):
)
for step, data in tqdm(enumerate(dataloader)):
data = data.to_dgl()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
......@@ -292,9 +290,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
model.train()
total_loss = 0
for step, data in enumerate(dataloader):
# Convert data to DGL format.
data = data.to_dgl()
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
......
......@@ -76,20 +76,6 @@ class GraphSAGE(nn.Module):
return hidden_x
def to_binary_link_dgl_computing_pack(data: gb.MiniBatch):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src, pos_dst = data.positive_node_pairs
neg_src, neg_dst = data.negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels)
@torch.no_grad()
def evaluate(model, dataset, device):
model.eval()
......
......@@ -176,7 +176,7 @@ def rel_graph_embed(graph, embed_size):
for the "paper" node type.
"""
node_num = {}
node_type_to_id = graph.metadata.node_type_to_id
node_type_to_id = graph.node_type_to_id
node_type_offset = graph.node_type_offset
for ntype, ntype_id in node_type_to_id.items():
# Skip the "paper" node type.
......@@ -328,12 +328,12 @@ class EntityClassify(nn.Module):
# Generate and sort a list of unique edge types from the input graph.
# eg. ['writes', 'cites']
etypes = list(graph.metadata.edge_type_to_id.keys())
etypes = list(graph.edge_type_to_id.keys())
etypes = [gb.etype_str_to_tuple(etype)[1] for etype in etypes]
self.relation_names = etypes
self.relation_names.sort()
self.dropout = 0.5
ntypes = list(graph.metadata.node_type_to_id.keys())
ntypes = list(graph.node_type_to_id.keys())
self.layers = nn.ModuleList()
# First layer: transform input features to hidden features. Use ReLU
......@@ -487,9 +487,6 @@ def evaluate(
y_true = list()
for data in tqdm(data_loader, desc="Inference"):
# Convert data to DGL format for computing.
data = data.to_dgl()
blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features(
name, blocks[0], data, node_embed, device
......@@ -558,9 +555,6 @@ def run(
total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks]
......
......@@ -118,7 +118,6 @@ def create_dataloader(
)
datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=0)
return dataloader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment