Commit 70d03142 authored by limm's avatar limm
Browse files

support v1.2.2

parent 378d2b88
#!/bin/bash
if [ "${TRAVIS_OS_NAME}" = "linux" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/home/travis/miniconda3/bin:${PATH}
fi
if [ "${TRAVIS_OS_NAME}" = "osx" ]; then
wget -nv https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh
chmod +x miniconda.sh
./miniconda.sh -b
PATH=/Users/travis/miniconda3/bin:${PATH}
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ]; then
choco install openssl.light
choco install miniconda3
PATH=/c/tools/miniconda3/Scripts:$PATH
fi
conda update --yes conda
conda create --yes -n test python="${PYTHON_VERSION}"
#!/bin/bash
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=cpuonly
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu92" ]; then
export CUDA_SHORT=9.2
export CUDA=9.2.148-1
export UBUNTU_VERSION=ubuntu1604
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu101" ]; then
export IDX=cu101
export CUDA_SHORT=10.1
export CUDA=10.1.243-1
export UBUNTU_VERSION=ubuntu1804
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu102" ]; then
export CUDA_SHORT=10.2
export CUDA=10.2.89-1
export UBUNTU_VERSION=ubuntu1804
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "$IDX" = "cu110" ]; then
export CUDA_SHORT=11.0
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=cpuonly
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu92" ]; then
export CUDA_SHORT=9.2
export CUDA_URL=https://developer.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod2/local_installers2
export CUDA_FILE=cuda_${CUDA_SHORT}.148_win10
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu101" ]; then
export CUDA_SHORT=10.1
export CUDA_URL=https://developer.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers
export CUDA_FILE=cuda_${CUDA_SHORT}.105_418.96_win10.exe
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu102" ]; then
export CUDA_SHORT=10.2
export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers
export CUDA_FILE=cuda_${CUDA_SHORT}.89_441.22_win10.exe
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "$IDX" = "cu110" ]; then
export CUDA_SHORT=11.0
export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.2/local_installers
export CUDA_FILE=cuda_${CUDA_SHORT}.2_451.48_win10.exe
export TOOLKIT="cudatoolkit=${CUDA_SHORT}"
fi
if [ "${TRAVIS_OS_NAME}" = "osx" ] && [ "$IDX" = "cpu" ]; then
export TOOLKIT=""
fi
if [ "${IDX}" = "cpu" ]; then
export FORCE_ONLY_CPU=1
else
export FORCE_CUDA=1
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "${IDX}" != "cpu" ] && [ "${IDX}" != "cu110" ]; then
INSTALLER="cuda-repo-${UBUNTU_VERSION}_${CUDA}_amd64.deb"
wget -nv "http://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/${INSTALLER}"
sudo dpkg -i "${INSTALLER}"
wget -nv "https://developer.download.nvidia.com/compute/cuda/repos/${UBUNTU_VERSION}/x86_64/7fa2af80.pub"
sudo apt-key add 7fa2af80.pub
sudo apt update -qq
sudo apt install "cuda-core-${CUDA_SHORT/./-}" "cuda-nvcc-${CUDA_SHORT/./-}" "cuda-libraries-dev-${CUDA_SHORT/./-}"
sudo apt clean
CUDA_HOME=/usr/local/cuda-${CUDA_SHORT}
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}
nvcc --version
# Fix cublas on CUDA 10.1:
if [ -d "/usr/local/cuda-10.2/targets/x86_64-linux/include" ]; then
sudo cp -r /usr/local/cuda-10.2/targets/x86_64-linux/include/* "${CUDA_HOME}/include/"
fi
if [ -d "/usr/local/cuda-10.2/targets/x86_64-linux/lib" ]; then
sudo cp -r /usr/local/cuda-10.2/targets/x86_64-linux/lib/* "${CUDA_HOME}/lib/"
fi
fi
if [ "${TRAVIS_OS_NAME}" = "linux" ] && [ "${IDX}" = "cu110" ]; then
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-ubuntu1804.pin
sudo mv cuda-ubuntu1804.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda-repo-ubuntu1804-11-0-local_11.0.3-450.51.06-1_amd64.deb
sudo dpkg -i cuda-repo-ubuntu1804-11-0-local_11.0.3-450.51.06-1_amd64.deb
sudo apt-key add /var/cuda-repo-ubuntu1804-11-0-local/7fa2af80.pub
sudo apt update -qq
sudo apt install cuda-nvcc-11-0 cuda-libraries-dev-11-0
sudo apt clean
CUDA_HOME=/usr/local/cuda-${CUDA_SHORT}
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}
nvcc --version
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${IDX}" != "cpu" ]; then
# Install NVIDIA drivers, see:
# https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102
curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip"
7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32"
# Install CUDA:
wget -nv "${CUDA_URL}/${CUDA_FILE}"
PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow"
CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v${CUDA_SHORT}
PATH=${CUDA_HOME}/bin:$PATH
PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH
nvcc --version
fi
#!/bin/bash
if [ "${TRAVIS_OS_NAME}" = "linux" ]; then
sudo add-apt-repository ppa:ubuntu-toolchain-r/test --yes
sudo apt update
sudo apt install gcc-7 g++-7 --yes
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 \
--slave /usr/bin/g++ g++ /usr/bin/g++-7
sudo update-alternatives --config gcc
gcc --version
g++ --version
fi
#!/bin/bash
# https://github.com/pytorch/pytorch/commit/d2e16dd888a9b5fd55bd475d4fcffb70f388d4f0
if [ "${TRAVIS_OS_NAME}" = "windows" ]; then
echo "Fix nvcc for PyTorch"
sed -i.bak -e 's/CONSTEXPR_EXCEPT_WIN_CUDA/const/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/api/module.h
sed -i.bak -e 's/return \*(this->value)/return \*((type\*)this->value)/g' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/pybind11/cast.h
fi
if [ "${TRAVIS_OS_NAME}" = "windows" ] && [ "${TORCH_VERSION}" = "1.7.0" ]; then
echo "Fix nvcc for PyTorch 1.7.0"
sed -i.bak '/static constexpr Symbol Kind/d' /c/tools/miniconda3/envs/test/lib/site-packages/torch/include/torch/csrc/jit/ir/ir.h
fi
[metadata]
description-file = README.md
long_description=file: README.md
long_description_content_type=text/markdown
classifiers =
Development Status :: 5 - Production/Stable
License :: OSI Approved :: MIT License
Programming Language :: Python
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3 :: Only
[aliases]
test=pytest
test = pytest
[tool:pytest]
addopts = --cov
addopts = --capture=no
import os
import glob
import os
import os.path as osp
import platform
import sys
from itertools import product
from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
from setuptools import find_packages, setup
from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)
__version__ = '1.2.2'
URL = 'https://github.com/rusty1s/pytorch_spline_conv'
WITH_CUDA = False
if torch.cuda.is_available():
WITH_CUDA = CUDA_HOME is not None or torch.version.hip
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu']
if os.getenv('FORCE_CUDA', '0') == '1':
suffices = ['cuda', 'cpu']
......@@ -23,20 +32,48 @@ BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1'
def get_extensions():
extensions = []
extensions_dir = osp.join(osp.dirname(osp.abspath(__file__)), 'csrc')
extensions_dir = osp.join('csrc')
main_files = glob.glob(osp.join(extensions_dir, '*.cpp'))
# remove generated 'hip' files, in case of rebuilds
main_files = [path for path in main_files if 'hip' not in path]
for main, suffix in product(main_files, suffices):
define_macros = []
undef_macros = []
extra_compile_args = {'cxx': ['-O2']}
if not os.name == 'nt': # Not on Windows:
extra_compile_args['cxx'] += ['-Wno-sign-compare']
extra_link_args = ['-s']
info = parallel_info()
if ('backend: OpenMP' in info and 'OpenMP not found' not in info
and sys.platform != 'darwin'):
extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP']
if sys.platform == 'win32':
extra_compile_args['cxx'] += ['/openmp']
else:
extra_compile_args['cxx'] += ['-fopenmp']
else:
print('Compiling without OpenMP...')
# Compile for mac arm64
if (sys.platform == 'darwin' and platform.machine() == 'arm64'):
extra_compile_args['cxx'] += ['-arch', 'arm64']
extra_link_args += ['-arch', 'arm64']
if suffix == 'cuda':
define_macros += [('WITH_CUDA', None)]
nvcc_flags = os.getenv('NVCC_FLAGS', '')
nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ')
nvcc_flags += ['-arch=sm_35', '--expt-relaxed-constexpr', '-O2']
nvcc_flags += ['-O2']
extra_compile_args['nvcc'] = nvcc_flags
if torch.version.hip:
# USE_ROCM was added to later versions of PyTorch
# Define here to support older PyTorch versions as well:
define_macros += [('USE_ROCM', None)]
undef_macros += ['__HIP_NO_HALF_CONVERSIONS__']
else:
nvcc_flags += ['--expt-relaxed-constexpr']
name = main.split(os.sep)[-1][:-4]
sources = [main]
......@@ -55,6 +92,7 @@ def get_extensions():
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
undef_macros=undef_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
......@@ -64,31 +102,42 @@ def get_extensions():
install_requires = []
setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov']
test_requires = [
'pytest',
'pytest-cov',
]
# work-around hipify abs paths
include_package_data = True
if torch.cuda.is_available() and torch.version.hip:
include_package_data = False
setup(
name='torch_spline_conv',
version='1.2.1',
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_spline_conv',
version=__version__,
description=('Implementation of the Spline-Based Convolution Operator of '
'SplineCNN in PyTorch'),
author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de',
url=URL,
download_url=f'{URL}/archive/{__version__}.tar.gz',
keywords=[
'pytorch',
'geometric-deep-learning',
'graph-neural-networks',
'spline-cnn',
],
license='MIT',
python_requires='>=3.6',
python_requires='>=3.7',
install_requires=install_requires,
setup_requires=setup_requires,
tests_require=tests_require,
extras_require={
'test': test_requires,
},
ext_modules=get_extensions() if not BUILD_DOCS else [],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
'build_ext':
BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False)
},
packages=find_packages(),
include_package_data=include_package_data,
)
......@@ -3,8 +3,7 @@ from itertools import product
import pytest
import torch
from torch_spline_conv import spline_basis
from .utils import dtypes, devices, tensor
from torch_spline_conv.testing import devices, dtypes, tensor
tests = [{
'pseudo': [[0], [0.0625], [0.25], [0.75], [0.9375], [1]],
......@@ -29,12 +28,18 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_basis_forward(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
pseudo = tensor(test['pseudo'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], dtype, device)
degree = 1
basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
degree)
assert basis.tolist() == test['basis']
assert weight_index.tolist() == test['weight_index']
assert torch.allclose(basis, basis)
assert torch.allclose(weight_index, weight_index)
......@@ -4,8 +4,7 @@ import pytest
import torch
from torch.autograd import gradcheck
from torch_spline_conv import spline_conv
from .utils import dtypes, devices, tensor
from torch_spline_conv.testing import devices, dtypes, tensor
degrees = [1, 2, 3]
......@@ -43,6 +42,9 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_conv_forward(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
x = tensor(test['x'], dtype, device)
edge_index = tensor(test['edge_index'], torch.long, device)
pseudo = tensor(test['pseudo'], dtype, device)
......@@ -51,14 +53,17 @@ def test_spline_conv_forward(test, dtype, device):
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device)
expected = tensor(test['expected'], dtype, device)
out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected']
error = 1e-2 if dtype == torch.bfloat16 else 1e-7
assert torch.allclose(out, expected, rtol=error, atol=error)
@pytest.mark.parametrize('degree,device', product(degrees, devices))
def test_spline_basis_backward(degree, device):
def test_spline_conv_backward(degree, device):
x = torch.rand((3, 2), dtype=torch.double, device=device)
x.requires_grad_()
edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
......
......@@ -4,8 +4,7 @@ import pytest
import torch
from torch.autograd import gradcheck
from torch_spline_conv import spline_basis, spline_weighting
from .utils import dtypes, devices, tensor
from torch_spline_conv.testing import devices, dtypes, tensor
tests = [{
'x': [[1, 2], [3, 4]],
......@@ -21,13 +20,17 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_weighting_forward(test, dtype, device):
if dtype == torch.bfloat16 and device == torch.device('cuda:0'):
return
x = tensor(test['x'], dtype, device)
weight = tensor(test['weight'], dtype, device)
basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
out = spline_weighting(x, weight, basis, weight_index)
assert out.tolist() == test['expected']
assert torch.allclose(out, expected)
@pytest.mark.parametrize('device', devices)
......
Metadata-Version: 2.1
Name: torch-spline-conv
Version: 1.2.1
Summary: Implementation of the Spline-Based Convolution Operator of SplineCNN in PyTorch
Home-page: https://github.com/rusty1s/pytorch_spline_conv
Author: Matthias Fey
Author-email: matthias.fey@tu-dortmund.de
License: MIT
Keywords: pytorch,geometric-deep-learning,graph-neural-networks,spline-cnn
Requires-Python: >=3.6
License-File: LICENSE
LICENSE
MANIFEST.in
README.md
setup.cfg
setup.py
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/basis.cpp
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/version.cpp
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/weighting.cpp
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/cpu/basis_cpu.cpp
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/cpu/weighting_cpu.cpp
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/hip/basis_hip_hip.hip
/work/home/quyuanhao123/software/test_ocp/torch_spline_conv-1.2.1/csrc/hip/weighting_hip_hip.hip
csrc/basis.cpp
csrc/spline_conv.h
csrc/version.cpp
csrc/weighting.cpp
csrc/cpu/basis_cpu.cpp
csrc/cpu/basis_cpu.h
csrc/cpu/utils.h
csrc/cpu/weighting_cpu.cpp
csrc/cpu/weighting_cpu.h
csrc/hip/atomics.cuh
csrc/hip/basis_hip.h
csrc/hip/basis_hip.hip
csrc/hip/basis_hip_hip.hip
csrc/hip/utils.cuh
csrc/hip/weighting_hip.h
csrc/hip/weighting_hip.hip
csrc/hip/weighting_hip_hip.hip
torch_spline_conv/__init__.py
torch_spline_conv/basis.py
torch_spline_conv/conv.py
torch_spline_conv/weighting.py
torch_spline_conv.egg-info/PKG-INFO
torch_spline_conv.egg-info/SOURCES.txt
torch_spline_conv.egg-info/dependency_links.txt
torch_spline_conv.egg-info/top_level.txt
\ No newline at end of file
......@@ -3,20 +3,23 @@ import os.path as osp
import torch
__version__ = '1.2.1'
suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
__version__ = '1.2.2'
for library in ['_version', '_basis', '_weighting']:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
if torch.cuda.is_available(): # pragma: no cover
cuda_version = torch.ops.torch_spline_conv.cuda_version()
if cuda_version == -1:
major = minor = 0
elif cuda_version < 10000:
cuda_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cuda', [osp.dirname(__file__)])
cpu_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cpu', [osp.dirname(__file__)])
spec = cuda_spec or cpu_spec
if spec is not None:
torch.ops.load_library(spec.origin)
else: # pragma: no cover
raise ImportError(f"Could not find module '{library}_cpu' in "
f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_spline_conv.cuda_version()
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
......@@ -31,8 +34,8 @@ if torch.cuda.is_available(): # pragma: no cover
f'matches your PyTorch install.')
from .basis import spline_basis # noqa
from .weighting import spline_weighting # noqa
from .conv import spline_conv # noqa
from .weighting import spline_weighting # noqa
__all__ = [
'spline_basis',
......
......@@ -69,10 +69,10 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
# Weight root node separately (if wished).
if root_weight is not None:
out = out + torch.matmul(x, root_weight)
out += x @ root_weight
# Add bias (if wished).
if bias is not None:
out = out + bias
out += bias
return out
from typing import Any
import torch
dtypes = [torch.float, torch.double]
dtypes = [torch.float, torch.double, torch.bfloat16]
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices += [torch.device(f'cuda:{torch.cuda.current_device()}')]
devices += [torch.device('cuda:0')]
def tensor(x, dtype, device):
def tensor(x: Any, dtype: torch.dtype, device: torch.device):
return None if x is None else torch.tensor(x, dtype=dtype, device=device)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment