"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "340244a60fb4fdd758117ed8eff8e3fb9aab0b3a"
Commit c877cda6 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

making disco helpers GPU ready

parent 00064117
...@@ -104,6 +104,16 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke ...@@ -104,6 +104,16 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
CHECK_INPUT_TENSOR(col_idx); CHECK_INPUT_TENSOR(col_idx);
CHECK_INPUT_TENSOR(val); CHECK_INPUT_TENSOR(val);
// get the input device and make sure all tensors are on the same device
auto device = ker_idx.device();
TORCH_INTERNAL_ASSERT(device.type() == row_idx.device().type() && (device.type() == col_idx.device().type()) && (device.type() == val.device().type()));
// move to cpu
ker_idx = ker_idx.to(torch::kCPU);
row_idx = row_idx.to(torch::kCPU);
col_idx = col_idx.to(torch::kCPU);
val = val.to(torch::kCPU);
int64_t nnz = val.size(0); int64_t nnz = val.size(0);
int64_t *ker_h = ker_idx.data_ptr<int64_t>(); int64_t *ker_h = ker_idx.data_ptr<int64_t>();
int64_t *row_h = row_idx.data_ptr<int64_t>(); int64_t *row_h = row_idx.data_ptr<int64_t>();
...@@ -117,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke ...@@ -117,13 +127,19 @@ torch::Tensor preprocess_psi(const int64_t K, const int64_t Ho, torch::Tensor ke
})); }));
// create output tensor // create output tensor
auto options = torch::TensorOptions().dtype(row_idx.dtype()); auto roff_idx = torch::empty({nrows + 1}, row_idx.options());
auto roff_idx = torch::empty({nrows + 1}, options);
int64_t *roff_out_h = roff_idx.data_ptr<int64_t>(); int64_t *roff_out_h = roff_idx.data_ptr<int64_t>();
for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; } for (int64_t i = 0; i < (nrows + 1); i++) { roff_out_h[i] = roff_h[i]; }
delete[] roff_h; delete[] roff_h;
// move to original device
ker_idx = ker_idx.to(device);
row_idx = row_idx.to(device);
col_idx = col_idx.to(device);
val = val.to(device);
roff_idx = roff_idx.to(device);
return roff_idx; return roff_idx;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment