Unverified Commit 01fa24ee authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

adjust tensor creation (#127)

parent 516d988d
...@@ -24,7 +24,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, ...@@ -24,7 +24,7 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
auto out_ptr = deg.toType(torch::kFloat) * ratio; auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
auto out = torch::empty(out_ptr[-1].data_ptr<int64_t>()[0], ptr.options()); auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>(); auto ptr_data = ptr.data_ptr<int64_t>();
auto out_ptr_data = out_ptr.data_ptr<int64_t>(); auto out_ptr_data = out_ptr.data_ptr<int64_t>();
......
...@@ -35,7 +35,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size, ...@@ -35,7 +35,7 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
auto num_voxels = (end - start).true_divide(size).toType(torch::kLong) + 1; auto num_voxels = (end - start).true_divide(size).toType(torch::kLong) + 1;
num_voxels = num_voxels.cumprod(0); num_voxels = num_voxels.cumprod(0);
num_voxels = num_voxels =
torch::cat({torch::ones(1, num_voxels.options()), num_voxels}, 0); torch::cat({torch::ones({1}, num_voxels.options()), num_voxels}, 0);
num_voxels = num_voxels.narrow(0, 0, size.size(0)); num_voxels = num_voxels.narrow(0, 0, size.size(0));
auto out = pos.true_divide(size.view({1, -1})).toType(torch::kLong); auto out = pos.true_divide(size.view({1, -1})).toType(torch::kLong);
......
...@@ -80,14 +80,14 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, ...@@ -80,14 +80,14 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio; auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
out_ptr = torch::cat({torch::zeros(1, ptr.options()), out_ptr}, 0); out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0);
torch::Tensor start; torch::Tensor start;
if (random_start) { if (random_start) {
start = torch::rand(batch_size, src.options()); start = torch::rand(batch_size, src.options());
start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong); start = (start * deg.toType(ratio.scalar_type())).toType(torch::kLong);
} else { } else {
start = torch::zeros(batch_size, ptr.options()); start = torch::zeros({batch_size}, ptr.options());
} }
auto dist = torch::full(src.size(0), 5e4, src.options()); auto dist = torch::full(src.size(0), 5e4, src.options());
...@@ -95,7 +95,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, ...@@ -95,7 +95,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto out_size = (int64_t *)malloc(sizeof(int64_t)); auto out_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t), cudaMemcpy(out_size, out_ptr[-1].data_ptr<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost); cudaMemcpyDeviceToHost);
auto out = torch::empty(out_size[0], out_ptr.options()); auto out = torch::empty({out_size[0]}, out_ptr.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
auto scalar_type = src.scalar_type(); auto scalar_type = src.scalar_type();
......
...@@ -58,7 +58,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size, ...@@ -58,7 +58,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
auto start = optional_start.value(); auto start = optional_start.value();
auto end = optional_end.value(); auto end = optional_end.value();
auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong)); auto out = torch::empty({pos.size(0)}, pos.options().dtype(torch::kLong));
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, pos.scalar_type(), "_", [&] {
......
...@@ -115,7 +115,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, ...@@ -115,7 +115,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options()); auto row = torch::empty({y.size(0) * k}, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS); dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment