Commit 3e87af1c authored by rusty1s's avatar rusty1s
Browse files

torch.half support

parent 8c25ddef
...@@ -44,8 +44,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -44,8 +44,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES_AND( AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
at::ScalarType::Half, mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr; scalar_t *value_data = nullptr;
auto mat_data = mat.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -58,8 +57,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -58,8 +57,7 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
int64_t grain_size = at::internal::GRAIN_SIZE / int64_t grain_size = at::internal::GRAIN_SIZE /
(K * std::max(col.numel() / M, (int64_t)1)); (K * std::max(col.numel() / M, (int64_t)1));
at::parallel_for( at::parallel_for(0, B * M, grain_size, [&](int64_t begin, int64_t end) {
0, B * M, grain_size, [&](int64_t begin, int64_t end) {
scalar_t val; scalar_t val;
std::vector<scalar_t> vals(K); std::vector<scalar_t> vals(K);
int64_t row_start, row_end, b, m, c; int64_t row_start, row_end, b, m, c;
...@@ -81,20 +79,18 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col, ...@@ -81,20 +79,18 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
if (HAS_VALUE) if (HAS_VALUE)
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&vals[k], val * mat_data[offset + c * K + k], &vals[k], val * mat_data[offset + c * K + k], &args[k],
&args[k], e); e);
else else
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&vals[k], mat_data[offset + c * K + k], &args[k], &vals[k], mat_data[offset + c * K + k], &args[k], e);
e);
} }
} }
offset = b * M * K + m * K; offset = b * M * K + m * K;
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write( Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
out_data + offset + k, vals[k], arg_out_data + offset + k,
arg_out_data + offset + k, args[k], args[k], row_end - row_start);
row_end - row_start);
} }
}); });
}); });
...@@ -127,8 +123,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr, ...@@ -127,8 +123,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
AT_DISPATCH_ALL_TYPES_AND( AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
at::ScalarType::Half, mat.scalar_type(), "spmm_value_bw", [&] {
auto mat_data = mat.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>(); auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -144,8 +139,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr, ...@@ -144,8 +139,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
grad_data[b * M * K + row * K + k]; grad_data[b * M * K + row * K + k];
} }
if (REDUCE == MEAN) { if (REDUCE == MEAN) {
int row_start = rowptr_data[row], int row_start = rowptr_data[row], row_end = rowptr_data[row + 1];
row_end = rowptr_data[row + 1];
val /= (scalar_t)std::max(row_end - row_start, 1); val /= (scalar_t)std::max(row_end - row_start, 1);
} }
out_data[e] += val; out_data[e] += val;
......
...@@ -73,7 +73,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -73,7 +73,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val; *address = val;
else if (REDUCE == MEAN) else if (REDUCE == MEAN)
*address = val / (count > 0 ? count : (scalar_t)1); *address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) { else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) { if (count > 0) {
*address = val; *address = val;
......
...@@ -132,7 +132,8 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col, ...@@ -132,7 +132,8 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32); auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -219,7 +220,8 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr, ...@@ -219,7 +220,8 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
auto col_data = col.data_ptr<int64_t>(); auto col_data = col.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_val_bw_kernel", [&] {
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>(); auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>(); auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
...@@ -5,3 +5,15 @@ ...@@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half __shfl_sync(const unsigned mask,
const at::Half var,
const unsigned int srcLane) {
return __shfl_sync(mask, (__half)var, srcLane);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, (__half)var, delta);
}
...@@ -162,7 +162,7 @@ public: ...@@ -162,7 +162,7 @@ public:
if (torch::autograd::any_variable_requires_grad({mat})) { if (torch::autograd::any_variable_requires_grad({mat})) {
row = row.index_select(0, csr2csc); row = row.index_select(0, csr2csc);
rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row); rowcount = rowcount.toType(mat.scalar_type()).index_select(0, row);
rowcount.clamp_(1); rowcount.masked_fill_(rowcount < 1, 1);
if (has_value > 0) if (has_value > 0)
rowcount = value.index_select(0, csr2csc).div(rowcount); rowcount = value.index_select(0, csr2csc).div(rowcount);
......
...@@ -40,27 +40,16 @@ def test_spmm(dtype, device, reduce): ...@@ -40,27 +40,16 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce) out = matmul(src, other, reduce)
out.backward(grad_out) out.backward(grad_out)
assert torch.allclose(expected, out, atol=1e-6)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-6)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-6)
def test_spmm_half_precision():
src_dense = torch.randn((10, 8), dtype=torch.half, device='cpu')
src_dense[2:4, :] = 0 # Remove multiple rows.
src_dense[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src_dense)
other = torch.randn((2, 8, 2), dtype=torch.float, device='cpu')
expected = (src_dense.to(torch.float) @ other).to(torch.half)
out = src @ other.to(torch.half)
assert torch.allclose(expected, out, atol=1e-2) assert torch.allclose(expected, out, atol=1e-2)
assert torch.allclose(expected_grad_value, value.grad, atol=1e-2)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-2)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype, src = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype,
device=device) device=device)
......
...@@ -9,6 +9,9 @@ from .utils import grad_dtypes, devices, tensor ...@@ -9,6 +9,9 @@ from .utils import grad_dtypes, devices, tensor
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_spspmm(dtype, device): def test_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device) indexA = torch.tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]], device=device)
valueA = tensor([1, 2, 3, 4, 5], dtype, device) valueA = tensor([1, 2, 3, 4, 5], dtype, device)
indexB = torch.tensor([[0, 2], [1, 0]], device=device) indexB = torch.tensor([[0, 2], [1, 0]], device=device)
...@@ -21,6 +24,9 @@ def test_spspmm(dtype, device): ...@@ -21,6 +24,9 @@ def test_spspmm(dtype, device):
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_sparse_tensor_spspmm(dtype, device): def test_sparse_tensor_spspmm(dtype, device):
if dtype == torch.half:
return # TODO
x = SparseTensor( x = SparseTensor(
row=torch.tensor( row=torch.tensor(
[0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9], [0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9],
......
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
reductions = ['sum', 'add', 'mean', 'min', 'max'] reductions = ['sum', 'add', 'mean', 'min', 'max']
dtypes = [torch.float, torch.double, torch.int, torch.long] dtypes = [torch.half, torch.float, torch.double, torch.int, torch.long]
grad_dtypes = [torch.float, torch.double] grad_dtypes = [torch.half, torch.float, torch.double]
devices = [torch.device('cpu')] devices = [torch.device('cpu')]
if torch.cuda.is_available(): if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment