"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "ec561daa26c58882c075131eb2da5a4fac792925"
Commit 92105bf1 authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent cb0e5f63
import torch
import cluster_cuda
dtype = torch.float
device = torch.device('cuda')
def grid_cluster(pos, size, start=None, end=None):
start = pos.t().min(dim=1)[0] if start is None else start
end = pos.t().max(dim=1)[0] if end is None else end
return cluster_cuda.grid(pos, size, start, end)
pos = torch.tensor(
[[1, 1], [3, 3], [5, 5], [7, 7]], dtype=dtype, device=device)
size = torch.tensor([2, 2, 1, 1, 4, 2, 1], dtype=dtype, device=device)
# print('pos', pos.tolist())
# print('size', size.tolist())
cluster = grid_cluster(pos, size)
print('result', cluster.tolist(), cluster.type())
...@@ -33,7 +33,7 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start, ...@@ -33,7 +33,7 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
cluster.data<int64_t>(), cluster.data<int64_t>(),
at::cuda::detail::getTensorInfo<scalar_t, int>(pos), at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
size.toType(pos.type()).data<scalar_t>(), size.toType(pos.type()).data<scalar_t>(),
start..toType(pos.type()).data<scalar_t>(), start.toType(pos.type()).data<scalar_t>(),
end.toType(pos.type()).data<scalar_t>(), num_nodes); end.toType(pos.type()).data<scalar_t>(), num_nodes);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment