Commit 6a2a503e authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #191 failed with stages
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
include README.md
include LICENSE
recursive-exclude test *
recursive-include csrc *
Metadata-Version: 2.1
Name: torch_scatter
Version: 2.0.9
Summary: PyTorch Extension Library of Optimized Scatter Operations
Home-page: https://github.com/rusty1s/pytorch_scatter
Author: Matthias Fey
Author-email: matthias.fey@tu-dortmund.de
License: MIT
Description: UNKNOWN
Keywords: pytorch,scatter,segment,gather
Platform: UNKNOWN
Requires-Python: >=3.6
Provides-Extra: test
[pypi-image]: https://badge.fury.io/py/torch-scatter.svg
[pypi-url]: https://pypi.python.org/pypi/torch-scatter
[testing-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml/badge.svg
[testing-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml
[linting-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml/badge.svg
[linting-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml
[docs-image]: https://readthedocs.org/projects/pytorch-scatter/badge/?version=latest
[docs-url]: https://pytorch-scatter.readthedocs.io/en/latest/?badge=latest
[coverage-image]: https://codecov.io/gh/rusty1s/pytorch_scatter/branch/master/graph/badge.svg
[coverage-url]: https://codecov.io/github/rusty1s/pytorch_scatter?branch=master
# PyTorch Scatter
[![PyPI Version][pypi-image]][pypi-url]
[![Testing Status][testing-image]][testing-url]
[![Linting Status][linting-image]][linting-url]
[![Docs Status][docs-image]][docs-url]
[![Code Coverage][coverage-image]][coverage-url]
<p align="center">
<img width="50%" src="https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true" />
</p>
--------------------------------------------------------------------------------
**[Documentation](https://pytorch-scatter.readthedocs.io)**
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html) based on arbitrary indices
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: `scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
## Installation
### Anaconda
**Update:** You can now install `pytorch-scatter` via [Anaconda](https://anaconda.org/pyg/pytorch-scatter) for all major OS/PyTorch/CUDA combinations 🤗
Given that you have [`pytorch >= 1.8.0` installed](https://pytorch.org/get-started/locally/), simply run
```
conda install pytorch-scatter -c pyg
```
### Binaries
We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
#### PyTorch 1.10.0
To install the binaries for PyTorch 1.10.0, simply run
```
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu113` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu113` |
|-------------|-------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |
#### PyTorch 1.9.0/1.9.1
To install the binaries for PyTorch 1.9.0 and 1.9.1, simply run
```
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu111` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu111` |
|-------------|-------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1 and PyTorch 1.8.0/1.8.1 (following the same procedure).
### From source
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.4.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
$ echo $CPATH
>>> /usr/local/cuda/include:...
```
Then run:
```
pip install torch-scatter
```
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
## Example
```py
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
```
```
print(out)
tensor([[0, 0, 4, 3, 2, 0],
[2, 4, 3, 0, 0, 0]])
print(argmax)
tensor([[5, 5, 3, 4, 0, 1]
[1, 4, 3, 5, 5, 5]])
```
## Running tests
```
python setup.py test
```
## C++ API
`torch-scatter` also offers a C++ API that contains C++ equivalent of python models.
```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
```
#pragma once
#include <torch/extension.h>
#define MAX_TENSORINFO_DIMS 25
template <typename scalar_t> struct TensorInfo {
TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS],
int st[MAX_TENSORINFO_DIMS]) {
data = p;
dims = dim;
AT_ASSERT(dims < MAX_TENSORINFO_DIMS);
for (int i = 0; i < dim; ++i) {
sizes[i] = sz[i];
strides[i] = st[i];
}
}
scalar_t *data;
int dims;
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
};
template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS];
int dims = tensor.dim();
for (int i = 0; i < dims; ++i) {
sizes[i] = tensor.size(i);
strides[i] = tensor.stride(i);
}
return TensorInfo<scalar_t>(tensor.data_ptr<scalar_t>(), dims, sizes,
strides);
}
template <typename scalar_t> struct IndexToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = 0;
for (int i = info.dims - 1; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
template <typename scalar_t> struct IndexPtrToOffset {
static inline int get(int idx, const TensorInfo<scalar_t> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
static constexpr ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
static constexpr ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
static constexpr ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
static constexpr ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
static constexpr ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
static constexpr ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg,
int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "scatter_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t, REDUCE>::update(
out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e);
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include "segment_coo_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (auto i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto tmp = index.select(dim, index.size(dim) - 1);
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
sizes[dim] = 1 + *tmp.data_ptr<int64_t>();
}
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
arg_out = torch::zeros(sizes, out.options());
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr;
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN)
count_data = arg_out.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = out_data[b * N * K + k];
row_start = 0;
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
vals[k] = out_data[b * N * K + next_idx * K + k];
}
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
row_start = e + 1;
}
idx = next_idx;
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1,
(scalar_t)1);
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) == index.size(i));
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
auto N = src.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
out_data[b * E * K + e * K + k] = vals[k];
if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
}
}
}
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
#include "segment_csr_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full(out.sizes(), src.size(dim), indptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e);
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
arg_out_data + n * K + k, args[k],
row_end - row_start);
}
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= indptr.dim());
auto sizes = indptr.sizes().vec();
for (auto i = 0; i < indptr.dim() - 1; i++)
sizes[i] = src.size(i);
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (src.numel() > 0)
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
else
sizes[dim] = 0;
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out;
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (auto k = 0; k < K; k++)
vals[k] = src_data[n * K + k];
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (auto e = row_start; e < row_end; e++)
for (auto k = 0; k < K; k++)
out_data[offset + e * K + k] = vals[k];
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce);
torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#pragma once
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::Half hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || TORCH_HIP_VERSION < 10000)
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
#else
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
#endif
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || TORCH_HIP_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(at::Half *address, at::Half val) {
AtomicMulDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(at::Half *address, at::Half val) {
AtomicDivDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(at::Half *address, at::Half val) {
AtomicMaxDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(at::Half *address, at::Half val) {
AtomicMinDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
#pragma once
#include <ATen/hip/detail/TensorInfo.cuh>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
idx /= info.sizes[info.dims - 1] - 1;
for (int i = info.dims - 2; i >= 0; --i) {
offset += (idx % info.sizes[i]) * info.strides[i];
idx /= info.sizes[i];
}
return offset;
}
};
#pragma once
#include <limits>
#include <map>
#include "atomics.cuh"
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val,
scalar_t new_val) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == SUM || REDUCE == MEAN)
atomAdd(address, val);
else if (REDUCE == MUL)
atomMul(address, val);
else if (REDUCE == DIV)
atomDiv(address, val);
else if (REDUCE == MIN)
atomMin(address, val);
else if (REDUCE == MAX)
atomMax(address, val);
}
};
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include "hip/hip_runtime.h"
#include "scatter_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, ReductionType REDUCE>
__global__ void
scatter_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int E, int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
Reducer<scalar_t, REDUCE>::atomic_write(out_data + b * N * K + idx * K + k,
src_data[thread_idx]);
}
}
template <typename scalar_t>
__global__ void
scatter_arg_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
const scalar_t *out_data, int64_t *arg_out_data, int E,
int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int e = (thread_idx / K) % E;
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) {
arg_out_data[b * N * K + idx * K + k] = e;
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
sizes[dim] = 1 + index.max().cpu().data_ptr<int64_t>()[0];
}
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
scatter_kernel<scalar_t, REDUCE>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, E, K, N, src.numel());
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX)
scatter_arg_kernel<scalar_t>
<<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});
return std::make_tuple(out, arg_out);
}
#include "hip/hip_runtime.h"
#include "scatter_hip.h"
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/detail/IndexUtils.cuh>
#include <ATen/hip/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, ReductionType REDUCE>
__global__ void
scatter_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
scalar_t *out_data, int E, int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
Reducer<scalar_t, REDUCE>::atomic_write(out_data + b * N * K + idx * K + k,
src_data[thread_idx]);
}
}
template <typename scalar_t>
__global__ void
scatter_arg_kernel(const scalar_t *src_data,
const at::cuda::detail::TensorInfo<int64_t, int> index_info,
const scalar_t *out_data, int64_t *arg_out_data, int E,
int K, int N, int numel) {
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int b = thread_idx / (E * K);
int e = (thread_idx / K) % E;
int k = thread_idx % K;
if (thread_idx < numel) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
thread_idx, index_info);
int64_t idx = index_info.data[offset];
if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) {
arg_out_data[b * N * K + idx * K + k] = e;
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CUDA(src);
CHECK_CUDA(index);
if (optional_out.has_value())
CHECK_CUDA(optional_out.value());
hipSetDevice(src.get_device());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
sizes[dim] = 1 + index.max().cpu().data_ptr<int64_t>()[0];
}
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out);
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
hipLaunchKernelGGL(( scatter_kernel<scalar_t, REDUCE>)
, dim3(BLOCKS(src.numel())), dim3(THREADS), 0, stream,
src_data, index_info, out_data, E, K, N, src.numel());
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MIN || REDUCE == MAX)
hipLaunchKernelGGL(( scatter_arg_kernel<scalar_t>)
, dim3(BLOCKS(src.numel())), dim3(THREADS), 0, stream,
src_data, index_info, out_data, arg_out_data, E, K, N,
src.numel());
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
template<typename T>
__device__ T __ldg(const T* ptr) {
return *ptr;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment