Commit 2515ce6d authored by rusty1s's avatar rusty1s
Browse files

set device

parent f9b00093
......@@ -23,6 +23,8 @@ __global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
}
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
cudaSetDevice(ind.get_device());
auto out = torch::empty(M + 1, ind.options());
auto ind_data = ind.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
......@@ -46,6 +48,8 @@ __global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
}
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
cudaSetDevice(ptr.get_device());
auto out = torch::empty(E, ptr.options());
auto ptr_data = ptr.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<int64_t>();
......
......@@ -40,6 +40,8 @@ __global__ void non_diag_mask_kernel(const int64_t *row_data,
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
cudaSetDevice(row.get_device());
int64_t E = row.size(0);
int64_t num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
......
......@@ -160,6 +160,7 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> value_opt, torch::Tensor mat,
std::string reduce) {
cudaSetDevice(rowptr.get_device());
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(col.dim() == 1, "Input mismatch");
if (value_opt.has_value())
......@@ -252,6 +253,8 @@ torch::Tensor spmm_val_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
cudaSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
......
......@@ -48,6 +48,9 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> valueA, torch::Tensor rowptrB,
torch::Tensor colB, torch::optional<torch::Tensor> valueB,
int64_t M, int64_t N, int64_t K) {
cudaSetDevice(rowptrA.get_device());
cusparseMatDescr_t descr = 0;
cusparseCreateMatDescr(&descr);
auto handle = at::cuda::getCurrentCUDASparseHandle();
......
......@@ -19,6 +19,7 @@ __global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
cudaSetDevice(src.get_device());
at::Tensor perm;
std::tie(src, perm) = src.sort();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment