Commit 2693efc9 authored by rusty1s's avatar rusty1s
Browse files

fix segment coo indexing

parent 4c4a2e6c
...@@ -85,7 +85,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -85,7 +85,7 @@ __global__ void segment_coo_broadcast_kernel(
int D = index_info.sizes[index_info.dims - 1]; int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D; int E_1 = E / D;
int E_2 = D + TB - (D % TB); int E_2 = (D - 1) + TB - ((D - 1) % TB);
int row_idx = blockIdx.x * blockDim.y + threadIdx.y; int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x; int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
...@@ -215,6 +215,12 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -215,6 +215,12 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
auto N = out.size(dim); auto N = out.size(dim);
auto avg_len = (float)E_2 / (float)N; auto avg_len = (float)E_2 / (float)N;
std::cout << "E " << E << std::endl;
std::cout << "E2 " << E_2 << std::endl;
std::cout << "E1 " << E_1 << std::endl;
std::cout << "K " << K << std::endl;
std::cout << "N " << N << std::endl;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment