Commit ab4d22e0 authored by rusty1s's avatar rusty1s
Browse files

cuda fix

parent 26c0b866
...@@ -9,18 +9,18 @@ env: ...@@ -9,18 +9,18 @@ env:
global: global:
- CUDA_HOME=/usr/local/cuda - CUDA_HOME=/usr/local/cuda
jobs: jobs:
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
# - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs: jobs:
exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs. exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
......
...@@ -98,7 +98,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -98,7 +98,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
CHECK_CUDA(pseudo); CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size); CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline); CHECK_CUDA(is_open_spline);
// cudaSetDevice(pseudo.get_device()); cudaSetDevice(pseudo.get_device());
CHECK_INPUT(kernel_size.dim() == 1); CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel()); CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
...@@ -116,16 +116,16 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -116,16 +116,16 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] { AT_DISPATCH_DEGREE_TYPES(degree, [&] {
// spline_basis_fw_kernel<scalar_t, DEGREE> spline_basis_fw_kernel<scalar_t, DEGREE>
// <<<BLOCKS(basis.numel()), THREADS, 0, stream>>>( <<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
// pseudo_data, kernel_size_data, is_open_spline_data, basis_data, pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
// weight_index_data, E, D, S, basis.numel()); weight_index_data, E, D, S, basis.numel());
}); });
}); });
...@@ -180,7 +180,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis, ...@@ -180,7 +180,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
CHECK_CUDA(pseudo); CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size); CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline); CHECK_CUDA(is_open_spline);
// cudaSetDevice(grad_basis.get_device()); cudaSetDevice(grad_basis.get_device());
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0)); CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1); CHECK_INPUT(kernel_size.dim() == 1);
...@@ -197,18 +197,18 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis, ...@@ -197,18 +197,18 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
auto kernel_size_data = kernel_size.data_ptr<int64_t>(); auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>(); auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] { AT_DISPATCH_DEGREE_TYPES(degree, [&] {
// spline_basis_bw_kernel<scalar_t, DEGREE> spline_basis_bw_kernel<scalar_t, DEGREE>
// <<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>( <<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
// grad_basis_data, pseudo_data, kernel_size_data, grad_basis_data, pseudo_data, kernel_size_data,
// is_open_spline_data, grad_pseudo_data, E, D, S, is_open_spline_data, grad_pseudo_data, E, D, S,
// grad_pseudo.numel()); grad_pseudo.numel());
}); });
}); });
......
...@@ -41,7 +41,7 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight, ...@@ -41,7 +41,7 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
// cudaSetDevice(x.get_device()); cudaSetDevice(x.get_device());
CHECK_INPUT(x.size(1) == weight.size(1)); CHECK_INPUT(x.size(1) == weight.size(1));
...@@ -54,17 +54,17 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight, ...@@ -54,17 +54,17 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
// spline_weighting_fw_kernel<scalar_t> spline_weighting_fw_kernel<scalar_t>
// <<<BLOCKS(out.numel()), THREADS, 0, stream>>>( <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
// x_data, weight_data, basis_data, weight_index_data, out_data, E, x_data, weight_data, basis_data, weight_index_data, out_data, E,
// M_in, M_out, S, out.numel()); M_in, M_out, S, out.numel());
}); });
return out; return out;
...@@ -106,7 +106,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -106,7 +106,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
// cudaSetDevice(grad_out.get_device()); cudaSetDevice(grad_out.get_device());
CHECK_INPUT(grad_out.size(1) == weight.size(2)); CHECK_INPUT(grad_out.size(1) == weight.size(2));
...@@ -120,17 +120,17 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -120,17 +120,17 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>(); auto grad_x_data = grad_x.data_ptr<scalar_t>();
// spline_weighting_bw_x_kernel<scalar_t> spline_weighting_bw_x_kernel<scalar_t>
// <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>( <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
// grad_out_data, weight_data, basis_data, weight_index_data, grad_out_data, weight_data, basis_data, weight_index_data,
// grad_x_data, E, M_in, M_out, S, grad_x.numel()); grad_x_data, E, M_in, M_out, S, grad_x.numel());
}); });
return grad_x; return grad_x;
...@@ -169,7 +169,7 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, ...@@ -169,7 +169,7 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
// cudaSetDevice(grad_out.get_device()); cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0); auto E = grad_out.size(0);
auto M_in = x.size(1); auto M_in = x.size(1);
...@@ -180,17 +180,17 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, ...@@ -180,17 +180,17 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>(); auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
// spline_weighting_bw_weight_kernel<scalar_t> spline_weighting_bw_weight_kernel<scalar_t>
// <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>( <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
// grad_out_data, x_data, basis_data, weight_index_data, grad_out_data, x_data, basis_data, weight_index_data,
// grad_weight_data, E, M_in, M_out, S, grad_out.numel()); grad_weight_data, E, M_in, M_out, S, grad_out.numel());
}); });
return grad_weight; return grad_weight;
...@@ -230,7 +230,7 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, ...@@ -230,7 +230,7 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
// cudaSetDevice(grad_out.get_device()); cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1)); CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2)); CHECK_INPUT(grad_out.size(1) == weight.size(2));
...@@ -244,17 +244,17 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, ...@@ -244,17 +244,17 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
// auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
// spline_weighting_bw_basis_kernel<scalar_t> spline_weighting_bw_basis_kernel<scalar_t>
// <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>( <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
// grad_out_data, x_data, weight_data, weight_index_data, grad_out_data, x_data, weight_data, weight_index_data,
// grad_basis_data, E, M_in, M_out, S, grad_out.numel()); grad_basis_data, E, M_in, M_out, S, grad_out.numel());
}); });
return grad_basis; return grad_basis;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment