Commit 3e87af1c authored by rusty1s's avatar rusty1s
Browse files

torch.half support

parent 8c25ddef
......@@ -44,62 +44,58 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
AT_DISPATCH_HAS_VALUE(optional_value, [&] {
if (HAS_VALUE) {
value_data = optional_value.value().data_ptr<scalar_t>();
}
int64_t grain_size = at::internal::GRAIN_SIZE /
(K * std::max(col.numel() / M, (int64_t)1));
at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &args[k],
e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], e);
}
}
int64_t grain_size = at::internal::GRAIN_SIZE /
(K * std::max(col.numel() / M, (int64_t)1));
at::parallel_for(
0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val;
std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c;
std::vector<int64_t> args(K);
for (auto i = begin; i < end; i++) {
b = i / M, m = i % M;
row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
auto offset = b * N * K;
for (auto e = row_start; e < row_end; e++) {
c = col_data[e];
if (HAS_VALUE)
val = value_data[e];
for (auto k = 0; k < K; k++) {
if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k],
&args[k], e);
else
Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k],
e);
}
}
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + offset + k, vals[k],
arg_out_data + offset + k, args[k],
row_end - row_start);
}
});
});
offset = b * M * K + m * K;
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k,
args[k], row_end - row_start);
}
});
});
});
});
return std::make_tuple(out, arg_out);
}
......@@ -127,32 +123,30 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half, mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
if (REDUCE == MEAN) {
int row_start = rowptr_data[row],
row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t val;
int64_t row, col;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (int b = 0; b < B; b++) {
for (int e = 0; e < E; e++) {
row = row_data[e], col = col_data[e], val = (scalar_t)0;
for (int k = 0; k < K; k++) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
});
});
if (REDUCE == MEAN) {
int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1);
}
out_data[e] += val;
}
}
});
});
return out;
}
......@@ -73,7 +73,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1);
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
......
......@@ -132,7 +132,8 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......@@ -219,7 +220,8 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
auto col_data = col.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
......
......@@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half __shfl_sync(const unsigned mask,
const at::Half var,
const unsigned int srcLane) {
return __shfl_sync(mask, (__half)var, srcLane);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}
......@@ -162,7 +162,7 @@ public:
if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1);
rowcount.masked_fill_(rowcount < 1, 1);
if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount);
......
......@@ -40,27 +40,16 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce)
out.backward(grad_out)
assert torch.allclose(expected, out, atol=1e-6)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def test_spmm_half_precision():
src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
src_dense[2:4, :] = 0 # Remove multiple rows.
src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')
expected = (src_dense.to(torch.float) @ other).to(torch.half)
out = src @ other.to(torch.half)
assert torch.allclose(expected, out, atol=1e-2)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-2)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device)
......
......@@ -9,6 +9,9 @@ from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device)
......@@ -21,6 +24,9 @@ def test_spspmm(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
x = SparseTensor(
row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
......
......@@ -2,8 +2,8 @@ import torch
reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double]
dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment