Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
b46459f4
Commit
b46459f4
authored
Aug 13, 2018
by
rusty1s
Browse files
cuda kernels
parent
57f5a26e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
308 additions
and
11 deletions
+308
-11
.travis.yml
.travis.yml
+1
-1
cpu/basis.cpp
cpu/basis.cpp
+4
-5
cpu/weighting.cpp
cpu/weighting.cpp
+2
-2
cuda/atomics.cuh
cuda/atomics.cuh
+15
-0
cuda/basis_kernel.cu
cuda/basis_kernel.cu
+26
-0
cuda/weighting.cpp
cuda/weighting.cpp
+60
-0
cuda/weighting_kernel.cu
cuda/weighting_kernel.cu
+190
-0
setup.py
setup.py
+4
-1
test/test_conv.py
test/test_conv.py
+1
-0
test/test_weighting.py
test/test_weighting.py
+1
-0
torch_spline_conv/weighting.py
torch_spline_conv/weighting.py
+4
-2
No files found.
.travis.yml
View file @
b46459f4
...
...
@@ -18,7 +18,7 @@ before_install:
-
export CXX="g++-4.9"
install
:
-
if [[ $TRAVIS_PYTHON_VERSION == 2.7 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp27-cp27mu-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install
http://download.pytorch.org/whl/cpu/torch
-0
.4.1-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cpu/torch.4.1-cp35-cp35m-linux_x86_64.whl; fi
-
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl; fi
-
pip install pycodestyle
-
pip install flake8
...
...
cpu/basis.cpp
View file @
b46459f4
...
...
@@ -142,12 +142,11 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \
auto d_
other
= d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_
other
) + 0.5)) % (M + 1); \
auto d_
new
= d_it - (d >= d_it);
\
k_mod = (s / (int64_t)(pow(M + 1, d_
new
) + 0.5)) % (M + 1);
\
v = pseudo_data[e * pseudo.stride(0) + \
d_other * pseudo.stride(1)]; \
v *= kernel_size_data[d_other] - \
M * is_open_spline_data[d_other]; \
d_new * pseudo.stride(1)]; \
v *= kernel_size_data[d_new] - M * is_open_spline_data[d_new]; \
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
tmp *= v; \
...
...
cpu/weighting.cpp
View file @
b46459f4
...
...
@@ -142,6 +142,6 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"weighting_fw"
,
&
weighting_fw
,
"Weighting Forward (CPU)"
);
m
.
def
(
"weighting_bw_x"
,
&
weighting_bw_x
,
"Weighting Backward X (CPU)"
);
m
.
def
(
"weighting_bw_w"
,
&
weighting_bw_w
,
"Weighting Backward W (CPU)"
);
m
.
def
(
"weighting_bw_b"
,
&
weighting_bw_b
,
"Weighting Backward B (CPU)"
);
m
.
def
(
"weighting_bw_w"
,
&
weighting_bw_w
,
"Weighting Backward W
eight
(CPU)"
);
m
.
def
(
"weighting_bw_b"
,
&
weighting_bw_b
,
"Weighting Backward B
asis
(CPU)"
);
}
cuda/atomics.cuh
0 → 100644
View file @
b46459f4
#pragma once
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static
inline
__device__
void
atomicAdd
(
double
*
address
,
double
val
)
{
unsigned
long
long
int
*
address_as_ull
=
(
unsigned
long
long
int
*
)
address
;
unsigned
long
long
int
old
=
*
address_as_ull
;
unsigned
long
long
int
assumed
;
do
{
assumed
=
old
;
old
=
atomicCAS
(
address_as_ull
,
assumed
,
__double_as_longlong
(
val
+
__longlong_as_double
(
assumed
)));
}
while
(
assumed
!=
old
);
}
#endif
cuda/basis_kernel.cu
View file @
b46459f4
...
...
@@ -185,6 +185,32 @@ template <typename scalar_t> struct BasisBackward {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; \
const size_t stride = blockDim.x * gridDim.x; \
for (ptrdiff_t i = index; i < NUMEL; i += stride) { \
int64_t e = i / GRAD_PSEUDO.sizes[1], d = i % GRAD_PSEUDO.sizes[1]; \
scalar_t g = 0, tmp; \
\
for (ptrdiff_t s = 0; s < GRAD_BASIS.sizes[1]; s++) { \
auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1); \
auto v = PSEUDO.data[e * PSEUDO.strides[0] + d * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
v -= floor(v); \
v = CODE; \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < GRAD_PSEUDO.sizes[1]; d_it++) { \
auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1); \
v = PSEUDO.data[e * pseudo.strides[0] + d_new * PSEUDO.strides[1]]; \
v *= KERNEL_SIZE[d_new] - M * IS_OPEN_SPLINE[d_new]; \
v -= floor(v); \
v = GRAD_CODE; \
tmp *= v; \
} \
g += tmp * \
GRAD_BASIS \
.data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \
} \
g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
GRAD_PSEUDO.data[e * GRAD_PSEUDO.sizes[1] + d] = g; \
} \
}()
...
...
cuda/weighting.cpp
0 → 100644
View file @
b46459f4
#include <torch/torch.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
weighting_fw_cuda
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
);
at
::
Tensor
weighting_bw_x_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
);
at
::
Tensor
weighting_bw_w_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
,
int64_t
K
);
at
::
Tensor
weighting_bw_b_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
weight_index
);
at
::
Tensor
weighting_fw
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
CHECK_CUDA
(
x
);
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
return
weighting_fw_cuda
(
x
,
weight
,
basis
,
weight_index
);
}
at
::
Tensor
weighting_bw_x
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
CHECK_CUDA
(
grad_out
);
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
return
weighting_bw_x_cuda
(
grad_out
,
weight
,
basis
,
weight_index
);
}
at
::
Tensor
weighting_bw_w
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
,
int64_t
K
)
{
CHECK_CUDA
(
grad_out
);
CHECK_CUDA
(
x
);
CHECK_CUDA
(
basis
);
CHECK_CUDA
(
weight_index
);
return
weighting_bw_w_cuda
(
grad_out
,
x
,
basis
,
weight_index
,
K
);
}
at
::
Tensor
weighting_bw_b
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
weight_index
)
{
CHECK_CUDA
(
grad_out
);
CHECK_CUDA
(
x
);
CHECK_CUDA
(
weight
);
CHECK_CUDA
(
weight_index
);
return
weighting_bw_b_cuda
(
grad_out
,
x
,
weight
,
weight_index
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"weighting_fw"
,
&
weighting_fw
,
"Weighting Forward (CUDA)"
);
m
.
def
(
"weighting_bw_x"
,
&
weighting_bw_x
,
"Weighting Backward X (CUDA)"
);
m
.
def
(
"weighting_bw_w"
,
&
weighting_bw_w
,
"Weighting Backward Weight (CUDA)"
);
m
.
def
(
"weighting_bw_b"
,
&
weighting_bw_b
,
"Weighting Backward Basis (CUDA)"
);
}
#define BLOCKS(N) (N + THREADS - 1) / THREADS
cuda/weighting_kernel.cu
0 → 100644
View file @
b46459f4
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
weighting_fw_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
out
.
sizes
[
1
],
m_out
=
i
%
out
.
sizes
[
1
];
auto
S
=
basis
.
sizes
[
1
];
scalar_t
v
=
0
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis
.
data
[
e
*
S
+
s
];
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
x
.
sizes
[
1
];
m_in
++
)
{
auto
tmp
=
weight
.
data
[
wi
*
weight
.
strides
[
0
]
+
m_in
*
weight
.
strides
[
1
]
+
m_out
*
weight
.
strides
[
2
]];
tmp
*=
b
*
x
.
data
[
e
*
x
.
strides
[
0
]
+
m_in
*
x
.
strides
[
1
]];
v
+=
tmp
;
}
}
out
.
data
[
e
*
out
.
sizes
[
1
]
+
m_out
]
=
v
;
}
}
at
::
Tensor
weighting_fw_cuda
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
auto
E
=
x
.
size
(
0
),
M_out
=
weight
.
size
(
2
);
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
type
());
AT_DISPATCH_FLOATING_TYPES
(
out
.
type
(),
"weighting_fw"
,
[
&
]
{
weighting_fw_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
out
.
numel
());
});
return
out
;
}
template
<
typename
scalar_t
>
__global__
void
weighting_bw_x_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
grad_x
.
sizes
[
1
],
m_in
=
i
%
grad_x
.
sizes
[
1
];
auto
S
=
basis
.
sizes
[
1
];
scalar_t
v
=
0
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis
.
data
[
e
*
S
+
s
];
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
grad_out
.
sizes
[
1
];
m_out
++
)
{
auto
tmp
=
weight
.
data
[
wi
*
weight
.
strides
[
0
]
+
m_out
*
weight
.
strides
[
1
]
+
m_in
*
weight
.
strides
[
2
]];
tmp
*=
b
*
grad_out
.
data
[
e
*
grad_out
.
strides
[
0
]
+
m_out
*
grad_out
.
strides
[
1
]];
v
+=
tmp
;
}
}
grad_x
.
data
[
e
*
grad_x
.
sizes
[
1
]
+
m_in
]
=
v
;
}
}
at
::
Tensor
weighting_bw_x_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
auto
E
=
grad_out
.
size
(
0
),
M_in
=
weight
.
size
(
1
);
auto
grad_x
=
at
::
empty
({
E
,
M_in
},
grad_out
.
type
());
weight
=
weight
.
transpose
(
1
,
2
).
contiguous
();
AT_DISPATCH_FLOATING_TYPES
(
grad_x
.
type
(),
"weighting_bw_x"
,
[
&
]
{
weighting_bw_x_kernel
<
scalar_t
><<<
BLOCKS
(
grad_x
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
grad_x
.
numel
());
});
return
grad_x
;
}
template
<
typename
scalar_t
>
__global__
void
weighting_bw_w_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_weight
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
basis
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
grad_out
.
sizes
[
1
],
m_out
=
i
%
grad_out
.
sizes
[
1
];
int64_t
S
=
basis
.
sizes
[
1
],
M_in
=
x
.
sizes
[
1
],
M_out
=
grad_out
.
sizes
[
1
];
auto
g
=
grad_out
.
data
[
e
*
grad_out
.
strides
[
0
]
+
m_out
*
grad_out
.
strides
[
1
]];
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis
.
data
[
e
*
S
+
s
];
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
v
=
g
*
b
*
x
.
data
[
e
*
x
.
strides
[
0
]
+
m_in
*
x
.
strides
[
1
]];
atomicAdd
(
&
grad_weight
.
data
[
wi
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
],
v
);
}
}
}
}
at
::
Tensor
weighting_bw_w_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
,
int64_t
K
)
{
auto
M_in
=
x
.
size
(
1
),
M_out
=
grad_out
.
size
(
1
);
auto
grad_weight
=
at
::
zeros
({
K
,
M_in
,
M_out
},
grad_out
.
type
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
type
(),
"weighting_bw_w"
,
[
&
]
{
weighting_bw_w_kernel
<
scalar_t
><<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
grad_out
.
numel
());
});
return
grad_weight
;
}
template
<
typename
scalar_t
>
__global__
void
weighting_bw_b_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_basis
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad_out
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
x
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
weight
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
weight_index
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
int64_t
e
=
i
/
grad_out
.
sizes
[
1
],
m_out
=
i
%
grad_out
.
sizes
[
1
];
auto
S
=
grad_basis
.
sizes
[
1
];
auto
g
=
grad_out
.
data
[
e
*
grad_out
.
strides
[
0
]
+
m_out
*
grad_out
.
strides
[
1
]];
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
scalar_t
v
=
0
;
auto
wi
=
weight_index
.
data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
x
.
sizes
[
1
];
m_in
++
)
{
auto
w
=
weight
.
data
[
wi
*
weight
.
strides
[
0
]
+
m_in
*
weight
.
strides
[
1
]
+
m_out
*
weight
.
strides
[
2
]];
v
+=
g
*
w
*
x
.
data
[
e
*
x
.
strides
[
0
]
+
m_in
*
x
.
strides
[
1
]];
}
atomicAdd
(
&
grad_basis
.
data
[
e
*
S
+
s
],
v
);
}
}
}
at
::
Tensor
weighting_bw_b_cuda
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
weight_index
)
{
auto
E
=
x
.
size
(
0
),
S
=
weight_index
.
size
(
1
);
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
type
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
type
(),
"weighting_bw_b"
,
[
&
]
{
weighting_bw_b_kernel
<
scalar_t
><<<
BLOCKS
(
grad_out
.
numel
()),
THREADS
>>>
(
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_basis
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad_out
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
x
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
weight
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
weight_index
),
grad_out
.
numel
());
});
return
grad_basis
;
}
setup.py
View file @
b46459f4
...
...
@@ -10,7 +10,10 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if
torch
.
cuda
.
is_available
():
ext_modules
+=
[
CUDAExtension
(
'basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
])
CUDAExtension
(
'basis_cuda'
,
[
'cuda/basis.cpp'
,
'cuda/basis_kernel.cu'
]),
CUDAExtension
(
'weighting_cuda'
,
[
'cuda/weighting.cpp'
,
'cuda/weighting_kernel.cu'
]),
]
__version__
=
'1.0.4'
...
...
test/test_conv.py
View file @
b46459f4
...
...
@@ -7,6 +7,7 @@ from torch_spline_conv import SplineConv
from
torch_spline_conv.basis
import
implemented_degrees
as
degrees
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cpu'
)]
tests
=
[{
'x'
:
[[
9
,
10
],
[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
]],
...
...
test/test_weighting.py
View file @
b46459f4
...
...
@@ -7,6 +7,7 @@ from torch_spline_conv.weighting import SplineWeighting
from
torch_spline_conv.basis
import
SplineBasis
from
.utils
import
dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
tests
=
[{
'x'
:
[[
1
,
2
],
[
3
,
4
]],
...
...
torch_spline_conv/weighting.py
View file @
b46459f4
import
torch
import
weighting_cpu
if
torch
.
cuda
.
is_available
():
import
weighting_cuda
def
get_func
(
name
,
tensor
):
# module = weighting_cuda if tensor.is_cuda else weighting_cpu
module
=
weighting_cpu
module
=
weighting_cuda
if
tensor
.
is_cuda
else
weighting_cpu
return
getattr
(
module
,
name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment