Commit 6884ab18 authored by rusty1s's avatar rusty1s
Browse files

no warnings

parent a49a26d0
......@@ -3,6 +3,7 @@
#include <torch/torch.h>
#define DIM_APPLY3(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
......@@ -16,7 +17,7 @@
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
......@@ -52,10 +53,12 @@
} else \
break; \
} \
}
} \
}()
#define DIM_APPLY4(TYPE1, TENSOR1, TYPE2, TENSOR2, TYPE3, TENSOR3, TYPE4, \
TENSOR4, DIM, CODE) \
[&] { \
TYPE1 *TENSOR1##_data = TENSOR1.data<TYPE1>(); \
auto TENSOR1##_size = TENSOR1.size(DIM); \
auto TENSOR1##_stride = TENSOR1.stride(DIM); \
......@@ -73,7 +76,7 @@
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(torch::CPU(at::kLong), {dims}); \
auto zeros = at::zeros(dims, torch::CPU(at::kLong)); \
auto counter = zeros.data<int64_t>(); \
bool has_finished = false; \
\
......@@ -111,4 +114,5 @@
} else \
break; \
} \
}
} \
}()
......@@ -11,7 +11,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride];
out_data[idx * out_stride] *= src_data[i * src_stride];
}
})
});
});
}
......@@ -24,7 +24,7 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
idx = index_data[i * index_stride];
out_data[idx * out_stride] /= src_data[i * src_stride];
}
})
});
});
}
......@@ -41,7 +41,7 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i;
}
}
})
});
});
}
......@@ -58,7 +58,7 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
arg_data[idx * arg_stride] = i;
}
}
})
});
});
}
......@@ -74,7 +74,7 @@ void index_backward(at::Tensor grad, at::Tensor index, at::Tensor arg,
out_data[i * out_stride] = grad_data[idx * grad_stride];
}
}
})
});
});
}
......
......@@ -4,7 +4,11 @@ from setuptools import setup, find_packages
import torch.cuda
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [CppExtension('scatter_cpu', ['cpu/scatter.cpp'])]
ext_modules = [
CppExtension(
'scatter_cpu', ['cpu/scatter.cpp'],
extra_compile_args=['-Wno-unused-variable'])
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
if torch.cuda.is_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment