Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9f444cd7
"src/array/cuda/csr2coo.hip" did not exist on "66a54555970bcbdc7d8af2df14bd7e766522fa6e"
Unverified
Commit
9f444cd7
authored
Dec 01, 2023
by
Rhett Ying
Committed by
GitHub
Dec 01, 2023
Browse files
[doc] update user guide 6.6 (#6656)
parent
3ea930d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
65 additions
and
47 deletions
+65
-47
docs/source/guide/minibatch-inference.rst
docs/source/guide/minibatch-inference.rst
+65
-47
No files found.
docs/source/guide/minibatch-inference.rst
View file @
9f444cd7
...
@@ -39,63 +39,81 @@ Implementing Offline Inference
...
@@ -39,63 +39,81 @@ Implementing Offline Inference
Consider the two-layer GCN we have mentioned in Section 6.1
Consider the two-layer GCN we have mentioned in Section 6.1
:ref:`guide-minibatch-node-classification-model`. The way
:ref:`guide-minibatch-node-classification-model`. The way
to implement offline inference still involves using
to implement offline inference still involves using
:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`, but sampling for
:class:`~dgl.graphbolt.NeighborSampler`, but sampling for
only one layer at a time. Note that offline inference is implemented as
only one layer at a time.
a method of the GNN module because the computation on one layer depends
on how messages are aggregated and combined as well.
.. code:: python
.. code:: python
class StochasticTwoLayerGCN(nn.Module):
datapipe = gb.ItemSampler(all_nodes_set, batch_size=1024, shuffle=True)
def __init__(self, in_features, hidden_features, out_features):
datapipe = datapipe.sample_neighbor(g, [-1]) # 1 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Note that offline inference is implemented as a method of the GNN module
because the computation on one layer depends on how messages are aggregated
and combined as well.
.. code:: python
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size):
super().__init__()
super().__init__()
self.hidden_features = hidden_features
self.layers = nn.ModuleList()
self.out_features = out_features
# Three-layer GraphSAGE-mean.
self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.n_layers = 2
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(0.5)
self.hidden_size = hidden_size
self.out_size = out_size
def forward(self, blocks, x):
def forward(self, blocks, x):
x_dst = x[:blocks[0].number_of_dst_nodes()]
hidden_x = x
x = F.relu(self.conv1(blocks[0], (x, x_dst)))
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
x_dst = x[:blocks[1].number_of_dst_nodes()]
hidden_x = layer(block, hidden_x)
x = F.relu(self.conv2(blocks[1], (x, x_dst)))
is_last_layer = layer_idx == len(self.layers) - 1
return x
if not is_last_layer:
hidden_x = F.relu(hidden_x)
def inference(self, g, x, batch_size, device):
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader, device):
"""
"""
Offline inference with this module
Offline inference with this module
"""
"""
feature = features.read("node", None, "feat")
# Compute representations layer by layer
# Compute representations layer by layer
for l, layer in enumerate([self.conv1, self.conv2]):
for layer_idx, layer in enumerate(self.layers):
y = torch.zeros(g.num_nodes(),
is_last_layer = layer_idx == len(self.layers) - 1
self.hidden_features
if l != self.n_layers - 1
y = torch.empty(
else self.out_features)
graph.total_num_nodes,
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
self.out_size if is_last_layer else self.hidden_size,
dataloader = dgl.dataloading.NodeDataLoader(
dtype=torch.float32,
g, torch.arange(g.num_nodes()), sampler,
device=buffer_device,
batch_size=batch_size,
pin_memory=pin_memory,
shuffle=True,
)
drop_last=False)
feature = feature.to(device)
# Within a layer, iterate over nodes in batches
for step, data in tqdm(enumerate(dataloader)):
for input_nodes, output_nodes, blocks in dataloader:
x = feature[data.input_nodes]
block = blocks[0]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
if not is_last_layer:
# Copy the features of necessary input nodes to GPU
hidden_x = F.relu(hidden_x)
h = x[input_nodes].to(device)
hidden_x = self.dropout(hidden_x)
# Compute output. Note that this computation is the same
# By design, our output nodes are contiguous.
# but only for a single layer.
y[
h_dst = h[:block.number_of_dst_nodes()]
data.output_nodes[0] : data.output_nodes[-1] + 1
h = F.relu(layer(block, (h, h_dst)))
] = hidden_x.to(device)
# Copy to output back to CPU.
feature = y
y[output_nodes] = h.cpu()
x = y
return y
return y
Note that for the purpose of computing evaluation metric on the
Note that for the purpose of computing evaluation metric on the
validation set for model selection we usually don’t have to compute
validation set for model selection we usually don’t have to compute
exact offline inference. The reason is that we need to compute the
exact offline inference. The reason is that we need to compute the
...
@@ -105,7 +123,7 @@ of unlabeled data. Neighborhood sampling will work fine for model
...
@@ -105,7 +123,7 @@ of unlabeled data. Neighborhood sampling will work fine for model
selection and validation.
selection and validation.
One can see
One can see
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/
pytorch/graphsage/train_sampling
.py>`__
`GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/
sampling/graphbolt/node_classification
.py>`__
and
and
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/
pytorch
/rgcn
-
hetero
/entity_classify_mb
.py>`__
`RGCN <https://github.com/dmlc/dgl/blob/master/examples/
sampling/graphbolt
/rgcn
/
hetero
_rgcn
.py>`__
for examples of offline inference.
for examples of offline inference.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment