"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3a7e481611bc299416aaeed4207086d9ddca5852"
Unverified Commit 5a7e156f authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Example] Fix label dtype error in multigpu node classification. (#6376)

parent d912947b
...@@ -204,7 +204,7 @@ def train( ...@@ -204,7 +204,7 @@ def train(
train_dataloader train_dataloader
): ):
x = blocks[0].srcdata["feat"] x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"] y = blocks[-1].dstdata["label"].to(torch.int64)
y_hat = model(blocks, x) y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
opt.zero_grad() opt.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment