Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
878fa61f
Commit
878fa61f
authored
Mar 31, 2018
by
rusty1s
Browse files
greedy done
parent
a7c265bf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
34 deletions
+72
-34
aten/THC/THCGreedy.cu
aten/THC/THCGreedy.cu
+7
-34
aten/THC/THCPropose.cu
aten/THC/THCPropose.cu
+30
-0
aten/THC/THCResponse.cu
aten/THC/THCResponse.cu
+35
-0
No files found.
aten/THC/THCGreedy.cu
View file @
878fa61f
...
...
@@ -3,36 +3,12 @@
#include "common.cuh"
#include "THCDegree.cu"
#include "THCColor.cu"
__global__
void
proposeKernel
(
int64_t
*
tensor
,
int64_t
*
color
,
int64_t
*
row
,
int64_t
*
col
,
int64_t
*
deg
,
int64_t
*
cumDeg
,
ptrdiff_t
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
if
(
color
[
i
]
!=
-
1
)
continue
;
// Only visit blue nodes.
ptrdiff_t
c
;
for
(
ptrdiff_t
e
=
cumDeg
[
i
];
e
<
cumDeg
[
i
]
+
deg
[
i
];
e
++
)
{
c
=
col
[
e
];
if
(
color
[
c
]
==
-
2
)
{
tensor
[
i
]
=
c
;
break
;
}
// Propose to first red node.
}
if
(
tensor
[
i
]
<
0
)
color
[
i
]
=
i
;
// Mark node as dead.
}
}
void
THCGreedy_propose
(
THCState
*
state
,
THCudaLongTensor
*
tensor
,
THCudaLongTensor
*
color
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
deg
,
THCudaLongTensor
*
cumDeg
)
{
ptrdiff_t
nNodes
=
THCudaLongTensor_nElement
(
state
,
color
);
int64_t
*
tensorData
=
THCudaLongTensor_data
(
state
,
tensor
);
int64_t
*
colorData
=
THCudaLongTensor_data
(
state
,
color
);
int64_t
*
rowData
=
THCudaLongTensor_data
(
state
,
row
);
int64_t
*
colData
=
THCudaLongTensor_data
(
state
,
col
);
int64_t
*
degData
=
THCudaLongTensor_data
(
state
,
deg
);
int64_t
*
cumDegData
=
THCudaLongTensor_data
(
state
,
cumDeg
);
KERNEL_RUN
(
proposeKernel
,
nNodes
,
tensorData
,
colorData
,
rowData
,
colData
,
degData
,
cumDegData
);
}
#include "THCPropose.cu"
#include "THCResponse.cu"
void
THCGreedy
(
THCState
*
state
,
THCudaLongTensor
*
cluster
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
)
{
THCAssertSameGPU
(
THCudaLongTensor_checkGPU
(
state
,
4
,
cluster
,
row
,
col
));
THCAssertSameGPU
(
THCudaLongTensor_checkGPU
(
state
,
3
,
cluster
,
row
,
col
));
int
nNodes
=
THCudaLongTensor_nElement
(
state
,
cluster
);
...
...
@@ -45,14 +21,11 @@ void THCGreedy(THCState *state, THCudaLongTensor *cluster, THCudaLongTensor *row
THCudaLongTensor
*
cumDeg
=
THCudaLongTensor_newWithSize1d
(
state
,
nNodes
);
THCudaLongTensor_cumsum
(
state
,
cumDeg
,
deg
,
0
);
THCGreedy_assignColor
(
state
,
cluster
);
THCGreedy_propose
(
state
,
prop
,
cluster
,
row
,
col
,
deg
,
cumDeg
);
while
(
!
THCGreedy_assignColor
(
state
,
cluster
))
{
THCGreedy_propose
(
state
,
cluster
,
prop
,
row
,
col
,
deg
,
cumDeg
);
THCGreedy_response
(
state
,
cluster
,
prop
,
row
,
col
,
deg
,
cumDeg
);
};
/* while(!THCGreedy_assignColor(state, cluster)) { */
/* printf("DRIN"); */
// call propose step
// call response step
/* }; */
THCudaLongTensor_free
(
state
,
prop
);
THCudaLongTensor_free
(
state
,
deg
);
THCudaLongTensor_free
(
state
,
cumDeg
);
...
...
aten/THC/THCPropose.cu
0 → 100644
View file @
878fa61f
#include "common.cuh"
__global__
void
proposeKernel
(
int64_t
*
color
,
int64_t
*
prop
,
int64_t
*
row
,
int64_t
*
col
,
int64_t
*
deg
,
int64_t
*
cumDeg
,
ptrdiff_t
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
if
(
color
[
i
]
!=
-
1
)
continue
;
// Only visit blue nodes.
ptrdiff_t
c
;
for
(
ptrdiff_t
e
=
cumDeg
[
i
]
-
deg
[
i
];
e
<
cumDeg
[
i
];
e
++
)
{
c
=
col
[
e
];
if
(
color
[
c
]
==
-
2
)
{
// Red neighbor found.
prop
[
i
]
=
c
;
// Propose!
break
;
}
}
if
(
prop
[
i
]
<
0
)
color
[
i
]
=
i
;
// Mark node as dead.
}
}
void
THCGreedy_propose
(
THCState
*
state
,
THCudaLongTensor
*
color
,
THCudaLongTensor
*
prop
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
deg
,
THCudaLongTensor
*
cumDeg
)
{
ptrdiff_t
nNodes
=
THCudaLongTensor_nElement
(
state
,
color
);
int64_t
*
colorData
=
THCudaLongTensor_data
(
state
,
color
);
int64_t
*
propData
=
THCudaLongTensor_data
(
state
,
prop
);
int64_t
*
rowData
=
THCudaLongTensor_data
(
state
,
row
);
int64_t
*
colData
=
THCudaLongTensor_data
(
state
,
col
);
int64_t
*
degData
=
THCudaLongTensor_data
(
state
,
deg
);
int64_t
*
cumDegData
=
THCudaLongTensor_data
(
state
,
cumDeg
);
KERNEL_RUN
(
proposeKernel
,
nNodes
,
colorData
,
propData
,
rowData
,
colData
,
degData
,
cumDegData
);
}
aten/THC/THCResponse.cu
0 → 100644
View file @
878fa61f
#include "common.cuh"
__global__
void
responseKernel
(
int64_t
*
color
,
int64_t
*
prop
,
int64_t
*
row
,
int64_t
*
col
,
int64_t
*
deg
,
int64_t
*
cumDeg
,
ptrdiff_t
nNodes
)
{
KERNEL_LOOP
(
i
,
nNodes
)
{
if
(
color
[
i
]
!=
-
2
)
continue
;
// Only visit red nodes.
ptrdiff_t
c
;
int64_t
neighborColor
,
minValue
;
bool
isDead
=
true
;
for
(
ptrdiff_t
e
=
cumDeg
[
i
]
-
deg
[
i
];
e
<
cumDeg
[
i
];
e
++
)
{
c
=
col
[
e
];
neighborColor
=
color
[
c
];
if
(
neighborColor
==
-
1
&&
prop
[
c
]
==
i
)
{
// Blue neighbor found which proposed to node i.
minValue
=
min
(
i
,
c
);
color
[
i
]
=
minValue
;
color
[
c
]
=
minValue
;
break
;
}
if
(
neighborColor
<
0
)
isDead
=
false
;
}
if
(
isDead
&&
color
[
i
]
<
0
)
color
[
i
]
=
i
;
// Mark node as dead.
}
}
void
THCGreedy_response
(
THCState
*
state
,
THCudaLongTensor
*
color
,
THCudaLongTensor
*
prop
,
THCudaLongTensor
*
row
,
THCudaLongTensor
*
col
,
THCudaLongTensor
*
deg
,
THCudaLongTensor
*
cumDeg
)
{
ptrdiff_t
nNodes
=
THCudaLongTensor_nElement
(
state
,
color
);
int64_t
*
colorData
=
THCudaLongTensor_data
(
state
,
color
);
int64_t
*
propData
=
THCudaLongTensor_data
(
state
,
prop
);
int64_t
*
rowData
=
THCudaLongTensor_data
(
state
,
row
);
int64_t
*
colData
=
THCudaLongTensor_data
(
state
,
col
);
int64_t
*
degData
=
THCudaLongTensor_data
(
state
,
deg
);
int64_t
*
cumDegData
=
THCudaLongTensor_data
(
state
,
cumDeg
);
KERNEL_RUN
(
responseKernel
,
nNodes
,
colorData
,
propData
,
rowData
,
colData
,
degData
,
cumDegData
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment