Unverified Commit 3afa105b authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt][Doc] Update docs related to `seeds`. (#7351)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 6133cecd
...@@ -79,11 +79,11 @@ can be used on heterogeneous graphs: ...@@ -79,11 +79,11 @@ can be used on heterogeneous graphs:
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
) )
......
...@@ -30,9 +30,9 @@ edges(namely, node pairs) in the training set instead of the nodes. ...@@ -30,9 +30,9 @@ edges(namely, node pairs) in the training set instead of the nodes.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph() g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2) seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,)) labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently: # Or equivalently:
...@@ -83,9 +83,9 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with ...@@ -83,9 +83,9 @@ You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph() g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2) seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,)) labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True) exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
...@@ -138,9 +138,9 @@ concatenating the incident node features and projecting it with a dense layer. ...@@ -138,9 +138,9 @@ concatenating the incident node features and projecting it with a dense layer.
super().__init__() super().__init__()
self.W = nn.Linear(2 * in_features, num_classes) self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, node_pairs, x): def forward(self, seeds, x):
src_x = x[node_pairs[0]] src_x = x[seeds[:, 0]]
dst_x = x[node_pairs[1]] dst_x = x[seeds[:, 1]]
data = torch.cat([src_x, dst_x], 1) data = torch.cat([src_x, dst_x], 1)
return self.W(data) return self.W(data)
...@@ -157,9 +157,9 @@ loader, as well as the input node features as follows: ...@@ -157,9 +157,9 @@ loader, as well as the input node features as follows:
in_features, hidden_features, out_features) in_features, hidden_features, out_features)
self.predictor = ScorePredictor(num_classes, out_features) self.predictor = ScorePredictor(num_classes, out_features)
def forward(self, blocks, x, node_pairs): def forward(self, blocks, x, seeds):
x = self.gcn(blocks, x) x = self.gcn(blocks, x)
return self.predictor(node_pairs, x) return self.predictor(seeds, x)
DGL ensures that that the nodes in the edge subgraph are the same as the DGL ensures that that the nodes in the edge subgraph are the same as the
output nodes of the last MFG in the generated list of MFGs. output nodes of the last MFG in the generated list of MFGs.
...@@ -182,7 +182,7 @@ their incident node representations. ...@@ -182,7 +182,7 @@ their incident node representations.
for data in dataloader: for data in dataloader:
blocks = data.blocks blocks = data.blocks
x = data.edge_features("feat") x = data.edge_features("feat")
y_hat = model(data.blocks, x, data.positive_node_pairs) y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat) loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
...@@ -226,10 +226,10 @@ over the edge types. ...@@ -226,10 +226,10 @@ over the edge types.
super().__init__() super().__init__()
self.W = nn.Linear(2 * in_features, num_classes) self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, node_pairs, x): def forward(self, seeds, x):
scores = {} scores = {}
for etype in node_pairs.keys(): for etype in seeds.keys():
src, dst = node_pairs[etype] src, dst = seeds[etype].T
data = torch.cat([x[etype][src], x[etype][dst]], 1) data = torch.cat([x[etype][src], x[etype][dst]], 1)
scores[etype] = self.W(data) scores[etype] = self.W(data)
return scores return scores
...@@ -242,9 +242,9 @@ over the edge types. ...@@ -242,9 +242,9 @@ over the edge types.
in_features, hidden_features, out_features, etypes) in_features, hidden_features, out_features, etypes)
self.pred = ScorePredictor(num_classes, out_features) self.pred = ScorePredictor(num_classes, out_features)
def forward(self, node_pairs, blocks, x): def forward(self, seeds, blocks, x):
x = self.rgcn(blocks, x) x = self.rgcn(blocks, x)
return self.pred(node_pairs, x) return self.pred(seeds, x)
Data loader definition is almost identical to that of homogeneous graph. The Data loader definition is almost identical to that of homogeneous graph. The
only difference is that the train_set is now an instance of only difference is that the train_set is now an instance of
...@@ -256,17 +256,17 @@ only difference is that the train_set is now an instance of ...@@ -256,17 +256,17 @@ only difference is that the train_set is now an instance of
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph() g = gb.SamplingGraph()
node_pairs = torch.arange(0, 1000).reshape(-1, 2) seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 3, (1000,)) labels = torch.randint(0, 3, (1000,))
node_pairs_labels = { seeds_labels = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (seeds, labels), names=("seeds", "labels")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (seeds, labels), names=("seeds", "labels")
), ),
} }
train_set = gb.ItemSetDict(node_pairs_labels) train_set = gb.ItemSetDict(seeds_labels)
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True) datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
...@@ -316,7 +316,7 @@ dictionaries of node types and predictions here. ...@@ -316,7 +316,7 @@ dictionaries of node types and predictions here.
for data in dataloader: for data in dataloader:
blocks = data.blocks blocks = data.blocks
x = data.edge_features(("user:like:item", "feat")) x = data.edge_features(("user:like:item", "feat"))
y_hat = model(data.blocks, x, data.positive_node_pairs) y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat) loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
......
...@@ -106,7 +106,7 @@ and combined as well. ...@@ -106,7 +106,7 @@ and combined as well.
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous. # By design, our output nodes are contiguous.
y[ y[
data.seed_nodes[0] : data.seed_nodes[-1] + 1 data.seeds[0] : data.seeds[-1] + 1
] = hidden_x.to(device) ] = hidden_x.to(device)
feature = y feature = y
......
...@@ -53,8 +53,8 @@ proportional to a power of degrees. ...@@ -53,8 +53,8 @@ proportional to a power of degrees.
self.weights = node_degrees ** 0.75 self.weights = node_degrees ** 0.75
self.k = k self.k = k
def _sample_with_etype(node_pairs, etype=None): def _sample_with_etype(self, seeds, etype=None):
src, _ = node_pairs src, _ = seeds.T
src = src.repeat_interleave(self.k) src = src.repeat_interleave(self.k)
dst = self.weights.multinomial(len(src), replacement=True) dst = self.weights.multinomial(len(src), replacement=True)
return src, dst return src, dst
...@@ -95,7 +95,7 @@ Define a GraphSAGE model for minibatch training ...@@ -95,7 +95,7 @@ Define a GraphSAGE model for minibatch training
When a negative sampler is provided, the data loader will generate positive and When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs). negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Use `node_pairs_with_labels` to get compact node pairs with corresponding Use `compacted_seeds` and `labels` to get compact node pairs and corresponding
labels. labels.
...@@ -116,7 +116,8 @@ above. ...@@ -116,7 +116,8 @@ above.
start_epoch_time = time.time() start_epoch_time = time.time()
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, labels = data.node_pairs_with_labels compacted_seeds = data.compacted_seeds.T
labels = data.labels
node_feature = data.node_features["feat"] node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks. # Convert sampled subgraphs to DGL blocks.
blocks = data.blocks blocks = data.blocks
...@@ -124,7 +125,7 @@ above. ...@@ -124,7 +125,7 @@ above.
# Get the embeddings of the input nodes. # Get the embeddings of the input nodes.
y = model(blocks, node_feature) y = model(blocks, node_feature)
logits = model.predictor( logits = model.predictor(
y[compacted_pairs[0]] * y[compacted_pairs[1]] y[compacted_seeds[0]] * y[compacted_seeds[1]]
).squeeze() ).squeeze()
# Compute loss. # Compute loss.
...@@ -217,8 +218,8 @@ If you want to give your own negative sampling function, just inherit from the ...@@ -217,8 +218,8 @@ If you want to give your own negative sampling function, just inherit from the
} }
self.k = k self.k = k
def _sample_with_etype(node_pairs, etype): def _sample_with_etype(self, seeds, etype):
src, _ = node_pairs src, _ = seeds.T
src = src.repeat_interleave(self.k) src = src.repeat_interleave(self.k)
dst = self.weights[etype].multinomial(len(src), replacement=True) dst = self.weights[etype].multinomial(len(src), replacement=True)
return src, dst return src, dst
...@@ -241,7 +242,8 @@ loss on specific edge type. ...@@ -241,7 +242,8 @@ loss on specific edge type.
start_epoch_time = time.time() start_epoch_time = time.time()
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Unpack MiniBatch. # Unpack MiniBatch.
compacted_pairs, labels = data.node_pairs_with_labels compacted_seeds = data.compacted_seeds
labels = data.labels
node_features = { node_features = {
ntype: data.node_features[(ntype, "feat")] ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes for ntype in data.blocks[0].srctypes
...@@ -251,8 +253,8 @@ loss on specific edge type. ...@@ -251,8 +253,8 @@ loss on specific edge type.
# Get the embeddings of the input nodes. # Get the embeddings of the input nodes.
y = model(blocks, node_feature) y = model(blocks, node_feature)
logits = model.predictor( logits = model.predictor(
y[category][compacted_pairs[category][0]] y[category][compacted_pairs[category][:, 0]]
* y[category][compacted_pairs[category][1]] * y[category][compacted_pairs[category][:, 1]]
).squeeze() ).squeeze()
# Compute loss. # Compute loss.
......
...@@ -201,9 +201,8 @@ such as ``num_classes`` and all these fields will be passed to the ...@@ -201,9 +201,8 @@ such as ``num_classes`` and all these fields will be passed to the
The ``name`` field is used to specify the name of the data. It is mandatory The ``name`` field is used to specify the name of the data. It is mandatory
and used to specify the data fields of ``MiniBatch`` for sampling. It can and used to specify the data fields of ``MiniBatch`` for sampling. It can
be either ``seed_nodes``, ``labels``, ``node_pairs``, ``negative_srcs`` or be either ``seeds``, ``labels`` or ``indexes``. If any other name is used,
``negative_dsts``. If any other name is used, it will be added into the it will be added into the ``MiniBatch`` data fields.
``MiniBatch`` data fields.
- ``format``: ``string`` - ``format``: ``string``
The ``format`` field is used to specify the format of the data. It can be The ``format`` field is used to specify the format of the data. It can be
......
...@@ -61,12 +61,12 @@ ...@@ -61,12 +61,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"node_pairs = torch.tensor(\n", "seeds = torch.tensor(\n",
" [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n", " [[7, 0], [6, 0], [1, 3], [3, 3], [2, 4], [8, 4], [1, 4], [2, 4], [1, 5],\n",
" [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n", " [9, 6], [0, 6], [8, 6], [7, 7], [7, 7], [4, 7], [6, 8], [5, 8], [9, 9],\n",
" [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n", " [4, 9], [4, 9], [5, 9], [9, 9], [5, 9], [9, 9], [7, 9]]\n",
")\n", ")\n",
"item_set = gb.ItemSet(node_pairs, names=\"node_pairs\")\n", "item_set = gb.ItemSet(seeds, names=\"seeds\")\n",
"print(list(item_set))" "print(list(item_set))"
] ]
}, },
...@@ -262,7 +262,7 @@ ...@@ -262,7 +262,7 @@
"num_nodes = 10\n", "num_nodes = 10\n",
"nodes = torch.arange(num_nodes)\n", "nodes = torch.arange(num_nodes)\n",
"labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n", "labels = torch.tensor([1, 2, 0, 2, 2, 0, 2, 2, 2, 2])\n",
"item_set = gb.ItemSet((nodes, labels), names=(\"seed_nodes\", \"labels\"))\n", "item_set = gb.ItemSet((nodes, labels), names=(\"seeds\", \"labels\"))\n",
"\n", "\n",
"indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n", "indptr = torch.tensor([0, 2, 2, 2, 4, 8, 9, 12, 15, 17, 25])\n",
"indices = torch.tensor(\n", "indices = torch.tensor(\n",
...@@ -311,4 +311,4 @@ ...@@ -311,4 +311,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment