Commit b73cc994 authored by quyuanhao123's avatar quyuanhao123
Browse files

Initial commit

parents
Pipeline #193 failed with stages
in 0 seconds
This directory contains eggs that were downloaded by setuptools to build, test, and run plug-ins.
This directory caches those eggs to prevent repeated downloads.
However, it is safe to delete this directory.
Copyright Jason R. Coombs
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to
deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
sell copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
Metadata-Version: 2.1
Name: pytest-runner
Version: 6.0.0
Summary: Invoke py.test as distutils command with dependency resolution
Home-page: https://github.com/pytest-dev/pytest-runner/
Author: Jason R. Coombs
Author-email: jaraco@jaraco.com
License: UNKNOWN
Platform: UNKNOWN
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Framework :: Pytest
Requires-Python: >=3.7
License-File: LICENSE
Provides-Extra: docs
Requires-Dist: sphinx ; extra == 'docs'
Requires-Dist: jaraco.packaging (>=9) ; extra == 'docs'
Requires-Dist: rst.linker (>=1.9) ; extra == 'docs'
Requires-Dist: jaraco.tidelift (>=1.4) ; extra == 'docs'
Provides-Extra: testing
Requires-Dist: pytest (>=6) ; extra == 'testing'
Requires-Dist: pytest-checkdocs (>=2.4) ; extra == 'testing'
Requires-Dist: pytest-flake8 ; extra == 'testing'
Requires-Dist: pytest-cov ; extra == 'testing'
Requires-Dist: pytest-enabler (>=1.0.1) ; extra == 'testing'
Requires-Dist: pytest-virtualenv ; extra == 'testing'
Requires-Dist: types-setuptools ; extra == 'testing'
Requires-Dist: pytest-black (>=0.3.7) ; (platform_python_implementation != "PyPy") and extra == 'testing'
Requires-Dist: pytest-mypy (>=0.9.1) ; (platform_python_implementation != "PyPy") and extra == 'testing'
.. image:: https://img.shields.io/pypi/v/pytest-runner.svg
:target: `PyPI link`_
.. image:: https://img.shields.io/pypi/pyversions/pytest-runner.svg
:target: `PyPI link`_
.. _PyPI link: https://pypi.org/project/pytest-runner
.. image:: https://github.com/pytest-dev/pytest-runner/workflows/tests/badge.svg
:target: https://github.com/pytest-dev/pytest-runner/actions?query=workflow%3A%22tests%22
:alt: tests
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black
:alt: Code style: Black
.. .. image:: https://readthedocs.org/projects/skeleton/badge/?version=latest
.. :target: https://skeleton.readthedocs.io/en/latest/?badge=latest
.. image:: https://img.shields.io/badge/skeleton-2022-informational
:target: https://blog.jaraco.com/skeleton
.. image:: https://tidelift.com/badges/package/pypi/pytest-runner
:target: https://tidelift.com/subscription/pkg/pypi-pytest-runner?utm_source=pypi-pytest-runner&utm_medium=readme
Setup scripts can use pytest-runner to add setup.py test support for pytest
runner.
Deprecation Notice
==================
pytest-runner depends on deprecated features of setuptools and relies on features that break security
mechanisms in pip. For example 'setup_requires' and 'tests_require' bypass ``pip --require-hashes``.
See also `pypa/setuptools#1684 <https://github.com/pypa/setuptools/issues/1684>`_.
It is recommended that you:
- Remove ``'pytest-runner'`` from your ``setup_requires``, preferably removing the ``setup_requires`` option.
- Remove ``'pytest'`` and any other testing requirements from ``tests_require``, preferably removing the ``tests_requires`` option.
- Select a tool to bootstrap and then run tests such as tox.
Usage
=====
- Add 'pytest-runner' to your 'setup_requires'. Pin to '>=2.0,<3dev' (or
similar) to avoid pulling in incompatible versions.
- Include 'pytest' and any other testing requirements to 'tests_require'.
- Invoke tests with ``setup.py pytest``.
- Pass ``--index-url`` to have test requirements downloaded from an alternate
index URL (unnecessary if specified for easy_install in setup.cfg).
- Pass additional py.test command-line options using ``--addopts``.
- Set permanent options for the ``python setup.py pytest`` command (like ``index-url``)
in the ``[pytest]`` section of ``setup.cfg``.
- Set permanent options for the ``py.test`` run (like ``addopts`` or ``pep8ignore``) in the ``[pytest]``
section of ``pytest.ini`` or ``tox.ini`` or put them in the ``[tool:pytest]``
section of ``setup.cfg``. See `pytest issue 567
<https://github.com/pytest-dev/pytest/issues/567>`_.
- Optionally, set ``test=pytest`` in the ``[aliases]`` section of ``setup.cfg``
to cause ``python setup.py test`` to invoke pytest.
Example
=======
The most simple usage looks like this in setup.py::
setup(
setup_requires=[
'pytest-runner',
],
tests_require=[
'pytest',
],
)
Additional dependencies require to run the tests (e.g. mock or pytest
plugins) may be added to tests_require and will be downloaded and
required by the session before invoking pytest.
Follow `this search on github
<https://github.com/search?utf8=%E2%9C%93&q=filename%3Asetup.py+pytest-runner&type=Code&ref=searchresults>`_
for examples of real-world usage.
Standalone Example
==================
This technique is deprecated - if you have standalone scripts
you wish to invoke with dependencies, `use pip-run
<https://pypi.org/project/pip-run>`_.
Although ``pytest-runner`` is typically used to add pytest test
runner support to maintained packages, ``pytest-runner`` may
also be used to create standalone tests. Consider `this example
failure <https://gist.github.com/jaraco/d979a558bc0bf2194c23>`_,
reported in `jsonpickle #117
<https://github.com/jsonpickle/jsonpickle/issues/117>`_
or `this MongoDB test
<https://gist.github.com/jaraco/0b9e482f5c0a1300dc9a>`_
demonstrating a technique that works even when dependencies
are required in the test.
Either example file may be cloned or downloaded and simply run on
any system with Python and Setuptools. It will download the
specified dependencies and run the tests. Afterward, the the
cloned directory can be removed and with it all trace of
invoking the test. No other dependencies are needed and no
system configuration is altered.
Then, anyone trying to replicate the failure can do so easily
and with all the power of pytest (rewritten assertions,
rich comparisons, interactive debugging, extensibility through
plugins, etc).
As a result, the communication barrier for describing and
replicating failures is made almost trivially low.
Considerations
==============
Conditional Requirement
-----------------------
Because it uses Setuptools setup_requires, pytest-runner will install itself
on every invocation of setup.py. In some cases, this causes delays for
invocations of setup.py that will never invoke pytest-runner. To help avoid
this contingency, consider requiring pytest-runner only when pytest
is invoked::
needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv)
pytest_runner = ['pytest-runner'] if needs_pytest else []
# ...
setup(
#...
setup_requires=[
#... (other setup requirements)
] + pytest_runner,
)
For Enterprise
==============
Available as part of the Tidelift Subscription.
This project and the maintainers of thousands of other packages are working with Tidelift to deliver one enterprise subscription that covers all of the open source you use.
`Learn more <https://tidelift.com/subscription/pkg/pypi-PROJECT?utm_source=pypi-PROJECT&utm_medium=referral&utm_campaign=github>`_.
Security Contact
================
To report a security vulnerability, please use the
`Tidelift security contact <https://tidelift.com/security>`_.
Tidelift will coordinate the fix and disclosure.
ptr/__init__.py,sha256=0UfzhCooVgCNTBwVEOPOVGEPck4pnl_6PTfsC-QzNGM,6730
pytest_runner-6.0.0.dist-info/LICENSE,sha256=2z8CRrH5J48VhFuZ_sR4uLUG63ZIeZNyL4xuJUKF-vg,1050
pytest_runner-6.0.0.dist-info/METADATA,sha256=xa7jfGba2yXK6_27FdHmVJzb9SifCjm_EBVxNXC8R6w,7381
pytest_runner-6.0.0.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
pytest_runner-6.0.0.dist-info/entry_points.txt,sha256=BqezBqeO63XyzSYmHYE58gKEFIjJUd-XdsRQkXHy2ig,58
pytest_runner-6.0.0.dist-info/top_level.txt,sha256=DPzHbWlKG8yq8EOD5UgEvVNDWeJRPyimrwfShwV6Iuw,4
pytest_runner-6.0.0.dist-info/RECORD,,
Wheel-Version: 1.0
Generator: bdist_wheel (0.37.1)
Root-Is-Purelib: true
Tag: py3-none-any
[distutils.commands]
ptr = ptr:PyTest
pytest = ptr:PyTest
[docs]
sphinx
jaraco.packaging>=9
rst.linker>=1.9
jaraco.tidelift>=1.4
[testing]
pytest>=6
pytest-checkdocs>=2.4
pytest-flake8
pytest-cov
pytest-enabler>=1.0.1
pytest-virtualenv
types-setuptools
pytest-black>=0.3.7
pytest-mypy>=0.9.1
"""
Implementation
"""
import os as _os
import shlex as _shlex
import contextlib as _contextlib
import sys as _sys
import operator as _operator
import itertools as _itertools
import warnings as _warnings
import pkg_resources
import setuptools.command.test as orig
from setuptools import Distribution
@_contextlib.contextmanager
def _save_argv(repl=None):
saved = _sys.argv[:]
if repl is not None:
_sys.argv[:] = repl
try:
yield saved
finally:
_sys.argv[:] = saved
class CustomizedDist(Distribution):
allow_hosts = None
index_url = None
def fetch_build_egg(self, req):
"""Specialized version of Distribution.fetch_build_egg
that respects respects allow_hosts and index_url."""
from setuptools.command.easy_install import easy_install
dist = Distribution({'script_args': ['easy_install']})
dist.parse_config_files()
opts = dist.get_option_dict('easy_install')
keep = (
'find_links',
'site_dirs',
'index_url',
'optimize',
'site_dirs',
'allow_hosts',
)
for key in list(opts):
if key not in keep:
del opts[key] # don't use any other settings
if self.dependency_links:
links = self.dependency_links[:]
if 'find_links' in opts:
links = opts['find_links'][1].split() + links
opts['find_links'] = ('setup', links)
if self.allow_hosts:
opts['allow_hosts'] = ('test', self.allow_hosts)
if self.index_url:
opts['index_url'] = ('test', self.index_url)
install_dir_func = getattr(self, 'get_egg_cache_dir', _os.getcwd)
install_dir = install_dir_func()
cmd = easy_install(
dist,
args=["x"],
install_dir=install_dir,
exclude_scripts=True,
always_copy=False,
build_directory=None,
editable=False,
upgrade=False,
multi_version=True,
no_report=True,
user=False,
)
cmd.ensure_finalized()
return cmd.easy_install(req)
class PyTest(orig.test):
"""
>>> import setuptools
>>> dist = setuptools.Distribution()
>>> cmd = PyTest(dist)
"""
user_options = [
('extras', None, "Install (all) setuptools extras when running tests"),
(
'index-url=',
None,
"Specify an index url from which to retrieve dependencies",
),
(
'allow-hosts=',
None,
"Whitelist of comma-separated hosts to allow "
"when retrieving dependencies",
),
(
'addopts=',
None,
"Additional options to be passed verbatim to the pytest runner",
),
]
def initialize_options(self):
self.extras = False
self.index_url = None
self.allow_hosts = None
self.addopts = []
self.ensure_setuptools_version()
@staticmethod
def ensure_setuptools_version():
"""
Due to the fact that pytest-runner is often required (via
setup-requires directive) by toolchains that never invoke
it (i.e. they're only installing the package, not testing it),
instead of declaring the dependency in the package
metadata, assert the requirement at run time.
"""
pkg_resources.require('setuptools>=27.3')
def finalize_options(self):
if self.addopts:
self.addopts = _shlex.split(self.addopts)
@staticmethod
def marker_passes(marker):
"""
Given an environment marker, return True if the marker is valid
and matches this environment.
"""
return (
not marker
or not pkg_resources.invalid_marker(marker)
and pkg_resources.evaluate_marker(marker)
)
def install_dists(self, dist):
"""
Extend install_dists to include extras support
"""
return _itertools.chain(
orig.test.install_dists(dist), self.install_extra_dists(dist)
)
def install_extra_dists(self, dist):
"""
Install extras that are indicated by markers or
install all extras if '--extras' is indicated.
"""
extras_require = dist.extras_require or {}
spec_extras = (
(spec.partition(':'), reqs) for spec, reqs in extras_require.items()
)
matching_extras = (
reqs
for (name, sep, marker), reqs in spec_extras
# include unnamed extras or all if self.extras indicated
if (not name or self.extras)
# never include extras that fail to pass marker eval
and self.marker_passes(marker)
)
results = list(map(dist.fetch_build_eggs, matching_extras))
return _itertools.chain.from_iterable(results)
@staticmethod
def _warn_old_setuptools():
msg = (
"pytest-runner will stop working on this version of setuptools; "
"please upgrade to setuptools 30.4 or later or pin to "
"pytest-runner < 5."
)
ver_str = pkg_resources.get_distribution('setuptools').version
ver = pkg_resources.parse_version(ver_str)
if ver < pkg_resources.parse_version('30.4'):
_warnings.warn(msg)
def run(self):
"""
Override run to ensure requirements are available in this session (but
don't install them anywhere).
"""
self._warn_old_setuptools()
dist = CustomizedDist()
for attr in 'allow_hosts index_url'.split():
setattr(dist, attr, getattr(self, attr))
for attr in (
'dependency_links install_requires tests_require extras_require '
).split():
setattr(dist, attr, getattr(self.distribution, attr))
installed_dists = self.install_dists(dist)
if self.dry_run:
self.announce('skipping tests (dry run)')
return
paths = map(_operator.attrgetter('location'), installed_dists)
with self.paths_on_pythonpath(paths):
with self.project_on_sys_path():
return self.run_tests()
@property
def _argv(self):
return ['pytest'] + self.addopts
def run_tests(self):
"""
Invoke pytest, replacing argv. Return result code.
"""
with _save_argv(_sys.argv[:1] + self.addopts):
result_code = __import__('pytest').main()
if result_code:
raise SystemExit(result_code)
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
include README.md
include LICENSE
recursive-exclude test *
recursive-include csrc *
Metadata-Version: 1.2
Name: torch_spline_conv
Version: 1.2.1
Summary: Implementation of the Spline-Based Convolution Operator of SplineCNN in PyTorch
Home-page: https://github.com/rusty1s/pytorch_spline_conv
Author: Matthias Fey
Author-email: matthias.fey@tu-dortmund.de
License: MIT
Description: UNKNOWN
Keywords: pytorch,geometric-deep-learning,graph-neural-networks,spline-cnn
Platform: UNKNOWN
Requires-Python: >=3.6
[pypi-image]: https://badge.fury.io/py/torch-spline-conv.svg
[pypi-url]: https://pypi.python.org/pypi/torch-spline-conv
[build-image]: https://travis-ci.org/rusty1s/pytorch_spline_conv.svg?branch=master
[build-url]: https://travis-ci.org/rusty1s/pytorch_spline_conv
[coverage-image]: https://codecov.io/gh/rusty1s/pytorch_spline_conv/branch/master/graph/badge.svg
[coverage-url]: https://codecov.io/github/rusty1s/pytorch_spline_conv?branch=master
# Spline-Based Convolution Operator of SplineCNN
[![PyPI Version][pypi-image]][pypi-url]
[![Build Status][build-image]][build-url]
[![Code Coverage][coverage-image]][coverage-url]
--------------------------------------------------------------------------------
This is a PyTorch implementation of the spline-based convolution operator of SplineCNN, as described in our paper:
Matthias Fey, Jan Eric Lenssen, Frank Weichert, Heinrich Müller: [SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels](https://arxiv.org/abs/1711.08920) (CVPR 2018)
The operator works on all floating point data types and is implemented both for CPU and GPU.
## Installation
### Binaries
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://pytorch-geometric.com/whl).
#### PyTorch 1.7.0
To install the binaries for PyTorch 1.7.0, simply run
```
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101`, `cu102`, or `cu110` depending on your PyTorch installation.
| | `cpu` | `cu92` | `cu101` | `cu102` | `cu110` |
|-------------|-------|--------|---------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ❌ | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | | |
#### PyTorch 1.6.0
To install the binaries for PyTorch 1.6.0, simply run
```
pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.6.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu101` or `cu102` depending on your PyTorch installation.
| | `cpu` | `cu92` | `cu101` | `cu102` |
|-------------|-------|--------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ❌ | ✅ | ✅ |
| **macOS** | ✅ | | | |
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0 and PyTorch 1.5.0 (following the same procedure).
### From source
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.4.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
$ echo $CPATH
>>> /usr/local/cuda/include:...
```
Then run:
```
pip install torch-spline-conv
```
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
## Usage
```python
from torch_spline_conv import spline_conv
out = spline_conv(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None)
```
Applies the spline-based convolution operator
<p align="center">
<img width="50%" src="https://user-images.githubusercontent.com/6945922/38684093-36d9c52e-3e6f-11e8-9021-db054223c6b9.png" />
</p>
over several node features of an input graph.
The kernel function is defined over the weighted B-spline tensor product basis, as shown below for different B-spline degrees.
<p align="center">
<img width="45%" src="https://user-images.githubusercontent.com/6945922/38685443-3a2a0c68-3e72-11e8-8e13-9ce9ad8fe43e.png" />
<img width="45%" src="https://user-images.githubusercontent.com/6945922/38685459-42b2bcae-3e72-11e8-88cc-4b61e41dbd93.png" />
</p>
### Parameters
* **x** *(Tensor)* - Input node features of shape `(number_of_nodes x in_channels)`.
* **edge_index** *(LongTensor)* - Graph edges, given by source and target indices, of shape `(2 x number_of_edges)`.
* **pseudo** *(Tensor)* - Edge attributes, ie. pseudo coordinates, of shape `(number_of_edges x number_of_edge_attributes)` in the fixed interval [0, 1].
* **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`.
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **norm** *(bool, optional)*: Whether to normalize output by node degree. (default: `True`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)
### Returns
* **out** *(Tensor)* - Out node features of shape `(number_of_nodes x out_channels)`.
### Example
```python
import torch
from torch_spline_conv import spline_conv
x = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
pseudo = torch.rand((6, 2), dtype=torch.float) # two-dimensional edge attributes
weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_channels x out_channels
kernel_size = torch.tensor([5, 5]) # 5 parameters in each edge dimension
is_open_spline = torch.tensor([1, 1], dtype=torch.uint8) # only use open B-splines
degree = 1 # B-spline degree of 1
norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias
out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, norm, root_weight, bias)
print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
```
## Cite
Please cite our paper if you use this code in your own work:
```
@inproceedings{Fey/etal/2018,
title={{SplineCNN}: Fast Geometric Deep Learning with Continuous {B}-Spline Kernels},
author={Fey, Matthias and Lenssen, Jan Eric and Weichert, Frank and M{\"u}ller, Heinrich},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2018},
}
```
## Running tests
```
python setup.py test
```
## C++ API
`torch-spline-conv` also offers a C++ API that contains C++ equivalent of python models.
```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install
```
#include <Python.h>
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#ifdef WITH_HIP
#include "hip/basis_hip.h"
#endif
#ifdef _WIN32
#ifdef WITH_HIP
PyMODINIT_FUNC PyInit__basis_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__basis_cpu(void) { return NULL; }
#endif
#endif
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (pseudo.device().is_cuda()) {
#ifdef WITH_HIP
return spline_basis_fw_cuda(pseudo, kernel_size, is_open_spline, degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_fw_cpu(pseudo, kernel_size, is_open_spline, degree);
}
}
torch::Tensor spline_basis_bw(torch::Tensor grad_basis, torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (grad_basis.device().is_cuda()) {
#ifdef WITH_HIP
return spline_basis_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_bw_cpu(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineBasis : public torch::autograd::Function<SplineBasis> {
public:
static variable_list forward(AutogradContext *ctx, Variable pseudo,
Variable kernel_size, Variable is_open_spline,
int64_t degree) {
ctx->saved_data["degree"] = degree;
auto result = spline_basis_fw(pseudo, kernel_size, is_open_spline, degree);
auto basis = std::get<0>(result), weight_index = std::get<1>(result);
ctx->save_for_backward({pseudo, kernel_size, is_open_spline});
ctx->mark_non_differentiable({weight_index});
return {basis, weight_index};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto degree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()};
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
pseudo = pseudo.contiguous();
auto result = SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_basis", &spline_basis);
#include "basis_cpu.h"
#include "utils.h"
template <typename scalar_t, int64_t degree> struct Basis {
static inline scalar_t forward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 1. - v - k_mod + 2. * v * k_mod;
} else if (degree == 2) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
} else if (degree == 3) {
if (k_mod == 0)
return (1. - v) * (1. - v) * (1. - v) / 6.;
else if (k_mod == 1)
return (3. * v * v * v - 6. * v * v + 4.) / 6.;
else if (k_mod == 2)
return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
else
return v * v * v / 6.;
} else {
return (scalar_t)-1.;
}
}
static inline scalar_t backward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 2 * k_mod - 1;
} else if (degree == 2) {
if (k_mod == 0)
return v - 1.;
else if (k_mod == 1)
return -2. * v + 1.;
else
return v;
} else if (degree == 3) {
if (k_mod == 0)
return (-v * v + 2. * v - 1.) / 2.;
else if (k_mod == 1)
return (3. * v * v - 4. * v) / 2.;
else if (k_mod == 2)
return (-3. * v * v + 2. * v + 1.) / 2.;
else
return v * v / 2.;
} else {
return (scalar_t)-1.;
}
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = (int64_t)(pow(degree + 1, D) + 0.5);
auto basis = at::empty({E, S}, pseudo.options());
auto weight_index = at::empty({E, S}, kernel_size.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
int64_t k, wi, wi_offset;
scalar_t b;
for (int64_t e = 0; e < E; e++) {
for (int64_t s = 0; s < S; s++) {
k = s, wi = 0, wi_offset = 1, b = (scalar_t)1.;
for (int64_t d = 0; d < D; d++) {
int64_t k_mod = k % (DEGREE + 1);
k /= DEGREE + 1;
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
wi += (((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset;
wi_offset *= kernel_size_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
b *= v;
}
basis_data[e * S + s] = b;
weight_index_data[e * S + s] = wi;
}
}
});
});
return std::make_tuple(basis, weight_index);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
CHECK_CPU(grad_basis);
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = grad_basis.size(1);
auto grad_pseudo = at::empty({E, D}, pseudo.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
scalar_t g, tmp;
for (int64_t e = 0; e < E; e++) {
for (int64_t d = 0; d < D; d++) {
g = (scalar_t)0.;
for (int64_t s = 0; s < S; s++) {
int64_t k_mod =
(s / (int64_t)(pow(DEGREE + 1, d) + 0.5)) % (DEGREE + 1);
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::backward(v, k_mod);
tmp = v;
for (int64_t d_it = 1; d_it < D; d_it++) {
int64_t d_new = d_it - (d >= d_it);
k_mod =
(s / (int64_t)(pow(DEGREE + 1, d_new) + 0.5)) % (DEGREE + 1);
v = pseudo_data[e * pseudo.stride(0) + d_new * pseudo.stride(1)];
v *=
kernel_size_data[d_new] - DEGREE * is_open_spline_data[d_new];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis_data[e * grad_basis.stride(0) +
s * grad_basis.stride(1)];
}
g *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
grad_pseudo_data[e * D + d] = g;
}
}
});
});
return grad_pseudo;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_DEGREE_TYPES(degree, ...) \
[&] { \
switch (degree) { \
case 1: { \
static constexpr int64_t DEGREE = 1; \
return __VA_ARGS__(); \
} \
case 2: { \
static constexpr int64_t DEGREE = 2; \
return __VA_ARGS__(); \
} \
case 3: { \
static constexpr int64_t DEGREE = 3; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Basis degree not implemented"); \
} \
}()
#include "weighting_cpu.h"
#include "utils.h"
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t v;
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
v = 0;
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
tmp *= b * x_data[e * x.stride(0) + m_in * x.stride(1)];
v += tmp;
}
}
out_data[e * M_out + m_out] = v;
}
}
});
return out;
}
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
grad_x_data[e * M_in + m_in] += g * b * w;
}
}
}
}
});
return grad_x;
}
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
}
}
}
}
});
return grad_weight;
}
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (int64_t s = 0; s < S; s++) {
scalar_t b = 0;
auto wi = weight_index_data[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
w *= x_data[e * x.stride(0) + m_in * x.stride(1)];
b += w;
}
grad_basis_data[e * S + s] += g * b;
}
}
}
});
return grad_basis;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || TORCH_HIP_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment