"src/api/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "5e5ed37aaaf726e9ed0ddd9987a9f9b148bb08e6"
Commit e56b9b9e authored by panning's avatar panning
Browse files

修复分类网络训练时tensor不连续报错

parent ae5cea86
...@@ -501,7 +501,7 @@ def accuracy(output, target, topk=(1,)): ...@@ -501,7 +501,7 @@ def accuracy(output, target, topk=(1,)):
res = [] res = []
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size)) res.append(correct_k.mul_(100.0 / batch_size))
return res return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment