Commit 7a4b01d4 authored by rusty1s's avatar rusty1s
Browse files

new tensor library

parent 0ad730f3
#include <TH/TH.h>
#define THGreedy_(NAME) TH_CONCAT_4(TH,Real,Greedy_,NAME)
#define TH_GREEDY_CLUSTER(cluster, row, col, deg, SELECT) { \
int64_t *clusterData = cluster->storage->data + cluster->storageOffset; \
int64_t *rowData = row->storage->data + row->storageOffset; \
int64_t *colData = col->storage->data + col->storageOffset; \
int64_t *degData = deg->storage->data + deg->storageOffset; \
ptrdiff_t rowIdx = 0, neighborIdx; \
int64_t rowValue, colValue, clusterValue, tmp; \
while(rowIdx < THLongTensor_nElement(row)) { \
rowValue = rowData[rowIdx]; \
if (clusterData[rowValue] < 0) { \
colValue = rowValue; \
SELECT \
clusterValue = rowValue < colValue ? rowValue : colValue; \
clusterData[rowValue] = clusterValue; \
clusterData[colValue] = clusterValue; \
} \
rowIdx += degData[rowValue]; \
} \
}
void THGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg) {
TH_GREEDY_CLUSTER(cluster, row, col, deg,
for (neighborIdx = rowIdx; neighborIdx < rowIdx + degData[rowValue]; neighborIdx++) {
tmp = colData[neighborIdx];
if (clusterData[tmp] < 0) {
colValue = tmp;
break;
}
}
)
}
#include "generic/THGreedy.c"
#include "THGenerateAllTypes.h"
void THGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg);
void THByteGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THByteTensor *weight);
void THCharGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THCharTensor *weight);
void THShortGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THShortTensor *weight);
void THIntGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THIntTensor *weight);
void THLongGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THLongTensor *weight);
void THFloatGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THFloatTensor *weight);
void THDoubleGreedy_cluster(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THDoubleTensor *weight);
#include <TH/TH.h>
#define THGrid_(NAME) TH_CONCAT_4(TH,Real,Grid_,NAME)
#include "generic/THGrid.c"
#include "THGenerateAllTypes.h"
void THByteGrid_cluster(THLongTensor *cluster, THByteTensor *pos, THByteTensor *size,
THLongTensor *count);
void THCharGrid_cluster(THLongTensor *cluster, THCharTensor *pos, THCharTensor *size,
THLongTensor *count);
void THShortGrid_cluster(THLongTensor *cluster, THShortTensor *pos, THShortTensor *size,
THLongTensor *count);
void THIntGrid_cluster(THLongTensor *cluster, THIntTensor *pos, THIntTensor *size,
THLongTensor *count);
void THLongGrid_cluster(THLongTensor *cluster, THLongTensor *pos, THLongTensor *size,
THLongTensor *count);
void THFloatGrid_cluster(THLongTensor *cluster, THFloatTensor *pos, THFloatTensor *size,
THLongTensor *count);
void THDoubleGrid_cluster(THLongTensor *cluster, THDoubleTensor *pos, THDoubleTensor *size,
THLongTensor *count);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THGreedy.c"
#else
void THGreedy_(cluster)(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THTensor *weight) {
real *weightData = weight->storage->data + weight->storageOffset;
real maxWeight = 0, tmpWeight;
TH_GREEDY_CLUSTER(cluster, row, col, deg,
for (neighborIdx = rowIdx; neighborIdx < rowIdx + degData[rowValue]; neighborIdx++) {
tmp = colData[neighborIdx];
tmpWeight = weightData[neighborIdx];
if (clusterData[tmp] < 0 && tmpWeight > maxWeight) {
colValue = tmp;
maxWeight = tmpWeight;
}
}
)
}
#endif // TH_GENERIC_FILE
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THGrid.c"
#else
void THGrid_(cluster)(THLongTensor *cluster, THTensor *pos, THTensor *size, THLongTensor *count) {
real *sizeData = size->storage->data + size->storageOffset;
int64_t *countData = count->storage->data + count->storageOffset;
int64_t dims = THLongTensor_nElement(count);
THLongTensor_unsqueeze1d(cluster, NULL, 1);
ptrdiff_t d; int64_t coef, value;
TH_TENSOR_DIM_APPLY2(int64_t, cluster, real, pos, 1,
coef = 1; value = 0;
for (d = 0; d < dims; d++) {
value += coef * (int64_t) (*(pos_data + d * pos_stride) / sizeData[d]);
coef *= countData[d];
}
cluster_data[0] = value;
)
THLongTensor_squeeze1d(cluster, NULL, 1);
}
#endif // TH_GENERIC_FILE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment