Unverified Commit 8df64670 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[examples] create dataloader before train epoch (#6636)

parent 659b9289
...@@ -537,6 +537,18 @@ def run( ...@@ -537,6 +537,18 @@ def run(
print("start to run...") print("start to run...")
category = "paper" category = "paper"
data_loader = create_dataloader(
name,
g,
features,
train_set,
device,
batch_size=1024,
fanouts=[25, 10],
shuffle=True,
num_workers=num_workers,
)
# Typically, the best Validation performance is obtained after # Typically, the best Validation performance is obtained after
# the 1st or 2nd epoch. This is why the max epoch is set to 3. # the 1st or 2nd epoch. This is why the max epoch is set to 3.
for epoch in range(3): for epoch in range(3):
...@@ -545,17 +557,6 @@ def run( ...@@ -545,17 +557,6 @@ def run(
total_loss = 0 total_loss = 0
data_loader = create_dataloader(
name,
g,
features,
train_set,
device,
batch_size=1024,
fanouts=[25, 10],
shuffle=True,
num_workers=num_workers,
)
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"): for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment