Commit 9da489a5 authored by rusty1s's avatar rusty1s
Browse files

serial define macro

parent 205d7a9b
...@@ -3,6 +3,20 @@ ...@@ -3,6 +3,20 @@
#else #else
void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THTensor *weight, THLongTensor *degree) { void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THTensor *weight, THLongTensor *degree) {
real *weight_data = weight->storage->data + weight->storageOffset;
real max_weight, w;
int64_t d, c;
SERIAL(output, row, col, degree,
max_weight = 0;
for (d = 0; d < degree_data[row_value]; d++) {
c = col_data[e + d];
w = weight_data[e + d];
if (output_data[c] < 0 && w >= max_weight) {
col_value = c;
max_weight = w;
}
}
)
} }
#endif #endif
......
...@@ -2,42 +2,42 @@ ...@@ -2,42 +2,42 @@
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real) #define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real)
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) { #define SERIAL(output, row, col, degree, SELECT) { \
int64_t *output_data = output->storage->data + output->storageOffset; int64_t *output_data = output->storage->data + output->storageOffset; \
int64_t *row_data = row->storage->data + row->storageOffset; int64_t *row_data = row->storage->data + row->storageOffset; \
int64_t *col_data = col->storage->data + col->storageOffset; int64_t *col_data = col->storage->data + col->storageOffset; \
int64_t *degree_data = degree->storage->data + degree->storageOffset; int64_t *degree_data = degree->storage->data + degree->storageOffset; \
\
int64_t e = 0, row_value, col_value, i, value; int64_t e = 0, row_value, col_value, value; \
while(e < THLongTensor_nElement(row)) { \
while(e < THLongTensor_nElement(row)) { row_value = row_data[e]; \
row_value = row_data[e]; if (output_data[row_value] >= 0) { \
if (output_data[row_value] < 0) { // Node is unmatched. col_value = -1; \
SELECT \
if (col_value < 0) { \
output_data[row_value] = row_value; \
} \
else { \
value = row_value < col_value ? row_value : col_value; \
output_data[row_value] = value; \
output_data[col_value] = value; \
} \
} \
e += degree_data[row_value]; \
} \
}
// Find next unmatched neighbor. void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
col_value = -1; int64_t d, c;
for (i = 0; i < degree_data[row_value]; i++) { SERIAL(output, row, col, degree,
value = col_data[e + i]; for (d = 0; d < degree_data[row_value]; d++) {
if (output_data[value] < 0) { // Neighbor found. Save and abort. c = col_data[e + d];
col_value = value; if (output_data[c] < 0) {
col_value = c;
break; break;
} }
} }
)
// Set cluster output for new matched nodes (one or two).
if (col_value < 0) {
output_data[row_value] = row_value;
}
else {
i = row_value < col_value ? row_value : col_value;
output_data[row_value] = i;
output_data[col_value] = i;
}
}
// Jump to next row.
e += degree_data[row_value];
}
} }
#include "generic/serial_cpu.c" #include "generic/serial_cpu.c"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment