Commit 7f91c408 authored by yan.yan's avatar yan.yan
Browse files

fix cuda 11 build

parent 42d92ee8
...@@ -24,6 +24,23 @@ ...@@ -24,6 +24,23 @@
#endif #endif
#include <boost/stacktrace.hpp> #include <boost/stacktrace.hpp>
#endif #endif
#ifdef TV_CUDA
#include <cuda.h>
#endif
#if defined(TV_USE_BOOST_TYPEOF) || (!defined(__clang__) && defined(CUDA_VERSION) && CUDA_VERSION >= 11000)
// a workaround when built with cuda 11
// two options: use BOOST_TYPEOF or identity_t.
// this is a nvcc bug, msvc/gcc/clang don't have this problem.
// #include <boost/typeof/typeof.hpp>
// #define TV_DECLTYPE(x) BOOST_TYPEOF(x)
namespace tv{
template <typename T>
using identity_t = T;
}
#define TV_DECLTYPE(x) tv::identity_t<decltype(x)>
#else
#define TV_DECLTYPE(x) decltype(x)
#endif
namespace tv { namespace tv {
......
...@@ -318,8 +318,8 @@ template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) { ...@@ -318,8 +318,8 @@ template <class... Ts, typename F> bool dispatch_noexcept(DType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type"); static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true; bool notFound = true;
mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) { mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (type_v<decltype(I)> == t && notFound) { if (type_v<TV_DECLTYPE(I)> == t && notFound) {
std::forward<F>(f)(decltype(I)()); std::forward<F>(f)(TV_DECLTYPE(I)());
notFound = false; notFound = false;
} }
}); });
...@@ -330,7 +330,7 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) { ...@@ -330,7 +330,7 @@ template <class... Ts, typename F> void dispatch(DType t, F &&f) {
if (!dispatch_noexcept<Ts...>(t, std::forward<F>(f))) { if (!dispatch_noexcept<Ts...>(t, std::forward<F>(f))) {
std::stringstream ss; std::stringstream ss;
mp_for_each<mp_list<Ts...>>([=, &ss](auto I) { mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << detail::TypeToString<decltype(I)>::value << " "; ss << detail::TypeToString<TV_DECLTYPE(I)>::value << " ";
}); });
TV_THROW_RT_ERR("unknown type", detail::typeString(t), TV_THROW_RT_ERR("unknown type", detail::typeString(t),
", available:", ss.str()); ", available:", ss.str());
...@@ -359,7 +359,7 @@ template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) { ...@@ -359,7 +359,7 @@ template <int... Is, typename F> bool dispatch_int_noexcept(int idx, F &&f) {
"you need to provide at least one candidate"); "you need to provide at least one candidate");
bool notFound = true; bool notFound = true;
mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) { mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) {
if (decltype(I)::value == idx && notFound) { if (TV_DECLTYPE(I)::value == idx && notFound) {
std::forward<F>(f)(I); std::forward<F>(f)(I);
notFound = false; notFound = false;
} }
...@@ -373,7 +373,7 @@ bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) { ...@@ -373,7 +373,7 @@ bool dispatch_int_noexcept(int idx, BinaryPredicate p, F &&f) {
"you need to provide at least one candidate"); "you need to provide at least one candidate");
bool notFound = true; bool notFound = true;
mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) { mp_for_each<mp_list_c<int, Is...>>([=, &notFound, &f](auto I) {
if (p(idx, decltype(I)::value) && notFound) { if (p(idx, TV_DECLTYPE(I)::value) && notFound) {
std::forward<F>(f)(I); std::forward<F>(f)(I);
notFound = false; notFound = false;
} }
...@@ -385,7 +385,7 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) { ...@@ -385,7 +385,7 @@ template <int... Is, typename F> void dispatch_int(int idx, F &&f) {
if (!dispatch_int_noexcept<Is...>(idx, std::forward<F>(f))) { if (!dispatch_int_noexcept<Is...>(idx, std::forward<F>(f))) {
std::stringstream ss; std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>( mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << decltype(I)::value << " "; }); [=, &ss](auto I) { ss << TV_DECLTYPE(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str()); TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
} }
} }
...@@ -396,7 +396,7 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) { ...@@ -396,7 +396,7 @@ void dispatch_int(int idx, BinaryPredicate p, F &&f) {
if (!dispatch_int_noexcept<Is...>(idx, p, std::forward<F>(f))) { if (!dispatch_int_noexcept<Is...>(idx, p, std::forward<F>(f))) {
std::stringstream ss; std::stringstream ss;
mp_for_each<mp_list_c<int, Is...>>( mp_for_each<mp_list_c<int, Is...>>(
[=, &ss](auto I) { ss << decltype(I)::value << " "; }); [=, &ss](auto I) { ss << TV_DECLTYPE(I)::value << " "; });
TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str()); TV_THROW_RT_ERR("unknown value", idx, ", available:", ss.str());
} }
} }
...@@ -408,7 +408,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) { ...@@ -408,7 +408,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
"you need to provide at least one candidate"); "you need to provide at least one candidate");
bool notFound = true; bool notFound = true;
mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) { mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
using val_lst_t = decltype(I); using val_lst_t = TV_DECLTYPE(I);
auto val_lst_size = mp_size<val_lst_t>::value; auto val_lst_size = mp_size<val_lst_t>::value;
bool equal = true; bool equal = true;
std::size_t count = 0; std::size_t count = 0;
...@@ -420,7 +420,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) { ...@@ -420,7 +420,7 @@ bool dispatch_container_noexcept(Iterator begin, Iterator end, F &&f) {
if (count >= val_lst_size) { if (count >= val_lst_size) {
TV_THROW_INVALID_ARG("iterator length invalid:", val_lst_size); TV_THROW_INVALID_ARG("iterator length invalid:", val_lst_size);
} }
constexpr auto c = decltype(E)::value; constexpr auto c = TV_DECLTYPE(E)::value;
if (c != *iter) { if (c != *iter) {
equal = false; equal = false;
} }
...@@ -450,8 +450,8 @@ void dispatch_container(Iterator begin, Iterator end, F &&f) { ...@@ -450,8 +450,8 @@ void dispatch_container(Iterator begin, Iterator end, F &&f) {
ss << "], available: "; ss << "], available: ";
mp_for_each<mp_list<Ts...>>([=, &ss](auto I) { mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << "["; ss << "[";
mp_for_each<decltype(I)>( mp_for_each<TV_DECLTYPE(I)>(
[=, &ss](auto E) { ss << decltype(E)::value << ","; }); [=, &ss](auto E) { ss << TV_DECLTYPE(E)::value << ","; });
ss << "]"; ss << "]";
}); });
TV_THROW_RT_ERR(ss.str()); TV_THROW_RT_ERR(ss.str());
...@@ -791,7 +791,7 @@ struct Tensor { ...@@ -791,7 +791,7 @@ struct Tensor {
writable_check(); writable_check();
TV_ASSERT_RT_ERR(device() == -1, "error"); TV_ASSERT_RT_ERR(device() == -1, "error");
Dispatch<detail::all_tensor_types_t>()(dtype_, [&](auto I) { Dispatch<detail::all_tensor_types_t>()(dtype_, [&](auto I) {
using Treal = decltype(I); using Treal = TV_DECLTYPE(I);
if (std::is_convertible<T, Treal>::value) { if (std::is_convertible<T, Treal>::value) {
auto ptr = reinterpret_cast<Treal *>(raw_data()); auto ptr = reinterpret_cast<Treal *>(raw_data());
std::fill(ptr, ptr + size(), Treal(value)); std::fill(ptr, ptr + size(), Treal(value));
...@@ -940,9 +940,9 @@ struct Tensor { ...@@ -940,9 +940,9 @@ struct Tensor {
TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now"); TV_ASSERT_INVALID_ARG(contiguous_, "only support contiguous for now");
auto tensor = Tensor(); auto tensor = Tensor();
Dispatch<detail::all_tensor_types_t>()(dtype, [&](auto Idst) { Dispatch<detail::all_tensor_types_t>()(dtype, [&](auto Idst) {
using Tdst = decltype(Idst); using Tdst = TV_DECLTYPE(Idst);
Dispatch<detail::all_tensor_types_t>()(this->dtype_, [&](auto Icur) { Dispatch<detail::all_tensor_types_t>()(this->dtype_, [&](auto Icur) {
using Tcur = decltype(Icur); using Tcur = TV_DECLTYPE(Icur);
if (std::is_convertible<Tcur, Tdst>::value) { if (std::is_convertible<Tcur, Tdst>::value) {
auto ptr = this->data<Tcur>(); auto ptr = this->data<Tcur>();
tensor = Tensor(this->shape_, this->stride_, dtype, this->device(), tensor = Tensor(this->shape_, this->stride_, dtype, this->device(),
...@@ -981,7 +981,7 @@ private: ...@@ -981,7 +981,7 @@ private:
template <typename Os> Os &operator<<(Os &os, const Tensor &tensor) { template <typename Os> Os &operator<<(Os &os, const Tensor &tensor) {
TV_ASSERT_INVALID_ARG(tensor.device() == -1, "must be cpu tensor"); TV_ASSERT_INVALID_ARG(tensor.device() == -1, "must be cpu tensor");
Dispatch<detail::all_tensor_types_t>()(tensor.dtype(), [&](auto I) { Dispatch<detail::all_tensor_types_t>()(tensor.dtype(), [&](auto I) {
using T = decltype(I); using T = TV_DECLTYPE(I);
std::stringstream ss; std::stringstream ss;
if (std::is_same<T, float>::value || std::is_same<T, double>::value) { if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
ss << std::setprecision(4); ss << std::setprecision(4);
......
...@@ -76,15 +76,15 @@ void dispatch_torch(at::ScalarType t, F &&f) { ...@@ -76,15 +76,15 @@ void dispatch_torch(at::ScalarType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type"); static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true; bool notFound = true;
tv::mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) { tv::mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (detail::TypeToTorchDtypeTraits<decltype(I)>::value == t) { if (detail::TypeToTorchDtypeTraits<TV_DECLTYPE(I)>::value == t) {
std::forward<F>(f)(decltype(I)()); std::forward<F>(f)(TV_DECLTYPE(I)());
notFound = false; notFound = false;
} }
}); });
if (notFound) { if (notFound) {
std::stringstream ss; std::stringstream ss;
tv::mp_for_each<mp_list<Ts...>>([=, &ss](auto I) { tv::mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << tv::detail::TypeToString<decltype(I)>::value << " "; ss << tv::detail::TypeToString<TV_DECLTYPE(I)>::value << " ";
}); });
TV_THROW_RT_ERR("unknown type", t, ", available:", ss.str()); TV_THROW_RT_ERR("unknown type", t, ", available:", ss.str());
} }
...@@ -101,7 +101,7 @@ struct DispatchTorch<T<Args...>> { ...@@ -101,7 +101,7 @@ struct DispatchTorch<T<Args...>> {
template <typename T> void check_torch_dtype(const torch::Tensor &tensor) { template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
DispatchTorch<detail::all_torch_types_t>()(tensor.scalar_type(), [&](auto I) { DispatchTorch<detail::all_torch_types_t>()(tensor.scalar_type(), [&](auto I) {
using Ttensor = decltype(I); using Ttensor = TV_DECLTYPE(I);
constexpr bool val = std::is_same<std::remove_cv_t<T>, Ttensor>::value; constexpr bool val = std::is_same<std::remove_cv_t<T>, Ttensor>::value;
TV_ASSERT_RT_ERR(val, "error"); TV_ASSERT_RT_ERR(val, "error");
}); });
......
...@@ -19,10 +19,15 @@ SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA") ...@@ -19,10 +19,15 @@ SPCONV_FORCE_BUILD_CUDA = os.getenv("SPCONV_FORCE_BUILD_CUDA")
PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor) PYTHON_VERSION = "{}.{}".format(sys.version_info.major, sys.version_info.minor)
remove_plus = torch.__version__.find("+") remove_plus = torch.__version__.find("+dev")
remove_dot = torch.__version__.find(".dev")
PYTORCH_VERSION = torch.__version__ PYTORCH_VERSION = torch.__version__
if remove_plus != -1: if remove_plus != -1:
PYTORCH_VERSION = torch.__version__[:remove_plus] PYTORCH_VERSION = torch.__version__[:remove_plus]
if remove_dot != -1:
PYTORCH_VERSION = torch.__version__[:remove_dot]
PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split("."))) PYTORCH_VERSION = list(map(int, PYTORCH_VERSION.split(".")))
PYTORCH_VERSION_NUMBER = PYTORCH_VERSION[0] * 10000 + PYTORCH_VERSION[1] * 100 + PYTORCH_VERSION[2] PYTORCH_VERSION_NUMBER = PYTORCH_VERSION[0] * 10000 + PYTORCH_VERSION[1] * 100 + PYTORCH_VERSION[2]
......
...@@ -268,10 +268,10 @@ int create_conv_indice_pair_cpu( ...@@ -268,10 +268,10 @@ int create_conv_indice_pair_cpu(
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) { tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) {
using Index = decltype(V); using Index = TV_DECLTYPE(V);
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) { tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value; constexpr int NDim = TV_DECLTYPE(I)::value;
tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end()); tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end());
tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end()); tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end());
tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end()); tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end());
...@@ -308,10 +308,10 @@ int create_submconv_indice_pair_cpu( ...@@ -308,10 +308,10 @@ int create_submconv_indice_pair_cpu(
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) { tv::dispatch_torch<int32_t, int64_t>(indicesIn.scalar_type(), [&](auto V) {
using Index = decltype(V); using Index = TV_DECLTYPE(V);
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) { tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value; constexpr int NDim = TV_DECLTYPE(I)::value;
tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end()); tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end());
tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end()); tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end());
tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end()); tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end());
......
...@@ -45,10 +45,10 @@ int create_conv_indice_pair_p1_cuda( ...@@ -45,10 +45,10 @@ int create_conv_indice_pair_p1_cuda(
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) { tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value; constexpr int NDim = TV_DECLTYPE(I)::value;
tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end()); tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end());
tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end()); tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end());
tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end()); tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end());
...@@ -57,7 +57,7 @@ int create_conv_indice_pair_p1_cuda( ...@@ -57,7 +57,7 @@ int create_conv_indice_pair_p1_cuda(
outSpatialShape.end()); outSpatialShape.end());
tv::DispatchInt<max_kernel_vol_t>()( tv::DispatchInt<max_kernel_vol_t>()(
kernelVolume, std::less_equal<int>(), [&](auto I2) { kernelVolume, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = TV_DECLTYPE(I2)::value;
if (transpose) { if (transpose) {
prepareDeConvIndicePairsKernel<Index, NDim, MaxKernelVolume> prepareDeConvIndicePairsKernel<Index, NDim, MaxKernelVolume>
<<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS, <<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS,
...@@ -106,10 +106,10 @@ int create_conv_indice_pair_p2_cuda( ...@@ -106,10 +106,10 @@ int create_conv_indice_pair_p2_cuda(
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) { tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value; constexpr int NDim = TV_DECLTYPE(I)::value;
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::SimpleVector<Index, NDim> ou(outSpatialShape.begin(), tv::SimpleVector<Index, NDim> ou(outSpatialShape.begin(),
outSpatialShape.end()); outSpatialShape.end());
...@@ -212,10 +212,10 @@ int create_submconv_indice_pair_cuda( ...@@ -212,10 +212,10 @@ int create_submconv_indice_pair_cuda(
if (numActIn == 0) if (numActIn == 0)
return 0; return 0;
tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) { tv::dispatch_torch<int32_t>(indicesIn.scalar_type(), [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
using IndexGrid = int32_t; using IndexGrid = int32_t;
tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) { tv::dispatch_int<2, 3, 4>(ndim, [&](auto I) {
constexpr int NDim = decltype(I)::value; constexpr int NDim = TV_DECLTYPE(I)::value;
tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end()); tv::SimpleVector<Index, NDim> ks(kernelSize.begin(), kernelSize.end());
tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end()); tv::SimpleVector<Index, NDim> st(stride.begin(), stride.end());
tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end()); tv::SimpleVector<Index, NDim> pa(padding.begin(), padding.end());
...@@ -254,7 +254,7 @@ int create_submconv_indice_pair_cuda( ...@@ -254,7 +254,7 @@ int create_submconv_indice_pair_cuda(
auto stash_count = table.get_stash_count(); auto stash_count = table.get_stash_count();
tv::DispatchInt<max_kernel_vol_t>()( tv::DispatchInt<max_kernel_vol_t>()(
kernelVolume, std::less_equal<int>(), [&](auto I2) { kernelVolume, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = TV_DECLTYPE(I2)::value;
getSubMIndicePairsHashKernel<Index, NDim, MaxKernelVolume> getSubMIndicePairsHashKernel<Index, NDim, MaxKernelVolume>
<<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS, <<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS,
0, stream>>>(tv::torch2tv<Index>(indicesIn), 0, stream>>>(tv::torch2tv<Index>(indicesIn),
...@@ -286,8 +286,8 @@ int create_submconv_indice_pair_cuda( ...@@ -286,8 +286,8 @@ int create_submconv_indice_pair_cuda(
tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[0], [&](auto K0C) { tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[0], [&](auto K0C) {
tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[1], [&](auto K1C) { tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[1], [&](auto K1C) {
constexpr int K0 = decltype(K0C)::value; constexpr int K0 = TV_DECLTYPE(K0C)::value;
constexpr int K1 = decltype(K1C)::value; constexpr int K1 = TV_DECLTYPE(K1C)::value;
found = true; found = true;
getSubMIndicePairsKernel2<Index, IndexGrid, K0, K1> getSubMIndicePairsKernel2<Index, IndexGrid, K0, K1>
<<<tv::cuda::getBlocks(numActIn), <<<tv::cuda::getBlocks(numActIn),
...@@ -306,9 +306,9 @@ int create_submconv_indice_pair_cuda( ...@@ -306,9 +306,9 @@ int create_submconv_indice_pair_cuda(
tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[1], [&](auto K1C) { tv::dispatch_int_noexcept<1, 3, 5>(kernelSize[1], [&](auto K1C) {
tv::dispatch_int_noexcept<1, 3, 5>( tv::dispatch_int_noexcept<1, 3, 5>(
kernelSize[2], [&](auto K2C) { kernelSize[2], [&](auto K2C) {
constexpr int K0 = decltype(K0C)::value; constexpr int K0 = TV_DECLTYPE(K0C)::value;
constexpr int K1 = decltype(K1C)::value; constexpr int K1 = TV_DECLTYPE(K1C)::value;
constexpr int K2 = decltype(K2C)::value; constexpr int K2 = TV_DECLTYPE(K2C)::value;
found = true; found = true;
getSubMIndicePairsKernel3<Index, IndexGrid, K0, K1, K2> getSubMIndicePairsKernel3<Index, IndexGrid, K0, K1, K2>
<<<tv::cuda::getBlocks(numActIn), <<<tv::cuda::getBlocks(numActIn),
...@@ -326,7 +326,7 @@ int create_submconv_indice_pair_cuda( ...@@ -326,7 +326,7 @@ int create_submconv_indice_pair_cuda(
if (!found) { if (!found) {
tv::DispatchInt< tv::DispatchInt<
max_kernel_vol_t>()(ndim, std::less_equal<int>(), [&](auto I2) { max_kernel_vol_t>()(ndim, std::less_equal<int>(), [&](auto I2) {
constexpr int MaxKernelVolume = decltype(I2)::value; constexpr int MaxKernelVolume = TV_DECLTYPE(I2)::value;
getSubMIndicePairsKernel<Index, IndexGrid, NDim, MaxKernelVolume> getSubMIndicePairsKernel<Index, IndexGrid, NDim, MaxKernelVolume>
<<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS, 0, <<<tv::cuda::getBlocks(numActIn), tv::cuda::CUDA_NUM_THREADS, 0,
stream>>>(tv::torch2tv<Index>(indicesIn), stream>>>(tv::torch2tv<Index>(indicesIn),
......
...@@ -29,9 +29,9 @@ void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures, ...@@ -29,9 +29,9 @@ void maxpool_fwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto dtype = inFeatures.scalar_type(); auto dtype = inFeatures.scalar_type();
auto int_dtype = indicesIn.scalar_type(); auto int_dtype = indicesIn.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
auto outFeaturesData = outFeatures.data_ptr<T>(); auto outFeaturesData = outFeatures.data_ptr<T>();
auto inFeaturesData = inFeatures.data_ptr<T>(); auto inFeaturesData = inFeatures.data_ptr<T>();
auto indicesInData = indicesIn.data_ptr<Index>(); auto indicesInData = indicesIn.data_ptr<Index>();
...@@ -58,9 +58,9 @@ void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures, ...@@ -58,9 +58,9 @@ void maxpool_bwd_cpu(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto dtype = inFeatures.scalar_type(); auto dtype = inFeatures.scalar_type();
auto int_dtype = indicesIn.scalar_type(); auto int_dtype = indicesIn.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
auto outFeaturesData = outFeatures.data_ptr<T>(); auto outFeaturesData = outFeatures.data_ptr<T>();
auto inFeaturesData = inFeatures.data_ptr<T>(); auto inFeaturesData = inFeatures.data_ptr<T>();
auto doutData = dout.data_ptr<T>(); auto doutData = dout.data_ptr<T>();
......
...@@ -320,13 +320,13 @@ void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures, ...@@ -320,13 +320,13 @@ void maxpool_fwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indicesIn, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &indicesIn,
...@@ -404,12 +404,12 @@ void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures, ...@@ -404,12 +404,12 @@ void maxpool_bwd_cuda(torch::Tensor outFeatures, torch::Tensor inFeatures,
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = using vecload_type_t =
std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>; std::conditional_t<std::is_same<T, at::Half>::value, int2, int4>;
using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>; using kernel_block_t = tv::mp_list_c<int, 64, 32, 16>;
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &dout, tv::mp_for_each<kernel_block_t>([=, &outFeatures, &inFeatures, &dout,
......
...@@ -26,9 +26,9 @@ void sparse_gather_cpu(torch::Tensor buffer, torch::Tensor features, ...@@ -26,9 +26,9 @@ void sparse_gather_cpu(torch::Tensor buffer, torch::Tensor features,
auto dtype = features.scalar_type(); auto dtype = features.scalar_type();
auto int_dtype = indices.scalar_type(); auto int_dtype = indices.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
Index *indices_data = indices.data_ptr<Index>(); Index *indices_data = indices.data_ptr<Index>();
T *buffer_data = buffer.data_ptr<T>(); T *buffer_data = buffer.data_ptr<T>();
const T *features_data = features.data_ptr<T>(); const T *features_data = features.data_ptr<T>();
...@@ -50,9 +50,9 @@ void sparse_scatter_add_cpu(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -50,9 +50,9 @@ void sparse_scatter_add_cpu(torch::Tensor buffer, torch::Tensor outFeatures,
auto int_dtype = indices.scalar_type(); auto int_dtype = indices.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(int_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
Index *indices_data = indices.data_ptr<Index>(); Index *indices_data = indices.data_ptr<Index>();
const T *buffer_data = buffer.data_ptr<T>(); const T *buffer_data = buffer.data_ptr<T>();
T *features_data = outFeatures.data_ptr<T>(); T *features_data = outFeatures.data_ptr<T>();
......
...@@ -51,10 +51,10 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -51,10 +51,10 @@ void sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
auto dtype = features.scalar_type(); auto dtype = features.scalar_type();
auto inds_dtype = indices.scalar_type(); auto inds_dtype = indices.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = typename half_vec_sadd<T>::type; using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
...@@ -140,10 +140,10 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures, ...@@ -140,10 +140,10 @@ void sparse_scatter_add_cuda(torch::Tensor buffer, torch::Tensor outFeatures,
auto inds_dtype = indices.scalar_type(); auto inds_dtype = indices.scalar_type();
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = typename half_vec_sadd<T>::type; using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = constexpr int vecloadFactor =
sizeof(vecload_type_t) / sizeof(T); // important for half. sizeof(vecload_type_t) / sizeof(T); // important for half.
...@@ -235,10 +235,10 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features, ...@@ -235,10 +235,10 @@ void batch_sparse_gather_cuda(torch::Tensor buffer, torch::Tensor features,
int inds_stride = indices.size(1); int inds_stride = indices.size(1);
int feature_stride = buffer.size(1); int feature_stride = buffer.size(1);
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = typename half_vec<T>::type; using vecload_type_t = typename half_vec<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T); constexpr int vecloadFactor = sizeof(vecload_type_t) / sizeof(T);
tv::mp_for_each<kernel_block_t>( tv::mp_for_each<kernel_block_t>(
...@@ -308,10 +308,10 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer, ...@@ -308,10 +308,10 @@ void batch_sparse_scatter_add_cuda(torch::Tensor buffer,
int feature_stride = buffer.size(1); int feature_stride = buffer.size(1);
tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) { tv::DispatchTorch<float_types_t>()(dtype, [&](auto TValue) {
using T = decltype(TValue); using T = TV_DECLTYPE(TValue);
using vecload_type_t = typename half_vec_sadd<T>::type; using vecload_type_t = typename half_vec_sadd<T>::type;
tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) { tv::DispatchTorch<int_types_t>()(inds_dtype, [&](auto IndexValue) {
using Index = decltype(IndexValue); using Index = TV_DECLTYPE(IndexValue);
bool notFound = true; bool notFound = true;
constexpr int vecloadFactor = 1; // important for half. constexpr int vecloadFactor = 1; // important for half.
......
cutlass @ c2b80ad4
Subproject commit c2b80ad4e4f8b60a65500bd04c8fecddff2ba355
mp11 @ 29764aad
Subproject commit 29764aad4881fde809af6a025c12012e47a55515
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment