Unverified Commit 8df64670 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[examples] create dataloader before train epoch (#6636)

parent 659b9289
...@@ -537,14 +537,6 @@ def run( ...@@ -537,14 +537,6 @@ def run(
print("start to run...") print("start to run...")
category = "paper" category = "paper"
# Typically, the best Validation performance is obtained after
# the 1st or 2nd epoch. This is why the max epoch is set to 3.
for epoch in range(3):
num_train = len(train_set)
model.train()
total_loss = 0
data_loader = create_dataloader( data_loader = create_dataloader(
name, name,
g, g,
...@@ -556,6 +548,15 @@ def run( ...@@ -556,6 +548,15 @@ def run(
shuffle=True, shuffle=True,
num_workers=num_workers, num_workers=num_workers,
) )
# Typically, the best Validation performance is obtained after
# the 1st or 2nd epoch. This is why the max epoch is set to 3.
for epoch in range(3):
num_train = len(train_set)
model.train()
total_loss = 0
for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"): for data in tqdm(data_loader, desc=f"Training~Epoch {epoch:02d}"):
# Convert MiniBatch to DGL Blocks. # Convert MiniBatch to DGL Blocks.
blocks = [block.to(device) for block in data.blocks] blocks = [block.to(device) for block in data.blocks]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment