Commit 947e0369 authored by rusty1s's avatar rusty1s
Browse files

torch.half support for spmm (CPU)

parent 6e23f7fb
......@@ -72,7 +72,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
......
......@@ -44,57 +44,62 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
}
int64_t grain_size = at::internal::GRAIN_SIZE / (K * (col.numel() / M));
at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], e);
}
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k,
args[k], row_end - row_start);
}
int64_t grain_size =
at::internal::GRAIN_SIZE / (K * (col.numel() / M));
at::parallel_for(
0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k],
&args[k], e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k],
e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
}
});
});
});
});
});
});
return std::make_tuple(out, arg_out);
}
......@@ -122,30 +127,32 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
if (REDUCE == MEAN) {
int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
if (REDUCE == MEAN) {
int row_start = rowptr_data[row],
row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
}
out_data[e] += val;
}
}
});
});
});
});
return out;
}
......@@ -45,6 +45,20 @@ def test_spmm(dtype, device, reduce):
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def test_spmm_half_precision():
src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
src_dense[2:4, :] = 0 # Remove multiple rows.
src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.half, device='cpu')
expected = src_dense @ other
out = src @ other
assert torch.allclose(expected, out, atol=1e-6)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment