Commit b2cddcbd authored by rusty1s's avatar rusty1s
Browse files

test cuda

parent 8f734f72
...@@ -2,25 +2,25 @@ language: shell ...@@ -2,25 +2,25 @@ language: shell
os: os:
- linux - linux
- osx # - osx
- windows # - windows
env: env:
global: global:
- CUDA_HOME=/usr/local/cuda - CUDA_HOME=/usr/local/cuda
jobs: jobs:
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92 - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101 # - TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs: jobs:
exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs. exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
......
...@@ -98,7 +98,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -98,7 +98,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
CHECK_CUDA(pseudo); CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size); CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline); CHECK_CUDA(is_open_spline);
cudaSetDevice(pseudo.get_device()); // cudaSetDevice(pseudo.get_device());
CHECK_INPUT(kernel_size.dim() == 1); CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel()); CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
...@@ -116,16 +116,16 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size, ...@@ -116,16 +116,16 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] { AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_fw_kernel<scalar_t, DEGREE> // spline_basis_fw_kernel<scalar_t, DEGREE>
<<<BLOCKS(basis.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
pseudo_data, kernel_size_data, is_open_spline_data, basis_data, // pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
weight_index_data, E, D, S, basis.numel()); // weight_index_data, E, D, S, basis.numel());
}); });
}); });
...@@ -180,7 +180,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis, ...@@ -180,7 +180,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
CHECK_CUDA(pseudo); CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size); CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline); CHECK_CUDA(is_open_spline);
cudaSetDevice(grad_basis.get_device()); // cudaSetDevice(grad_basis.get_device());
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0)); CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1); CHECK_INPUT(kernel_size.dim() == 1);
...@@ -197,18 +197,18 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis, ...@@ -197,18 +197,18 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
auto kernel_size_data = kernel_size.data_ptr<int64_t>(); auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>(); auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] { AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>(); auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>(); auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] { AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_bw_kernel<scalar_t, DEGREE> // spline_basis_bw_kernel<scalar_t, DEGREE>
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
grad_basis_data, pseudo_data, kernel_size_data, // grad_basis_data, pseudo_data, kernel_size_data,
is_open_spline_data, grad_pseudo_data, E, D, S, // is_open_spline_data, grad_pseudo_data, E, D, S,
grad_pseudo.numel()); // grad_pseudo.numel());
}); });
}); });
......
...@@ -41,7 +41,7 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight, ...@@ -41,7 +41,7 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
cudaSetDevice(x.get_device()); // cudaSetDevice(x.get_device());
CHECK_INPUT(x.size(1) == weight.size(1)); CHECK_INPUT(x.size(1) == weight.size(1));
...@@ -54,17 +54,17 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight, ...@@ -54,17 +54,17 @@ torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
spline_weighting_fw_kernel<scalar_t> // spline_weighting_fw_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
x_data, weight_data, basis_data, weight_index_data, out_data, E, // x_data, weight_data, basis_data, weight_index_data, out_data, E,
M_in, M_out, S, out.numel()); // M_in, M_out, S, out.numel());
}); });
return out; return out;
...@@ -106,7 +106,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -106,7 +106,7 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device()); // cudaSetDevice(grad_out.get_device());
CHECK_INPUT(grad_out.size(1) == weight.size(2)); CHECK_INPUT(grad_out.size(1) == weight.size(2));
...@@ -120,17 +120,17 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out, ...@@ -120,17 +120,17 @@ torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>(); auto grad_x_data = grad_x.data_ptr<scalar_t>();
spline_weighting_bw_x_kernel<scalar_t> // spline_weighting_bw_x_kernel<scalar_t>
<<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
grad_out_data, weight_data, basis_data, weight_index_data, // grad_out_data, weight_data, basis_data, weight_index_data,
grad_x_data, E, M_in, M_out, S, grad_x.numel()); // grad_x_data, E, M_in, M_out, S, grad_x.numel());
}); });
return grad_x; return grad_x;
...@@ -169,7 +169,7 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, ...@@ -169,7 +169,7 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CUDA(basis); CHECK_CUDA(basis);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device()); // cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0); auto E = grad_out.size(0);
auto M_in = x.size(1); auto M_in = x.size(1);
...@@ -180,17 +180,17 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out, ...@@ -180,17 +180,17 @@ torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>(); auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
spline_weighting_bw_weight_kernel<scalar_t> // spline_weighting_bw_weight_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, basis_data, weight_index_data, // grad_out_data, x_data, basis_data, weight_index_data,
grad_weight_data, E, M_in, M_out, S, grad_out.numel()); // grad_weight_data, E, M_in, M_out, S, grad_out.numel());
}); });
return grad_weight; return grad_weight;
...@@ -230,7 +230,7 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, ...@@ -230,7 +230,7 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
CHECK_CUDA(x); CHECK_CUDA(x);
CHECK_CUDA(weight); CHECK_CUDA(weight);
CHECK_CUDA(weight_index); CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device()); // cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1)); CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2)); CHECK_INPUT(grad_out.size(1) == weight.size(2));
...@@ -244,17 +244,17 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out, ...@@ -244,17 +244,17 @@ torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
auto weight_index_data = weight_index.data_ptr<int64_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); // auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>(); auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
spline_weighting_bw_basis_kernel<scalar_t> // spline_weighting_bw_basis_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>( // <<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, weight_data, weight_index_data, // grad_out_data, x_data, weight_data, weight_index_data,
grad_basis_data, E, M_in, M_out, S, grad_out.numel()); // grad_basis_data, E, M_in, M_out, S, grad_out.numel());
}); });
return grad_basis; return grad_basis;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment