Unverified Commit eafcfe0a authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Fix tensor creation (#222)

* adjust tensor creation

* update
parent 7a6c9ab4
...@@ -13,7 +13,7 @@ torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M, ...@@ -13,7 +13,7 @@ torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool)); auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>(); auto mask_data = mask.data_ptr<bool>();
int64_t r, c; int64_t r, c;
......
...@@ -44,7 +44,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -44,7 +44,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
vwgt = optional_node_weight.value().data_ptr<int64_t>(); vwgt = optional_node_weight.value().data_ptr<int64_t>();
int64_t objval = -1; int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty({nvtxs}, rowptr.options());
auto part_data = part.data_ptr<int64_t>(); auto part_data = part.data_ptr<int64_t>();
if (recursive) { if (recursive) {
...@@ -99,7 +99,7 @@ mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -99,7 +99,7 @@ mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
mtmetis_pid_type nparts = num_parts; mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1; mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options()); auto part = torch::empty({nvtxs}, rowptr.options());
mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>(); mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();
double *opts = mtmetis_init_options(); double *opts = mtmetis_init_options();
......
...@@ -64,7 +64,7 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -64,7 +64,7 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
std::unordered_map<int64_t, int64_t> n_id_map; std::unordered_map<int64_t, int64_t> n_id_map;
std::unordered_map<int64_t, int64_t>::iterator it; std::unordered_map<int64_t, int64_t>::iterator it;
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options()); auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>(); auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0; out_rowptr_data[0] = 0;
...@@ -76,12 +76,12 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -76,12 +76,12 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
out_rowptr_data[i + 1] = offset; out_rowptr_data[i + 1] = offset;
} }
auto out_col = torch::empty(offset, col.options()); auto out_col = torch::empty({offset}, col.options());
auto out_col_data = out_col.data_ptr<int64_t>(); auto out_col_data = out_col.data_ptr<int64_t>();
torch::optional<torch::Tensor> out_value = torch::nullopt; torch::optional<torch::Tensor> out_value = torch::nullopt;
if (optional_value.has_value()) { if (optional_value.has_value()) {
out_value = torch::empty(offset, optional_value.value().options()); out_value = torch::empty({offset}, optional_value.value().options());
AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] { AT_DISPATCH_ALL_TYPES(optional_value.value().scalar_type(), "relabel", [&] {
auto value_data = optional_value.value().data_ptr<scalar_t>(); auto value_data = optional_value.value().data_ptr<scalar_t>();
......
...@@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -19,7 +19,7 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto idx_data = idx.data_ptr<int64_t>(); auto idx_data = idx.data_ptr<int64_t>();
auto out_rowptr = torch::empty(idx.numel() + 1, rowptr.options()); auto out_rowptr = torch::empty({idx.numel() + 1}, rowptr.options());
auto out_rowptr_data = out_rowptr.data_ptr<int64_t>(); auto out_rowptr_data = out_rowptr.data_ptr<int64_t>();
out_rowptr_data[0] = 0; out_rowptr_data[0] = 0;
...@@ -117,9 +117,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ...@@ -117,9 +117,9 @@ sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone(); auto out_n_id = torch::from_blob(n_ids.data(), {N}, col.options()).clone();
int64_t E = out_rowptr_data[idx.numel()]; int64_t E = out_rowptr_data[idx.numel()];
auto out_col = torch::empty(E, col.options()); auto out_col = torch::empty({E}, col.options());
auto out_col_data = out_col.data_ptr<int64_t>(); auto out_col_data = out_col.data_ptr<int64_t>();
auto out_e_id = torch::empty(E, col.options()); auto out_e_id = torch::empty({E}, col.options());
auto out_e_id_data = out_e_id.data_ptr<int64_t>(); auto out_e_id_data = out_e_id.data_ptr<int64_t>();
i = 0; i = 0;
......
...@@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr, ...@@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto out = torch::zeros(row.numel(), grad.options()); auto out = torch::zeros({row.numel()}, grad.options());
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
......
...@@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
if (!optional_valueA.has_value() && optional_valueB.has_value()) if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA = optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options()); torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value()) if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB = optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options()); torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float; auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value()) if (optional_valueA.has_value())
......
...@@ -61,7 +61,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false, ...@@ -61,7 +61,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
return torch::multinomial(weight.value(), num_samples, replace); return torch::multinomial(weight.value(), num_samples, replace);
if (replace) { if (replace) {
const auto out = torch::empty(num_samples, at::kLong); const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>(); auto *out_data = out.data_ptr<int64_t>();
for (int64_t i = 0; i < num_samples; i++) { for (int64_t i = 0; i < num_samples; i++) {
out_data[i] = uniform_randint(population); out_data[i] = uniform_randint(population);
...@@ -72,7 +72,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false, ...@@ -72,7 +72,7 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
// Sample without replacement via Robert Floyd algorithm: // Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/ // https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html // robert-floyds-tiny-and-beautiful.html
const auto out = torch::empty(num_samples, at::kLong); const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>(); auto *out_data = out.data_ptr<int64_t>();
std::unordered_set<int64_t> samples; std::unordered_set<int64_t> samples;
for (int64_t i = population - num_samples; i < population; i++) { for (int64_t i = population - num_samples; i < population; i++) {
......
...@@ -27,7 +27,7 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) { ...@@ -27,7 +27,7 @@ torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind); CHECK_CUDA(ind);
cudaSetDevice(ind.get_device()); cudaSetDevice(ind.get_device());
auto out = torch::empty(M + 1, ind.options()); auto out = torch::empty({M + 1}, ind.options());
if (ind.numel() == 0) if (ind.numel() == 0)
return out.zero_(); return out.zero_();
...@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) { ...@@ -57,7 +57,7 @@ torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr); CHECK_CUDA(ptr);
cudaSetDevice(ptr.get_device()); cudaSetDevice(ptr.get_device());
auto out = torch::empty(E, ptr.options()); auto out = torch::empty({E}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>(); auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>(); auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
...@@ -51,7 +51,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col, ...@@ -51,7 +51,7 @@ torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool)); auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>(); auto mask_data = mask.data_ptr<bool>();
if (E == 0) if (E == 0)
......
...@@ -213,7 +213,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr, ...@@ -213,7 +213,7 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS); auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros(row.numel(), grad.options()); auto out = torch::zeros({row.numel()}, grad.options());
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
......
...@@ -59,11 +59,11 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -59,11 +59,11 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
if (!optional_valueA.has_value() && optional_valueB.has_value()) if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA = optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options()); torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value()) if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB = optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options()); torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float; auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value()) if (optional_valueA.has_value())
...@@ -108,7 +108,7 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -108,7 +108,7 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
cudaMalloc(&buffer, bufferSize); cudaMalloc(&buffer, bufferSize);
// Step 3: Compute CSR row pointer. // Step 3: Compute CSR row pointer.
rowptrC = torch::empty(M + 1, rowptrA.options()); rowptrC = torch::empty({M + 1}, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>(); auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data, cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data, colA_data, descr, colB.numel(), rowptrB_data,
...@@ -116,11 +116,11 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -116,11 +116,11 @@ spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
nnzTotalDevHostPtr, info, buffer); nnzTotalDevHostPtr, info, buffer);
// Step 4: Compute CSR entries. // Step 4: Compute CSR entries.
colC = torch::empty(nnzC, rowptrC.options()); colC = torch::empty({nnzC}, rowptrC.options());
auto colC_data = colC.data_ptr<int>(); auto colC_data = colC.data_ptr<int>();
if (optional_valueA.has_value()) if (optional_valueA.has_value())
optional_valueC = torch::empty(nnzC, optional_valueA.value().options()); optional_valueC = torch::empty({nnzC}, optional_valueA.value().options());
scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL; scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) { if (optional_valueA.has_value()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment