Commit fdcab318 authored by rusty1s's avatar rusty1s
Browse files

added broadcasting capabilities to softmax

parent d4325fd1
...@@ -2,22 +2,25 @@ ...@@ -2,22 +2,25 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr); at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out); at::optional<at::Tensor> out_opt);
at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index,
at::Tensor out);
void segment_add_thrust_cuda(at::Tensor src, at::Tensor index, at::Tensor out); at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(indptr); CHECK_CUDA(indptr);
return segment_add_csr_cuda(src, indptr); if (out_opt.has_value())
CHECK_CUDA(out_opt.value());
return segment_add_csr_cuda(src, indptr, out_opt);
} }
void segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) { at::Tensor segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
CHECK_CUDA(out); CHECK_CUDA(out);
segment_add_coo_cuda(src, index, out); return segment_add_coo_cuda(src, index, out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -92,17 +92,28 @@ __global__ void segment_add_csr_broadcast_kernel( ...@@ -92,17 +92,28 @@ __global__ void segment_add_csr_broadcast_kernel(
} }
} }
at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) { at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim());
for (int i = 0; i < indptr.dim() - 1; i++) for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i)); AT_ASSERTM(src.size(i) == indptr.size(i));
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
auto sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; at::Tensor out;
auto out = at::empty(sizes, src.options()); if (out_opt.has_value()) {
out = out_opt.value();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
} else {
auto sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options());
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
...@@ -164,6 +175,7 @@ __global__ void segment_add_coo_kernel( ...@@ -164,6 +175,7 @@ __global__ void segment_add_coo_kernel(
for (int i = 1; i < 32; i *= 2) { for (int i = 1; i < 32; i *= 2) {
tmp = __shfl_up_sync(FULL_MASK, val, i); tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i); next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx) if (lane_idx >= i && idx == next_idx)
val += tmp; val += tmp;
} }
...@@ -202,6 +214,7 @@ __global__ void segment_add_coo_broadcast_kernel( ...@@ -202,6 +214,7 @@ __global__ void segment_add_coo_broadcast_kernel(
int idx2 = __ldg(index_info.data + offset + int idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]); i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) { if (idx1 == idx2) {
val += src_data[K * (row_start + i) + col_idx]; val += src_data[K * (row_start + i) + col_idx];
} else { } else {
...@@ -215,7 +228,8 @@ __global__ void segment_add_coo_broadcast_kernel( ...@@ -215,7 +228,8 @@ __global__ void segment_add_coo_broadcast_kernel(
} }
} }
void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { at::Tensor segment_add_coo_cuda(at::Tensor src, at::Tensor index,
at::Tensor out) {
AT_ASSERTM(src.dim() >= index.dim()); AT_ASSERTM(src.dim() >= index.dim());
for (int i = 0; i < index.dim(); i++) for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i)); AT_ASSERTM(src.size(i) == index.size(i));
...@@ -257,4 +271,6 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) { ...@@ -257,4 +271,6 @@ void segment_add_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8), <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K); 0, stream>>>(src_data, index_info, out_data, E, K);
}); });
return out;
} }
...@@ -22,11 +22,21 @@ def test_softmax(dtype, device): ...@@ -22,11 +22,21 @@ def test_softmax(dtype, device):
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0) ], dim=0).to(device)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax_broadcasting(dtype, device):
src = torch.randn(10, 5, dtype=dtype, device=device)
index = tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4], torch.long, device)
out = scatter_softmax(src, index, dim=0).view(5, 2, 5)
out = out.sum(dim=1)
assert torch.allclose(out, torch.ones_like(out))
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device): def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
...@@ -42,6 +52,6 @@ def test_log_softmax(dtype, device): ...@@ -42,6 +52,6 @@ def test_log_softmax(dtype, device):
expected = torch.stack([ expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0) ], dim=0).to(device)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
...@@ -20,5 +20,5 @@ def test_logsumexp(dtype, device): ...@@ -20,5 +20,5 @@ def test_logsumexp(dtype, device):
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype) out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype) out4 = torch.tensor(-1, dtype=dtype)
expected = torch.stack([out0, out1, out2, out3, out4], dim=0) expected = torch.stack([out0, out1, out2, out3, out4], dim=0).to(device)
assert torch.allclose(out, expected) assert torch.allclose(out, expected)
import torch import torch
from torch_scatter import scatter_add, scatter_max from torch_scatter import scatter_add, scatter_max
from torch_scatter.utils.gen import broadcast
def scatter_softmax(src, index, dim=-1, eps=1e-12): def scatter_softmax(src, index, dim=-1, eps=1e-12):
...@@ -31,6 +32,7 @@ def scatter_softmax(src, index, dim=-1, eps=1e-12): ...@@ -31,6 +32,7 @@ def scatter_softmax(src, index, dim=-1, eps=1e-12):
raise ValueError('`scatter_softmax` can only be computed over tensors ' raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.') 'with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
...@@ -73,6 +75,7 @@ def scatter_log_softmax(src, index, dim=-1, eps=1e-12): ...@@ -73,6 +75,7 @@ def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
raise ValueError('`scatter_log_softmax` can only be computed over ' raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.') 'tensors with floating point data types.')
src, index = broadcast(src, index, dim)
max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index) max_per_src_element = max_value_per_index.gather(dim, index)
......
import torch import torch
from torch_scatter.utils.gen import gen
from torch_scatter.add import scatter_add from torch_scatter.add import scatter_add
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -11,8 +10,8 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -11,8 +10,8 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
return scatter_add(src, index, dim, out, dim_size, fill_value) return scatter_add(src, index, dim, out, dim_size, fill_value)
def segment_add_csr(src, indptr): def segment_add_csr(src, indptr, out=None):
return torch_scatter.segment_cuda.segment_add_csr(src, indptr) return torch_scatter.segment_cuda.segment_add_csr(src, indptr, out)
def segment_add_coo(src, index, dim_size=None): def segment_add_coo(src, index, dim_size=None):
......
...@@ -11,10 +11,9 @@ def maybe_dim_size(index, dim_size=None): ...@@ -11,10 +11,9 @@ def maybe_dim_size(index, dim_size=None):
return index.max().item() + 1 if index.numel() > 0 else 0 return index.max().item() + 1 if index.numel() > 0 else 0
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): def broadcast(src, index, dim):
dim = range(src.dim())[dim] # Get real dim value. dim = range(src.dim())[dim] # Get real dim value.
# Automatically expand index tensor to the right dimensions.
if index.dim() == 1: if index.dim() == 1:
index_size = list(repeat(1, src.dim())) index_size = list(repeat(1, src.dim()))
index_size[dim] = src.size(dim) index_size[dim] = src.size(dim)
...@@ -33,9 +32,17 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -33,9 +32,17 @@ def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
expand_size = [] expand_size = []
for s, i in zip(src.size(), index.size()): for s, i in zip(src.size(), index.size()):
expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)] expand_size += [-1 if s == i and s != 1 and i != 1 else max(i, s)]
src = src.expand(expand_size) src = src.expand(expand_size)
index = index.expand_as(src) index = index.expand_as(src)
return src, index
def gen(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src, index = broadcast(src, index, dim)
dim = range(src.dim())[dim] # Get real dim value.
# Generate output tensor if not given. # Generate output tensor if not given.
if out is None: if out is None:
out_size = list(src.size()) out_size = list(src.size())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment