Unverified Commit 88833e6f authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Misc] Delete all the DGLMiniBatch from docs and comments. (#6805)

parent a9656e2c
......@@ -111,8 +111,6 @@ features. It is the basic unit for training a GNN model.
:template: graphbolt_classtemplate.rst
MiniBatch
DGLMiniBatch
DGLMiniBatchConverter
NegativeSampler
......
......@@ -95,22 +95,8 @@ Define a GraphSAGE model for minibatch training
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Let's define a utility function to compact node pairs as follows:
.. code:: python
def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src, pos_dst = data.positive_node_pairs
neg_src, neg_dst = data.negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
Use `node_pairs_with_labels` to get compact node pairs with corresponding
labels.
Training loop
......@@ -130,7 +116,7 @@ above.
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
compacted_pairs, labels = data.node_pairs_with_labels
node_feature = data.node_features["feat"]
# Convert sampled subgraphs to DGL blocks.
blocks = data.blocks
......@@ -240,26 +226,9 @@ If you want to give your own negative sampling function, just inherit from the
datapipe = datapipe.customized_sample_negative(5, node_degrees)
For heterogeneous graphs, node pairs are grouped by edge types.
.. code:: python
def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch, etype):
"""Convert the minibatch to a training pair and a label tensor."""
pos_src, pos_dst = data.positive_node_pairs[etype]
neg_src, neg_dst = data.negative_node_pairs[etype]
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
The training loop is again almost the same as that on homogeneous graph,
except for computing loss on specific edge type.
For heterogeneous graphs, node pairs are grouped by edge types. The training
loop is again almost the same as that on homogeneous graph, except for computing
loss on specific edge type.
.. code:: python
......@@ -272,7 +241,7 @@ except for computing loss on specific edge type.
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data, category)
compacted_pairs, labels = data.node_pairs_with_labels
node_features = {
ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes
......@@ -282,11 +251,12 @@ except for computing loss on specific edge type.
# Get the embeddings of the input nodes.
y = model(blocks, node_feature)
logits = model.predictor(
y[category][compacted_pairs[0]] * y[category][compacted_pairs[1]]
y[category][compacted_pairs[category][0]]
* y[category][compacted_pairs[category][1]]
).squeeze()
# Compute loss.
loss = F.binary_cross_entropy_with_logits(logits, labels)
loss = F.binary_cross_entropy_with_logits(logits, labels[category])
optimizer.zero_grad()
loss.backward()
optimizer.step()
......
......@@ -11,8 +11,7 @@ generate a minibatch, including:
* Sample neighbors for each seed from graph.
* Exclude seed edges from the sampled subgraphs.
* Fetch node and edge features for the sampled subgraphs.
* Convert the sampled subgraphs to DGLMiniBatches.
* Copy the DGLMiniBatches to the target device.
* Copy the MiniBatches to the target device.
.. code:: python
......
......@@ -107,7 +107,7 @@ class FeatureFetcher(MiniBatchTransformer):
if original_edge_ids is None:
continue
if is_heterogeneous:
# Convert edge type to string for DGLMiniBatch.
# Convert edge type to string.
original_edge_ids = {
etype_tuple_to_str(key)
if isinstance(key, tuple)
......
......@@ -9,11 +9,6 @@ from torchdata.datapipes.iter import Mapper
from . import gb_test_utils
class MiniBatchType(Enum):
MiniBatch = 1
DGLMiniBatch = 2
def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment