Commit 947e0369 authored by rusty1s's avatar rusty1s
Browse files

torch.half support for spmm (CPU)

parent 6e23f7fb
...@@ -72,7 +72,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -72,7 +72,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val; *address = val;
else if (REDUCE == MEAN) else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1); *address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) { else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) { if (count > 0) {
*address = val; *address = val;
......
...@@ -44,57 +44,62 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -44,57 +44,62 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] { AT_DISPATCH_ALL_TYPES_AND(
scalar_t *value_data = nullptr; at::ScalarType::Half, mat.scalar_type(), "spmm", [&] {
auto mat_data = mat.data_ptr<scalar_t>(); scalar_t *value_data = nullptr;
auto out_data = out.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (HAS_VALUE) { AT_DISPATCH_HAS_VALUE(optional_value, [&] {
value_data = optional_value.value().data_ptr<scalar_t>(); if (HAS_VALUE) {
} value_data = optional_value.value().data_ptr<scalar_t>();
int64_t grain_size = at::internal::GRAIN_SIZE / (K * (col.numel() / M));
at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], e);
}
} }
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++) int64_t grain_size =
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k], at::internal::GRAIN_SIZE / (K * (col.numel() / M));
arg_out_data + offset + k, at::parallel_for(
args[k], row_end - row_start); 0, B * M, grain_size, [&](int64_t begin, int64_t end) {
} scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k],
&args[k], e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k],
e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
}
});
});
}); });
}); });
});
});
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
...@@ -122,30 +127,32 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr, ...@@ -122,30 +127,32 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] { AT_DISPATCH_ALL_TYPES_AND(
auto mat_data = mat.data_ptr<scalar_t>(); at::ScalarType::Half, mat.scalar_type(), "spmm_value_bw", [&] {
auto grad_data = grad.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col; scalar_t val;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { int64_t row, col;
for (int b = 0; b < B; b++) { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int e = 0; e < E; e++) { for (int b = 0; b < B; b++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0; for (int e = 0; e < E; e++) {
for (int k = 0; k < K; k++) { row = row_data[e], col = col_data[e], val = (scalar_t)0;
val += mat_data[b * N * K + col * K + k] * for (int k = 0; k < K; k++) {
grad_data[b * M * K + row * K + k]; val += mat_data[b * N * K + col * K + k] *
} grad_data[b * M * K + row * K + k];
if (REDUCE == MEAN) { }
int row_start = rowptr_data[row], row_end = rowptr_data[row + 1]; if (REDUCE == MEAN) {
val /= (scalar_t)std::max(row_end - row_start, 1); int row_start = rowptr_data[row],
row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
} }
out_data[e] += val; });
} });
}
});
});
return out; return out;
} }
...@@ -45,6 +45,20 @@ def test_spmm(dtype, device, reduce): ...@@ -45,6 +45,20 @@ def test_spmm(dtype, device, reduce):
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6) assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def test_spmm_half_precision():
src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
src_dense[2:4, :] = 0 # Remove multiple rows.
src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.half, device='cpu')
expected = src_dense @ other
out = src @ other
assert torch.allclose(expected, out, atol=1e-6)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment