Commit 2eba313c authored by rusty1s's avatar rusty1s
Browse files

fixed spspmm for cpu

parent 57852a66
...@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -48,80 +48,56 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
auto rowptrB_data = rowptrB.data_ptr<int64_t>(); auto rowptrB_data = rowptrB.data_ptr<int64_t>();
auto colB_data = colB.data_ptr<int64_t>(); auto colB_data = colB.data_ptr<int64_t>();
// Pass 1: Compute CSR row pointer.
auto rowptrC = torch::empty_like(rowptrA); auto rowptrC = torch::empty_like(rowptrA);
auto rowptrC_data = rowptrC.data_ptr<int64_t>(); auto rowptrC_data = rowptrC.data_ptr<int64_t>();
rowptrC_data[0] = 0; rowptrC_data[0] = 0;
std::vector<int64_t> mask(K, -1); torch::Tensor colC;
int64_t nnz = 0, row_nnz, rowA_start, rowA_end, rowB_start, rowB_end, cA, cB;
for (auto n = 0; n < rowptrA.numel() - 1; n++) {
row_nnz = 0;
for (auto eA = rowptrA_data[n]; eA < rowptrA_data[n + 1]; eA++) {
cA = colA_data[eA];
for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
cB = colB_data[eB];
if (mask[cB] != n) {
mask[cB] = n;
row_nnz++;
}
}
}
nnz += row_nnz;
rowptrC_data[n + 1] = nnz;
}
// Pass 2: Compute CSR entries.
auto colC = torch::empty(nnz, rowptrC.options());
auto colC_data = colC.data_ptr<int64_t>();
torch::optional<torch::Tensor> optional_valueC = torch::nullopt; torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
if (optional_valueA.has_value())
optional_valueC = torch::empty(nnz, optional_valueA.value().options());
AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] { AT_DISPATCH_ALL_TYPES(scalar_type, "spspmm", [&] {
AT_DISPATCH_HAS_VALUE(optional_valueC, [&] { AT_DISPATCH_HAS_VALUE(optional_valueA, [&] {
scalar_t *valA_data = nullptr, *valB_data = nullptr, *valC_data = nullptr; scalar_t *valA_data = nullptr, *valB_data = nullptr;
if (HAS_VALUE) { if (HAS_VALUE) {
valA_data = optional_valueA.value().data_ptr<scalar_t>(); valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>(); valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
} }
scalar_t valA;
rowA_start = 0, nnz = 0; int64_t nnz = 0, cA, cB;
std::vector<scalar_t> vals(K, 0); std::vector<scalar_t> tmp_vals(K, 0);
for (auto n = 1; n < rowptrA.numel(); n++) { std::vector<int64_t> cols;
rowA_end = rowptrA_data[n]; std::vector<scalar_t> vals;
for (auto eA = rowA_start; eA < rowA_end; eA++) { for (auto rA = 0; rA < rowptrA.numel() - 1; rA++) {
for (auto eA = rowptrA_data[rA]; eA < rowptrA_data[rA + 1]; eA++) {
cA = colA_data[eA]; cA = colA_data[eA];
if (HAS_VALUE) for (auto eB = rowptrB_data[cA]; eB < rowptrB_data[cA + 1]; eB++) {
valA = valA_data[eA];
rowB_start = rowptrB_data[cA], rowB_end = rowptrB_data[cA + 1];
for (auto eB = rowB_start; eB < rowB_end; eB++) {
cB = colB_data[eB]; cB = colB_data[eB];
if (HAS_VALUE) if (HAS_VALUE)
vals[cB] += valA * valB_data[eB]; tmp_vals[cB] += valA_data[eA] * valB_data[eB];
else else
vals[cB] += 1; tmp_vals[cB]++;
} }
} }
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
if (vals[k] != 0) { if (tmp_vals[k] != 0) {
colC_data[nnz] = k; cols.push_back(k);
if (HAS_VALUE) if (HAS_VALUE)
valC_data[nnz] = vals[k]; vals.push_back(tmp_vals[k]);
nnz++; nnz++;
} }
vals[k] = (scalar_t)0; tmp_vals[k] = (scalar_t)0;
}
rowptrC_data[rA + 1] = nnz;
} }
rowA_start = rowA_end; colC = torch::from_blob(cols.data(), {nnz}, colA.options()).clone();
if (HAS_VALUE) {
optional_valueC = torch::from_blob(vals.data(), {nnz},
optional_valueA.value().options());
optional_valueC = optional_valueC.value().clone();
} }
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment