Unverified Commit 5a7e156f authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Example] Fix label dtype error in multigpu node classification. (#6376)

parent d912947b
......@@ -204,7 +204,7 @@ def train(
train_dataloader
):
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y = blocks[-1].dstdata["label"].to(torch.int64)
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment