Unverified Commit 107b4347 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[example] update docstring (#6902)

parent 566910d8
......@@ -136,7 +136,11 @@ def create_dataloader(
# Move the mini-batch to the appropriate device.
# `device`:
# The device to move the mini-batch to.
# [TODO] Moving `MiniBatch` to GPU is not supported yet.
# [Rui] Usually, we move the mini-batch to target device in the datapipe.
# However, in this example, we leaves the mini-batch on CPU and move it to
# GPU after blocks are created. This is because this example is busy on
# GPU due to embedding layer. And block creation on CPU could be overlapped
# with optimization operation on GPU and it results in better performance.
device = torch.device("cpu")
datapipe = datapipe.copy_to(device)
......@@ -443,6 +447,7 @@ def evaluate(
y_true = list()
for data in tqdm(data_loader, desc="Inference"):
# Convert MiniBatch to DGL Blocks and move them to the target device.
blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features(
name, blocks[0], data, node_embed, device
......@@ -508,7 +513,8 @@ def train(
total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch + 1:02d}"):
# Convert MiniBatch to DGL Blocks.
# Convert MiniBatch to DGL Blocks and move them to the target
# device.
blocks = [block.to(device) for block in data.blocks]
# Fetch the number of seed nodes in the batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment