Commit 3b318ccb authored by rusty1s's avatar rusty1s
Browse files

clean up

parent 81940250
#include <TH/TH.h> #include <TH/TH.h>
#define THGreedy_(NAME) TH_CONCAT_4(TH,Real,Greedy_,NAME) #define THGreedy_(NAME) TH_CONCAT_4(TH,Real,Greedy_,NAME)
#define DATA(TENSOR) TENSOR->storage->data + TENSOR->storageOffset
#define TH_GREEDY_CLUSTER(cluster, row, col, deg, SELECT) { \ #define TH_GREEDY_CLUSTER(cluster, row, col, deg, SELECT) { \
int64_t *clusterData = cluster->storage->data + cluster->storageOffset; \ THLongTensor_fill(cluster, -1); \
int64_t *rowData = row->storage->data + row->storageOffset; \ int64_t *clusterData = DATA(cluster); \
int64_t *colData = col->storage->data + col->storageOffset; \ int64_t *rowData = DATA(row); \
int64_t *degData = deg->storage->data + deg->storageOffset; \ int64_t *colData = DATA(col); \
int64_t *degData = DATA(deg); \
ptrdiff_t rowIdx = 0, neighborIdx; \ ptrdiff_t rowIdx = 0, neighborIdx; \
int64_t rowValue, colValue, clusterValue, tmp; \ int64_t rowValue, colValue, clusterValue, tmp; \
while(rowIdx < THLongTensor_nElement(row)) { \ while(rowIdx < THLongTensor_nElement(row)) { \
rowValue = rowData[rowIdx]; \ rowValue = rowData[rowIdx]; \
printf("rowValue = %lli, ", rowValue); \
if (clusterData[rowValue] < 0) { \ if (clusterData[rowValue] < 0) { \
colValue = rowValue; \ colValue = rowValue; \
SELECT \ SELECT \
clusterValue = rowValue < colValue ? rowValue : colValue; \ clusterValue = rowValue < colValue ? rowValue : colValue; \
printf("%lli", clusterValue); \
clusterData[rowValue] = clusterValue; \ clusterData[rowValue] = clusterValue; \
clusterData[colValue] = clusterValue; \ clusterData[colValue] = clusterValue; \
} \ } \
......
#include <TH/TH.h> #include <TH/TH.h>
#define THGrid_(NAME) TH_CONCAT_4(TH,Real,Grid_,NAME) #define THGrid_(NAME) TH_CONCAT_4(TH,Real,Grid_,NAME)
#define DATA(TENSOR) TENSOR->storage->data + TENSOR->storageOffset
#include "generic/THGrid.c" #include "generic/THGrid.c"
#include "THGenerateAllTypes.h" #include "THGenerateAllTypes.h"
void THByteGrid(THLongTensor *cluster, void THByteGrid_cluster(THLongTensor *cluster,
THByteTensor *pos, THByteTensor *pos,
THByteTensor *size, THByteTensor *size,
THLongTensor *count); THLongTensor *count);
void THCharGrid(THLongTensor *cluster, void THCharGrid_cluster(THLongTensor *cluster,
THCharTensor *pos, THCharTensor *pos,
THCharTensor *size, THCharTensor *size,
THLongTensor *count); THLongTensor *count);
void THShortGrid(THLongTensor *cluster, void THShortGrid_cluster(THLongTensor *cluster,
THShortTensor *pos, THShortTensor *pos,
THShortTensor *size, THShortTensor *size,
THLongTensor *count); THLongTensor *count);
void THIntGrid(THLongTensor *cluster, void THIntGrid_cluster(THLongTensor *cluster,
THIntTensor *pos, THIntTensor *pos,
THIntTensor *size, THIntTensor *size,
THLongTensor *count); THLongTensor *count);
void THLongGrid(THLongTensor *cluster, void THLongGrid_cluster(THLongTensor *cluster,
THLongTensor *pos, THLongTensor *pos,
THLongTensor *size, THLongTensor *size,
THLongTensor *count); THLongTensor *count);
void THFloatGrid(THLongTensor *cluster, void THFloatGrid_cluster(THLongTensor *cluster,
THFloatTensor *pos, THFloatTensor *pos,
THFloatTensor *size, THFloatTensor *size,
THLongTensor *count); THLongTensor *count);
void THDoubleGrid(THLongTensor *cluster, void THDoubleGrid_cluster(THLongTensor *cluster,
THDoubleTensor *pos, THDoubleTensor *pos,
THDoubleTensor *size, THDoubleTensor *size,
THLongTensor *count); THLongTensor *count);
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
void THGreedy_(cluster)(THLongTensor *cluster, THLongTensor *row, THLongTensor *col, void THGreedy_(cluster)(THLongTensor *cluster, THLongTensor *row, THLongTensor *col,
THLongTensor *deg, THTensor *weight) { THLongTensor *deg, THTensor *weight) {
real *weightData = weight->storage->data + weight->storageOffset; real *weightData = DATA(weight);
real maxWeight = 0, tmpWeight; real maxWeight = 0, tmpWeight;
TH_GREEDY_CLUSTER(cluster, row, col, deg, TH_GREEDY_CLUSTER(cluster, row, col, deg,
for (neighborIdx = rowIdx; neighborIdx < rowIdx + degData[rowValue]; neighborIdx++) { for (neighborIdx = rowIdx; neighborIdx < rowIdx + degData[rowValue]; neighborIdx++) {
......
...@@ -3,20 +3,21 @@ ...@@ -3,20 +3,21 @@
#else #else
void THGrid_(cluster)(THLongTensor *cluster, THTensor *pos, THTensor *size, THLongTensor *count) { void THGrid_(cluster)(THLongTensor *cluster, THTensor *pos, THTensor *size, THLongTensor *count) {
real *sizeData = size->storage->data + size->storageOffset; int64_t *clusterData = DATA(cluster);
int64_t *countData = count->storage->data + count->storageOffset; real *posData = DATA(pos);
int64_t dims = THLongTensor_nElement(count); real *sizeData = DATA(size);
THLongTensor_unsqueeze1d(cluster, NULL, 1); int64_t *countData = DATA(count);
ptrdiff_t d; int64_t coef, value;
TH_TENSOR_DIM_APPLY2(int64_t, cluster, real, pos, 1, ptrdiff_t n, d; int64_t coef, value;
for (n = 0; n < THTensor_(size)(pos, 0); n++) {
coef = 1; value = 0; coef = 1; value = 0;
for (d = 0; d < dims; d++) { for (d = 0; d < THTensor_(size)(pos, 1); d++) {
value += coef * (int64_t) (*(pos_data + d * pos_stride) / sizeData[d]); value += coef * (int64_t) (*(posData + d * pos->stride[1]) / sizeData[d]);
coef *= countData[d]; coef *= countData[d];
} }
cluster_data[0] = value; posData += pos->stride[0];
) clusterData[n] = value;
THLongTensor_squeeze1d(cluster, NULL, 1); }
} }
#endif // TH_GENERIC_FILE #endif // TH_GENERIC_FILE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment