Commit 6884ab18 authored by rusty1s's avatar rusty1s
Browse files

no warnings

parent a49a26d0
......@@ -3,112 +3,116 @@
#include <torch/torch.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
\
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
TYPE2 *TENSOR2##_data = TENSOR2.data<TYPE2>(); \
auto TENSOR2##_size = TENSOR2.size(DIM); \
auto TENSOR2##_stride = TENSOR2.stride(DIM); \
\
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
TYPE3 *TENSOR3##_data = TENSOR3.data<TYPE3>(); \
auto TENSOR3##_size = TENSOR3.size(DIM); \
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
TYPE4 *TENSOR4##_data = TENSOR4.data<TYPE4>(); \
auto TENSOR4##_size = TENSOR4.size(DIM); \
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
while (!has_finished) { \
CODE; \
if (dims == 1) \
break; \
\
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
for (int64_t cur_dim = 0; cur_dim < dims; cur_dim++) { \
if (cur_dim == DIM) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} \
continue; \
} \
continue; \
} \
\
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
counter[cur_dim]++; \
TENSOR1##_data += TENSOR1.stride(cur_dim); \
TENSOR2##_data += TENSOR2.stride(cur_dim); \
TENSOR3##_data += TENSOR3.stride(cur_dim); \
TENSOR4##_data += TENSOR4.stride(cur_dim); \
\
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
if (counter[cur_dim] == TENSOR1.size(cur_dim)) { \
if (cur_dim == dims - 1) { \
has_finished = true; \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} else { \
TENSOR1##_data -= counter[cur_dim] * TENSOR1.stride(cur_dim); \
TENSOR2##_data -= counter[cur_dim] * TENSOR2.stride(cur_dim); \
TENSOR3##_data -= counter[cur_dim] * TENSOR3.stride(cur_dim); \
TENSOR4##_data -= counter[cur_dim] * TENSOR4.stride(cur_dim); \
counter[cur_dim] = 0; \
} \
} else \
break; \
} \
} \
}
}()
......@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride];
}
})
});
});
}
......@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride];
}
})
});
});
}
......@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i;
}
}
})
});
});
}
......@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i;
}
}
})
});
});
}
......@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
out_data[i * out_stride] = grad_data[idx * grad_stride];
}
}
})
});
});
}
......
......@@ -4,7 +4,11 @@ from setuptools import setup, find_packages
import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension('scatter_cpu', ['cpu/scatter.cpp'])]
ext_modules = [
CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable'])
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment