Commit 9da489a5 authored by rusty1s's avatar rusty1s
Browse files

serial define macro

parent 205d7a9b
......@@ -2,7 +2,21 @@
#define TH_GENERIC_FILE "generic/serial_cpu.c"
#else
void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THTensor *weight, THLongTensor *degree) {
void cluster_(serial)(THLongTensor *output, THLongTensor *row, THLongTensor *col, THTensor *weight, THLongTensor *degree) {
real *weight_data = weight->storage->data + weight->storageOffset;
real max_weight, w;
int64_t d, c;
SERIAL(output, row, col, degree,
max_weight = 0;
for (d = 0; d < degree_data[row_value]; d++) {
c = col_data[e + d];
w = weight_data[e + d];
if (output_data[c] < 0 && w >= max_weight) {
col_value = c;
max_weight = w;
}
}
)
}
#endif
......
......@@ -2,42 +2,42 @@
#define cluster_(NAME) TH_CONCAT_4(cluster_, NAME, _, Real)
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
int64_t *output_data = output->storage->data + output->storageOffset;
int64_t *row_data = row->storage->data + row->storageOffset;
int64_t *col_data = col->storage->data + col->storageOffset;
int64_t *degree_data = degree->storage->data + degree->storageOffset;
int64_t e = 0, row_value, col_value, i, value;
while(e < THLongTensor_nElement(row)) {
row_value = row_data[e];
if (output_data[row_value] < 0) { // Node is unmatched.
// Find next unmatched neighbor.
col_value = -1;
for (i = 0; i < degree_data[row_value]; i++) {
value = col_data[e + i];
if (output_data[value] < 0) { // Neighbor found. Save and abort.
col_value = value;
break;
}
}
#define SERIAL(output, row, col, degree, SELECT) { \
int64_t *output_data = output->storage->data + output->storageOffset; \
int64_t *row_data = row->storage->data + row->storageOffset; \
int64_t *col_data = col->storage->data + col->storageOffset; \
int64_t *degree_data = degree->storage->data + degree->storageOffset; \
\
int64_t e = 0, row_value, col_value, value; \
while(e < THLongTensor_nElement(row)) { \
row_value = row_data[e]; \
if (output_data[row_value] >= 0) { \
col_value = -1; \
SELECT \
if (col_value < 0) { \
output_data[row_value] = row_value; \
} \
else { \
value = row_value < col_value ? row_value : col_value; \
output_data[row_value] = value; \
output_data[col_value] = value; \
} \
} \
e += degree_data[row_value]; \
} \
}
// Set cluster output for new matched nodes (one or two).
if (col_value < 0) {
output_data[row_value] = row_value;
}
else {
i = row_value < col_value ? row_value : col_value;
output_data[row_value] = i;
output_data[col_value] = i;
void cluster_serial(THLongTensor *output, THLongTensor *row, THLongTensor *col, THLongTensor *degree) {
int64_t d, c;
SERIAL(output, row, col, degree,
for (d = 0; d < degree_data[row_value]; d++) {
c = col_data[e + d];
if (output_data[c] < 0) {
col_value = c;
break;
}
}
// Jump to next row.
e += degree_data[row_value];
}
)
}
#include "generic/serial_cpu.c"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment