"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "780fd1435844b5ce672bfc40d0a83576e9e62824"
Commit 6bf96692 authored by rusty1s's avatar rusty1s
Browse files

grid fixes

parent 06df4d9b
...@@ -16,7 +16,7 @@ __global__ void grid_kernel(const scalar_t *pos, const scalar_t *size, ...@@ -16,7 +16,7 @@ __global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
if (thread_idx < numel) { if (thread_idx < numel) {
int64_t c = 0, k = 1; int64_t c = 0, k = 1;
for (int64_t d = 0; d < D; d++) { for (int64_t d = 0; d < D; d++) {
scalar_t p = pos.data[thread_idx * D + d] - start[d]; scalar_t p = pos[thread_idx * D + d] - start[d];
c += (int64_t)(p / size[d]) * k; c += (int64_t)(p / size[d]) * k;
k *= (int64_t)((end[d] - start[d]) / size[d]) + 1; k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
} }
...@@ -24,9 +24,9 @@ __global__ void grid_kernel(const scalar_t *pos, const scalar_t *size, ...@@ -24,9 +24,9 @@ __global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
} }
} }
torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size, torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size,
torch::optional<torch::Tensor> optional_start, torch::optional<torch::Tensor> optional_start,
torch::optional<torch::Tensor> optional_end) { torch::optional<torch::Tensor> optional_end) {
CHECK_CUDA(pos); CHECK_CUDA(pos);
CHECK_CUDA(size); CHECK_CUDA(size);
cudaSetDevice(pos.get_device()); cudaSetDevice(pos.get_device());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment