Commit 0799bc08 authored by limm's avatar limm
Browse files

suport v2.1.0

parent 50e05e1e
cmake_minimum_required(VERSION 3.0)
project(torchscatter)
set(CMAKE_CXX_STANDARD 14)
set(TORCHSCATTER_VERSION 2.1.0)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
if(WITH_CUDA)
enable_language(CUDA)
add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()
if (WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Development)
endif()
find_package(Torch REQUIRED)
file(GLOB HEADERS csrc/*.h)
file(GLOB OPERATOR_SOURCES csrc/cpu/*.h csrc/cpu/*.cpp csrc/*.cpp)
if(WITH_CUDA)
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu)
endif()
add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif()
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter)
target_include_directories(${PROJECT_NAME} INTERFACE
$<BUILD_INTERFACE:${HEADERS}>
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
set(TORCHSCATTER_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchScatter" CACHE STRING "install path for TorchScatterConfig.cmake")
configure_package_config_file(cmake/TorchScatterConfig.cmake.in
"${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfig.cmake"
INSTALL_DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR})
write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfigVersion.cmake
VERSION ${TORCHSCATTER_VERSION}
COMPATIBILITY AnyNewerVersion)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfigVersion.cmake
DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR})
install(TARGETS ${PROJECT_NAME}
EXPORT TorchScatterTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
install(EXPORT TorchScatterTargets
NAMESPACE TorchScatter::
DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR})
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME})
install(FILES
csrc/cpu/scatter_cpu.h
csrc/cpu/segment_coo_cpu.h
csrc/cpu/segment_csr_cpu.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu)
if(WITH_CUDA)
install(FILES
csrc/cuda/scatter_cuda.h
csrc/cuda/segment_coo_cuda.h
csrc/cuda/segment_csr_cuda.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda)
endif()
if(WITH_CUDA)
set_property(TARGET torch_cuda PROPERTY INTERFACE_COMPILE_OPTIONS "")
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
endif()
Metadata-Version: 2.1
Name: torch_scatter
Version: 2.0.9
Summary: PyTorch Extension Library of Optimized Scatter Operations
Home-page: https://github.com/rusty1s/pytorch_scatter
Author: Matthias Fey
Author-email: matthias.fey@tu-dortmund.de
License: MIT
Description: UNKNOWN
Keywords: pytorch,scatter,segment,gather
Platform: UNKNOWN
Requires-Python: >=3.6
Provides-Extra: test
# <div align="center"><strong>torch-scatter-2.0.9</strong></div> # <div align="center"><strong>PyTorch Scatter</strong></div>
## 简介 ## 简介
torch-scatter一个在PyTorch库中使用的Python库,它用于从张量中随机选择元素并返回一个新的张量。这个库提供了一种简单的方法来创建具有随机标签的数据集,这对于许多机器学习任务非常有用,例如数据增强或生成对抗网络(GANs)。 PyTorch Scatter一个小型扩展库组成,该扩展库包含用于PyTorch的高度优化的稀疏更新(分散和分段)操作,这些操作在主包中丢失。分散和分段运算可以粗略地描述为基于给定“群索引”张量的归约运算。分段运算需要对“组索引”张量进行排序,而分散运算则不受这些要求的约束。PyTorch Scatter官方github地址:[https://github.com/rusty1s/pytorch_scatter](https://github.com/rusty1s/pytorch_scatter)
## 依赖安装 ## 安装
+ pytorch1.10或者pytorch1.13 以及对应的torchvision(建议dtk-22.04.2、dtk-23.04与dtk-23.10)
+ python 3.7-3.10
### 1、使用源码编译方式安装 ### 使用pip方式安装
pytorch-scatter whl包下载目录:[http://10.6.10.68:8000/customized/torch-scatter/dtk2404](http://10.6.10.68:8000/customized/torch-scatter/dtk2404),目前只提供有python3.8版本的whl包。
```shell
pip install torch_scatter* (下载的torch_scatter的whl包)
```
### 使用源码编译方式安装
#### 编译环境准备 #### 编译环境准备
提供2种环境准备方式: - 安装相关依赖
1. 基于光源pytorch基础镜像环境:镜像下载地址:[https://sourcefind.cn/#/image/dcu/pytorch](https://sourcefind.cn/#/image/dcu/pytorch),根据pytorch、python、dtk及系统下载对应的镜像版本。
2. 基于现有python环境:安装pytorch和torchvision,whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch)[https://cancon.hpccube.com:65024/4/main/vision](https://cancon.hpccube.com:65024/4/main/vision),根据python、dtk版本,下载对应pytorch和torchvision的whl包。安装命令如下:
```shell ```shell
pip install torch* (下载的torch的whl包) pip install -r requirements.txt
pip install torchvision* (下载的torchvision的whl包)
pip install setuptools==59.5.0 wheel
``` ```
- 在首页 | 光合开发者社区下载 dtk24.04 解压至 /opt/ 路径下,并建立软链接
#### 源码编译安装
- 代码下载
```shell ```shell
git clone http://developer.hpccube.com/codes/aicomponent/torch-scatter # 根据编译需要切换分支 cd /opt && ln -s dtk-24.04 dtk
source /opt/dtk/env.sh
``` ```
- 源码编译(进入torch-scatter目录): - 安装pytorch,pytorch whl包下载目录:[http://10.6.10.68:8000/debug/pytorch/dtk24.10/hipify/](http://10.6.10.68:8000/debug/pytorch/dtk24.04/hipify/),根据python、dtk版本,下载对应pytorch的whl包。安装命令如下:
```shell
pip install torch* (下载的torch的whl包)
``` ```
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$C_INCLUDE_PATH #### 源码编译安装
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$CPLUS_INCLUDE_PATH ```shell
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$C_INCLUDE_PATH git clone -b 2.1.0-release http://developer.hpccube.com/codes/aicomponent/torch-scatter.git
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$CPLUS_INCLUDE_PATH cd torch-scatter
export C_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$C_INCLUDE_PATH python setup.py bdist_wheel
export CPLUS_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$CPLUS_INCLUDE_PATH pip install dist/*.whl
export LD_LIBRARY_PATH=$ROCM_PATH/rocrand/lib:$LD_LIBRARY_PATH
export FORCE_ONLY_HIP=1
export CC=hipcc
export CXX=hipcc
python setup.py install
``` ```
#### 注意事项 ## 单测
+ 若使用pip install下载安装过慢,可添加pypi清华源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ ```shell
+ ROCM_PATH为dtk的路径,默认为/opt/dtk cd torch-scatter
python setup.py test
## 验证
```python
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
``` ```
## Known Issue
完成安装进行单测时,会报错ImportError: Could not find module '_version_cpu' ~,在根目录/下查找一下,然后把库文件目录添加一下软链接即可。
``` ```
print(out) find / -name "_version_cpu.so"
tensor([[0, 0, 4, 3, 2, 0], cd /torch-scatter/torch_scatter
[2, 4, 3, 0, 0, 0]]) ln -s /usr/local/lib/python3.8/site-packages/torch-scatter/* .
print(argmax)
tensor([[5, 5, 3, 4, 0, 1]
[1, 4, 3, 5, 5, 5]])
``` ```
## Known Issue
- 该库没有基于cpu环境修改,仅支持dcu,请在有dcu卡的环境运行。
- 如需完整使用所有pyg功能,请pip install torch-geometric
## 参考资料 ## 参考资料
- [README_ORIGIN](README_ORIGIN.md)
- [https://pypi.org/project/torch-scatter/2.0.9/](https://pypi.org/project/torch-scatter/2.0.9/) https://github.com/rusty1s/pytorch_scatter
...@@ -54,39 +54,41 @@ conda install pytorch-scatter -c pyg ...@@ -54,39 +54,41 @@ conda install pytorch-scatter -c pyg
We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl). We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
#### PyTorch 1.10.0 #### PyTorch 1.13
To install the binaries for PyTorch 1.10.0, simply run To install the binaries for PyTorch 1.13.0, simply run
``` ```
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html pip install torch-scatter -f https://data.pyg.org/whl/torch-1.13.0+${CUDA}.html
``` ```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu113` depending on your PyTorch installation. where `${CUDA}` should be replaced by either `cpu`, `cu116`, or `cu117` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu113` | | | `cpu` | `cu116` | `cu117` |
|-------------|-------|---------|---------| |-------------|-------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | | **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ | | **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | **macOS** | ✅ | | |
#### PyTorch 1.9.0/1.9.1 #### PyTorch 1.12
To install the binaries for PyTorch 1.9.0 and 1.9.1, simply run To install the binaries for PyTorch 1.12.0, simply run
``` ```
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+${CUDA}.html pip install torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+${CUDA}.html
``` ```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, or `cu111` depending on your PyTorch installation. where `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, or `cu116` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu111` | | | `cpu` | `cu102` | `cu113` | `cu116` |
|-------------|-------|---------|---------| |-------------|-------|---------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | | **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ | | **Windows** | ✅ | | ✅ | ✅ |
| **macOS** | ✅ | | | | **macOS** | ✅ | | | |
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1 and PyTorch 1.8.0/1.8.1 (following the same procedure). **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2 and PyTorch 1.11.0 (following the same procedure).
For older versions, you need to explicitly specify the latest supported version number or install via `pip install --no-index` in order to prevent a manual installation from source.
You can look up the latest supported version number [here](https://data.pyg.org/whl).
### From source ### From source
...@@ -141,18 +143,19 @@ tensor([[5, 5, 3, 4, 0, 1] ...@@ -141,18 +143,19 @@ tensor([[5, 5, 3, 4, 0, 1]
## Running tests ## Running tests
``` ```
python setup.py test pytest
``` ```
## C++ API ## C++ API
`torch-scatter` also offers a C++ API that contains C++ equivalent of python models. `torch-scatter` also offers a C++ API that contains C++ equivalent of python models.
For this, we need to add `TorchLib` to the `-DCMAKE_PREFIX_PATH` (*e.g.*, it may exists in `{CONDA}/lib/python{X.X}/site-packages/torch` if installed via `conda`):
``` ```
mkdir build mkdir build
cd build cd build
# Add -DWITH_CUDA=on support for the CUDA if needed # Add -DWITH_CUDA=on support for CUDA support
cmake .. cmake -DCMAKE_PREFIX_PATH="..." ..
make make
make install make install
``` ```
import time
import itertools
import argparse
import torch
from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr
from scatter_segment import short_rows, long_rows, download, bold
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes[1:]:
try:
x = torch.randn((dim_size, size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row)
out2 = gather_coo(x, row)
out3 = gather_csr(x, rowptr)
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
torch.autograd.grad(out, x, out, only_inputs=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def select(x):
return x.index_select(0, row)
def gather(x):
return x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
def gat_coo(x):
return gather_coo(x, row)
def gat_csr(x):
return gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((dim_size, size), device=args.device)
t1 += [time_func(select, x)]
t2 += [time_func(gather, x)]
t3 += [time_func(gat_coo, x)]
t4 += [time_func(gat_csr, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SELECT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('GAT ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('GAT_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('GAT_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
import time
import os.path as osp
import itertools
import argparse
import wget
import torch
from scipy.io import loadmat
from torch_scatter import scatter, segment_coo, segment_csr
short_rows = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def download(dataset):
url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name))
print('')
def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out3 = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out3 = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def sca1_row(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row_tmp = row.view(-1, 1).expand_as(x) if x.dim() > 1 else row
return out.scatter_add_(0, row_tmp, x)
def sca1_col(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row2_tmp = row2.view(-1, 1).expand_as(x) if x.dim() > 1 else row2
return out.scatter_add_(0, row2_tmp, x)
def sca2_row(x):
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
def sca2_col(x):
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
def seg_csr(x):
return segment_csr(x, rowptr, reduce=args.reduce)
def dense1(x):
return getattr(torch, args.reduce)(x, dim=-2)
def dense2(x):
return getattr(torch, args.reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6, t7, t8 = [], [], [], [], [], [], [], []
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
t1 += [time_func(sca1_row, x)]
t2 += [time_func(sca1_col, x)]
t3 += [time_func(sca2_row, x)]
t4 += [time_func(sca2_col, x)]
t5 += [time_func(seg_coo, x)]
t6 += [time_func(seg_csr, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4, t5, t6):
t.append(float('inf'))
try:
x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=args.device)
t7 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1))
t8 += [time_func(dense2, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t7, t8):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4, t5, t6, t7, t8])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SCA1_ROW')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SCA1_COL')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SCA2_ROW')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SCA2_COL')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print('\t'.join([bold('SEG_COO ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
print('\t'.join([bold('SEG_CSR ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
print('\t'.join([bold('DENSE1 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t7, winner[6])]))
print('\t'.join([bold('DENSE2 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t8, winner[7])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
# TorchScatterConfig.cmake
# --------------------
#
# Exported targets:: Scatter
#
@PACKAGE_INIT@
set(PN TorchScatter)
set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@")
set(${PN}_LIBRARY "")
set(${PN}_DEFINITIONS USING_${PN})
check_required_components(${PN})
if(NOT (CMAKE_VERSION VERSION_LESS 3.0))
#-----------------------------------------------------------------------------
# Don't include targets if this file is being picked up by another
# project which has already built this as a subproject
#-----------------------------------------------------------------------------
if(NOT TARGET ${PN}::TorchScatter)
include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake")
if(NOT TARGET torch_library)
find_package(Torch REQUIRED)
endif()
if(NOT TARGET Python3::Python)
find_package(Python3 COMPONENTS Development)
endif()
target_link_libraries(TorchScatter::TorchScatter INTERFACE ${TORCH_LIBRARIES} Python3::Python)
if(@WITH_CUDA@)
target_compile_definitions(TorchScatter::TorchScatter INTERFACE WITH_CUDA)
endif()
endif()
endif()
```
./build_conda.sh 3.9 1.13.0 cu116 # python, pytorch and cuda version
```
#!/bin/bash
export PYTHON_VERSION=$1
export TORCH_VERSION=$2
export CUDA_VERSION=$3
export CONDA_PYTORCH_CONSTRAINT="pytorch==${TORCH_VERSION%.*}.*"
if [ "${CUDA_VERSION}" = "cpu" ]; then
export CONDA_CUDATOOLKIT_CONSTRAINT="cpuonly # [not osx]"
else
case $CUDA_VERSION in
cu117)
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==11.7.*"
;;
cu116)
if [ "${TORCH_VERSION}" = "1.12.0" ]; then
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.6.*"
else
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==11.6.*"
fi
;;
cu115)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.5.*"
;;
cu113)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.3.*"
;;
cu111)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.1.*"
;;
cu102)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.2.*"
;;
cu101)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.1.*"
;;
*)
echo "Unrecognized CUDA_VERSION=$CUDA_VERSION"
exit 1
;;
esac
fi
echo "PyTorch $TORCH_VERSION+$CUDA_VERSION"
echo "- $CONDA_PYTORCH_CONSTRAINT"
echo "- $CONDA_CUDATOOLKIT_CONSTRAINT"
if [ "${TORCH_VERSION}" = "1.12.0" ] && [ "${CUDA_VERSION}" = "cu116" ]; then
conda build . -c pytorch -c default -c nvidia -c conda-forge --output-folder "$HOME/conda-bld"
else
conda build . -c pytorch -c default -c nvidia --output-folder "$HOME/conda-bld"
fi
package:
name: pytorch-scatter
version: 2.1.0
source:
path: ../..
requirements:
build:
- {{ compiler('c') }} # [win]
host:
- pip
- python {{ environ.get('PYTHON_VERSION') }}
- {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
- {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
run:
- python {{ environ.get('PYTHON_VERSION') }}
- {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
- {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
build:
string: py{{ environ.get('PYTHON_VERSION').replace('.', '') }}_torch_{{ environ['TORCH_VERSION'] }}_{{ environ['CUDA_VERSION'] }}
script: pip install .
script_env:
- FORCE_CUDA
- TORCH_CUDA_ARCH_LIST
test:
imports:
- torch_scatter
about:
home: https://github.com/rusty1s/pytorch_scatter
license: MIT
summary: PyTorch Extension Library of Optimized Scatter Operations
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define MAX_TENSORINFO_DIMS 25 #define MAX_TENSORINFO_DIMS 25
......
...@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -57,7 +57,7 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto N = out.size(dim); auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "scatter_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
......
...@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -69,7 +69,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1]; auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K); std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_coo_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr; scalar_t *count_data = nullptr;
...@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -178,7 +178,7 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1]; auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_coo_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index, segment_coo_cpu(torch::Tensor src, torch::Tensor index,
......
...@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -57,7 +57,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr); auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1]; auto stride = indptr_info.strides[indptr_info.dims - 1];
std::vector<int64_t> args(K); std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_csr_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
...@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -135,7 +135,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
auto indptr_info = getTensorInfo<int64_t>(indptr); auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1]; auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_", [&] { AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_csr_cpu", [&] {
auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
......
#pragma once #pragma once
#include <torch/extension.h> #include "../extensions.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#pragma once
#define ATOMIC(NAME) \
template <typename scalar, size_t size> struct Atomic##NAME##IntegerImpl; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 1> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)(address - ((size_t)address & 3)); \
uint32_t old = *address_as_ui; \
uint32_t shift = ((size_t)address & 3) * 8; \
uint32_t sum; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, scalar((old >> shift) & 0xff)); \
old = (old & ~(0x000000ff << shift)) | (sum << shift); \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 2> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = \
(uint32_t *)((char *)address - ((size_t)address & 2)); \
uint32_t old = *address_as_ui; \
uint32_t sum; \
uint32_t newval; \
uint32_t assumed; \
\
do { \
assumed = old; \
sum = OP(val, (size_t)address & 2 ? scalar(old >> 16) \
: scalar(old & 0xffff)); \
newval = (size_t)address & 2 ? (old & 0xffff) | (sum << 16) \
: (old & 0xffff0000) | sum; \
old = atomicCAS(address_as_ui, assumed, newval); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
uint32_t *address_as_ui = (uint32_t *)address; \
uint32_t old = *address_as_ui; \
uint32_t assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ui, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##IntegerImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long *address_as_ull = (unsigned long long *)address; \
unsigned long long old = *address_as_ull; \
unsigned long long assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_ull, assumed, OP(val, (scalar)old)); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar, size_t size> struct Atomic##NAME##DecimalImpl; \
\
template <> struct Atomic##NAME##DecimalImpl<at::Half, 2> { \
inline __device__ void operator()(at::Half *address, at::Half val) { \
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::Half hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <> struct Atomic##NAME##DecimalImpl<at::BFloat16, 2> { \
inline __device__ void operator()(at::BFloat16 *address, at::BFloat16 val){\
unsigned int *address_as_ui = \
(unsigned int *)((char *)address - ((size_t)address & 2)); \
unsigned int old = *address_as_ui; \
unsigned int assumed; \
\
do { \
assumed = old; \
at::BFloat16 hsum; \
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); \
hsum = OP(hsum, val); \
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) \
: (old & 0xffff0000) | hsum.x; \
old = atomicCAS(address_as_ui, assumed, old); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 4> { \
inline __device__ void operator()(scalar *address, scalar val) { \
int *address_as_i = (int *)address; \
int old = *address_as_i; \
int assumed; \
\
do { \
assumed = old; \
old = atomicCAS(address_as_i, assumed, \
__float_as_int(OP(val, __int_as_float(assumed)))); \
} while (assumed != old); \
} \
}; \
\
template <typename scalar> struct Atomic##NAME##DecimalImpl<scalar, 8> { \
inline __device__ void operator()(scalar *address, scalar val) { \
unsigned long long int *address_as_ull = \
(unsigned long long int *)address; \
unsigned long long int old = *address_as_ull; \
unsigned long long int assumed; \
\
do { \
assumed = old; \
old = atomicCAS( \
address_as_ull, assumed, \
__double_as_longlong(OP(val, __longlong_as_double(assumed)))); \
} while (assumed != old); \
} \
};
#define OP(X, Y) Y + X
ATOMIC(Add)
#undef OP
static inline __device__ void atomAdd(uint8_t *address, uint8_t val) {
AtomicAddIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomAdd(int8_t *address, int8_t val) {
AtomicAddIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomAdd(int16_t *address, int16_t val) {
AtomicAddIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomAdd(int32_t *address, int32_t val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(int64_t *address, int64_t val) {
AtomicAddIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700 || CUDA_VERSION < 10000))
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
AtomicAddDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
#else
static inline __device__ void atomAdd(at::Half *address, at::Half val) {
atomicAdd(reinterpret_cast<__half *>(address), val);
}
#endif
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
AtomicAddDecimalImpl<double, sizeof(double)>()(address, val);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
static inline __device__ void atomAdd(at::BFloat16 *address, at::BFloat16 val) {
AtomicAddDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y *X
ATOMIC(Mul)
#undef OP
static inline __device__ void atomMul(uint8_t *address, uint8_t val) {
AtomicMulIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMul(int8_t *address, int8_t val) {
AtomicMulIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMul(int16_t *address, int16_t val) {
AtomicMulIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMul(int32_t *address, int32_t val) {
AtomicMulIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomMul(int64_t *address, int64_t val) {
AtomicMulIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMul(float *address, float val) {
AtomicMulDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMul(at::Half *address, at::Half val) {
AtomicMulDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMul(double *address, double val) {
AtomicMulDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMul(at::BFloat16 *address, at::BFloat16 val) {
AtomicMulDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) Y / X
ATOMIC(Div)
#undef OP
static inline __device__ void atomDiv(uint8_t *address, uint8_t val) {
AtomicDivIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomDiv(int8_t *address, int8_t val) {
AtomicDivIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomDiv(int16_t *address, int16_t val) {
AtomicDivIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomDiv(int32_t *address, int32_t val) {
AtomicDivIntegerImpl<int32_t, sizeof(int32_t)>()(address, val);
}
static inline __device__ void atomDiv(int64_t *address, int64_t val) {
AtomicDivIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomDiv(at::Half *address, at::Half val) {
AtomicDivDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomDiv(float *address, float val) {
AtomicDivDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomDiv(double *address, double val) {
AtomicDivDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomDiv(at::BFloat16 *address, at::BFloat16 val) {
AtomicDivDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) max(Y, X)
ATOMIC(Max)
#undef OP
static inline __device__ void atomMax(uint8_t *address, uint8_t val) {
AtomicMaxIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMax(int8_t *address, int8_t val) {
AtomicMaxIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMax(int16_t *address, int16_t val) {
AtomicMaxIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMax(int32_t *address, int32_t val) {
atomicMax(address, val);
}
static inline __device__ void atomMax(int64_t *address, int64_t val) {
AtomicMaxIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMax(at::Half *address, at::Half val) {
AtomicMaxDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMax(float *address, float val) {
AtomicMaxDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMax(double *address, double val) {
AtomicMaxDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMax(at::BFloat16 *address, at::BFloat16 val) {
AtomicMaxDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
#define OP(X, Y) min(Y, X)
ATOMIC(Min)
#undef OP
static inline __device__ void atomMin(uint8_t *address, uint8_t val) {
AtomicMinIntegerImpl<uint8_t, sizeof(uint8_t)>()(address, val);
}
static inline __device__ void atomMin(int8_t *address, int8_t val) {
AtomicMinIntegerImpl<int8_t, sizeof(int8_t)>()(address, val);
}
static inline __device__ void atomMin(int16_t *address, int16_t val) {
AtomicMinIntegerImpl<int16_t, sizeof(int16_t)>()(address, val);
}
static inline __device__ void atomMin(int32_t *address, int32_t val) {
atomicMin(address, val);
}
static inline __device__ void atomMin(int64_t *address, int64_t val) {
AtomicMinIntegerImpl<int64_t, sizeof(int64_t)>()(address, val);
}
static inline __device__ void atomMin(at::Half *address, at::Half val) {
AtomicMinDecimalImpl<at::Half, sizeof(at::Half)>()(address, val);
}
static inline __device__ void atomMin(float *address, float val) {
AtomicMinDecimalImpl<float, sizeof(float)>()(address, val);
}
static inline __device__ void atomMin(double *address, double val) {
AtomicMinDecimalImpl<double, sizeof(double)>()(address, val);
}
static inline __device__ void atomMin(at::BFloat16 *address, at::BFloat16 val) {
AtomicMinDecimalImpl<at::BFloat16, sizeof(at::BFloat16)>()(address, val);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment