Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
13dabd40
Commit
13dabd40
authored
Apr 18, 2018
by
rusty1s
Browse files
graclus weight gpu bugfix
parent
7985cdd8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
6 additions
and
6 deletions
+6
-6
aten/THC/THCNumerics.cuh
aten/THC/THCNumerics.cuh
+2
-2
aten/THC/THCPropose.cuh
aten/THC/THCPropose.cuh
+1
-1
aten/THC/THCResponse.cuh
aten/THC/THCResponse.cuh
+1
-1
setup.py
setup.py
+1
-1
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
No files found.
aten/THC/THCNumerics.cuh
View file @
13dabd40
...
@@ -16,14 +16,14 @@
...
@@ -16,14 +16,14 @@
template
<
typename
T
>
template
<
typename
T
>
struct
THCNumerics
{
struct
THCNumerics
{
static
inline
__host__
__device__
T
div
(
T
a
,
T
b
)
{
return
a
/
b
;
}
static
inline
__host__
__device__
T
div
(
T
a
,
T
b
)
{
return
a
/
b
;
}
static
inline
__host__
__device__
bool
gt
(
T
a
,
T
b
)
{
return
a
>
b
;
}
static
inline
__host__
__device__
bool
gt
e
(
T
a
,
T
b
)
{
return
a
>
=
b
;
}
};
};
#ifdef CUDA_HALF_TENSOR
#ifdef CUDA_HALF_TENSOR
template
<
>
template
<
>
struct
THCNumerics
<
half
>
{
struct
THCNumerics
<
half
>
{
static
inline
__host__
__device__
half
div
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
/
h2f
(
b
));
}
static
inline
__host__
__device__
half
div
(
half
a
,
half
b
)
{
return
f2h
(
h2f
(
a
)
/
h2f
(
b
));
}
static
inline
__host__
__device__
bool
gt
(
half
a
,
half
b
)
{
return
h2f
(
a
)
>
h2f
(
b
);
}
static
inline
__host__
__device__
bool
gt
e
(
half
a
,
half
b
)
{
return
h2f
(
a
)
>
=
h2f
(
b
);
}
};
};
#endif // CUDA_HALF_TENSOR
#endif // CUDA_HALF_TENSOR
...
...
aten/THC/THCPropose.cuh
View file @
13dabd40
...
@@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro
...
@@ -32,7 +32,7 @@ __global__ void weightedProposeKernel(int64_t *color, int64_t *prop, int64_t *ro
tmp
=
weight
[
e
];
tmp
=
weight
[
e
];
if
(
isDead
&&
color
[
c
]
<
0
)
{
isDead
=
false
;
}
// Unmatched neighbor found.
if
(
isDead
&&
color
[
c
]
<
0
)
{
isDead
=
false
;
}
// Unmatched neighbor found.
// Find maximum weighted red neighbor.
// Find maximum weighted red neighbor.
if
(
color
[
c
]
==
-
2
&&
THCNumerics
<
T
>::
gt
(
tmp
,
maxWeight
))
{
if
(
color
[
c
]
==
-
2
&&
THCNumerics
<
T
>::
gt
e
(
tmp
,
maxWeight
))
{
matchedValue
=
c
;
matchedValue
=
c
;
maxWeight
=
tmp
;
maxWeight
=
tmp
;
}
}
...
...
aten/THC/THCResponse.cuh
View file @
13dabd40
...
@@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r
...
@@ -35,7 +35,7 @@ __global__ void weightedResponseKernel(int64_t *color, int64_t *prop, int64_t *r
tmp
=
weight
[
e
];
tmp
=
weight
[
e
];
if
(
isDead
&&
color
[
c
]
<
0
)
{
isDead
=
false
;
}
// Unmatched neighbor found.
if
(
isDead
&&
color
[
c
]
<
0
)
{
isDead
=
false
;
}
// Unmatched neighbor found.
// Find maximum weighted blue neighbor, who proposed to i.
// Find maximum weighted blue neighbor, who proposed to i.
if
(
color
[
c
]
==
-
1
&&
prop
[
c
]
==
i
&&
THCNumerics
<
T
>::
gt
(
tmp
,
maxWeight
))
{
if
(
color
[
c
]
==
-
1
&&
prop
[
c
]
==
i
&&
THCNumerics
<
T
>::
gt
e
(
tmp
,
maxWeight
))
{
matchedValue
=
c
;
matchedValue
=
c
;
maxWeight
=
tmp
;
maxWeight
=
tmp
;
}
}
...
...
setup.py
View file @
13dabd40
...
@@ -2,7 +2,7 @@ from os import path as osp
...
@@ -2,7 +2,7 @@ from os import path as osp
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
__version__
=
'1.0.
2
'
__version__
=
'1.0.
3
'
url
=
'https://github.com/rusty1s/pytorch_cluster'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'cffi'
]
install_requires
=
[
'cffi'
]
...
...
torch_cluster/__init__.py
View file @
13dabd40
from
.graclus
import
graclus_cluster
from
.graclus
import
graclus_cluster
from
.grid
import
grid_cluster
from
.grid
import
grid_cluster
__version__
=
'1.0.
2
'
__version__
=
'1.0.
3
'
__all__
=
[
'graclus_cluster'
,
'grid_cluster'
,
'__version__'
]
__all__
=
[
'graclus_cluster'
,
'grid_cluster'
,
'__version__'
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment