Unverified Commit f7382417 authored by duzekun's avatar duzekun Committed by GitHub
Browse files

[Refactor] Simplify the logic of sparse_conv (#2802)

parent ba8aa764
......@@ -147,12 +147,8 @@ torch::Tensor IndiceConvForwardMLUKernelLauncher(
torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse,
int64_t _subM) {
auto indice_num_cpu = indiceNum.to({torch::kCPU});
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
int indice_num_len = indiceNum.numel();
int64_t indice_num[indice_num_len];
for (int i = 0; i < indice_num_len; ++i) {
indice_num[i] = (int64_t)(((int *)indice_num_cpu_64)[i]);
}
auto indice_num_cpu_64 = indice_num_cpu.to(torch::kInt64);
auto indice_num = indice_num_cpu_64.data_ptr<int64_t>();
// generate empty output
int C = filters.dim() == 4 ? filters.size(3) : filters.size(4);
......@@ -241,12 +237,8 @@ std::vector<torch::Tensor> IndiceConvBackwardMLUKernelLauncher(
torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse,
int64_t _subM) {
auto indice_num_cpu = indiceNum.to({torch::kCPU});
auto indice_num_cpu_64 = indice_num_cpu.data_ptr<int>();
int indice_num_len = indiceNum.numel();
int64_t indice_num[indice_num_len];
for (int i = 0; i < indice_num_len; ++i) {
indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]);
}
auto indice_num_cpu_64 = indice_num_cpu.to(torch::kInt64);
auto indice_num = indice_num_cpu_64.data_ptr<int64_t>();
// generate empty input_grad
torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment