Unverified Commit 107b4347 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[example] update docstring (#6902)

parent 566910d8
...@@ -136,7 +136,11 @@ def create_dataloader( ...@@ -136,7 +136,11 @@ def create_dataloader(
# Move the mini-batch to the appropriate device. # Move the mini-batch to the appropriate device.
# `device`: # `device`:
# The device to move the mini-batch to. # The device to move the mini-batch to.
# [TODO] Moving `MiniBatch` to GPU is not supported yet. # [Rui] Usually, we move the mini-batch to target device in the datapipe.
# However, in this example, we leaves the mini-batch on CPU and move it to
# GPU after blocks are created. This is because this example is busy on
# GPU due to embedding layer. And block creation on CPU could be overlapped
# with optimization operation on GPU and it results in better performance.
device = torch.device("cpu") device = torch.device("cpu")
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
...@@ -443,6 +447,7 @@ def evaluate( ...@@ -443,6 +447,7 @@ def evaluate(
y_true = list() y_true = list()
for data in tqdm(data_loader, desc="Inference"): for data in tqdm(data_loader, desc="Inference"):
# Convert MiniBatch to DGL Blocks and move them to the target device.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
node_features = extract_node_features( node_features = extract_node_features(
name, blocks[0], data, node_embed, device name, blocks[0], data, node_embed, device
...@@ -508,7 +513,8 @@ def train( ...@@ -508,7 +513,8 @@ def train(
total_loss = 0 total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch + 1:02d}"): for data in tqdm(data_loader, desc=f"Training~Epoch {epoch + 1:02d}"):
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks and move them to the target
# device.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
# Fetch the number of seed nodes in the batch. # Fetch the number of seed nodes in the batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment