Commit 13dabd40 authored by rusty1s's avatar rusty1s
Browse files

graclus weight gpu bugfix

parent 7985cdd8
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
template<typename T> template<typename T>
struct THCNumerics { struct THCNumerics {
static inline __host__ __device__ T div(T a, T b) { return a / b; } static inline __host__ __device__ T div(T a, T b) { return a / b; }
static inline __host__ __device__ bool gt(T a, T b) { return a > b; } static inline __host__ __device__ bool gte(T a, T b) { return a >= b; }
}; };
#ifdef CUDA_HALF_TENSOR #ifdef CUDA_HALF_TENSOR
template<> template<>
struct THCNumerics<half> { struct THCNumerics<half> {
static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); } static inline __host__ __device__ half div(half a, half b) { return f2h(h2f(a) / h2f(b)); }
static inline __host__ __device__ bool gt(half a, half b) { return h2f(a) > h2f(b); } static inline __host__ __device__ bool gte(half a, half b) { return h2f(a) >= h2f(b); }
}; };
#endif // CUDA_HALF_TENSOR #endif // CUDA_HALF_TENSOR
......
...@@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro ...@@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro
tmp = weight[e]; tmp = weight[e];
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found. if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
// Find maximum weighted red neighbor. // Find maximum weighted red neighbor.
if (color[c] == -2 && THCNumerics<T>::gt(tmp, maxWeight)) { if (color[c] == -2 && THCNumerics<T>::gte(tmp, maxWeight)) {
matchedValue = c; matchedValue = c;
maxWeight = tmp; maxWeight = tmp;
} }
......
...@@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r ...@@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r
tmp = weight[e]; tmp = weight[e];
if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found. if (isDead && color[c] < 0) { isDead = false; } // Unmatched neighbor found.
// Find maximum weighted blue neighbor, who proposed to i. // Find maximum weighted blue neighbor, who proposed to i.
if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gt(tmp, maxWeight)) { if (color[c] == -1 && prop[c] == i && THCNumerics<T>::gte(tmp, maxWeight)) {
matchedValue = c; matchedValue = c;
maxWeight = tmp; maxWeight = tmp;
} }
......
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '1.0.2' __version__ = '1.0.3'
url = 'https://github.com/rusty1s/pytorch_cluster' url = 'https://github.com/rusty1s/pytorch_cluster'
install_requires = ['cffi'] install_requires = ['cffi']
......
from .graclus import graclus_cluster from .graclus import graclus_cluster
from .grid import grid_cluster from .grid import grid_cluster
__version__ = '1.0.2' __version__ = '1.0.3'
__all__ = ['graclus_cluster', 'grid_cluster', '__version__'] __all__ = ['graclus_cluster', 'grid_cluster', '__version__']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment