namespace detail { // Add a layer of SFINAE to support static_assert template class PtrTraits, int NewDim, bool B> struct UpcastTHCRoot { static THCDeviceTensor make(THCState* state, THCudaTensor* t); }; template class PtrTraits, int NewDim, bool B> struct UpcastTHC : UpcastTHCRoot { }; // Never instantiated SFINAE purposes only template class PtrTraits, int NewDim> struct UpcastTHC : UpcastTHCRoot { }; template class PtrTraits, int NewDim> struct UpcastTHC : UpcastTHCRoot { static THCDeviceTensor make(THCState* state, THCudaTensor* t) { thc_static_assert(NewDim > Dim); return toDeviceTensor(state, t). template upcastOuter(); } }; // Add a layer of SFINAE to support static_assert template class PtrTraits, int NewDim, bool B> struct DowncastTHCRoot { static THCDeviceTensor make(THCState* state, THCudaTensor* t); }; template class PtrTraits, int NewDim, bool B> struct DowncastTHC : DowncastTHCRoot { }; // Never instantiated SFINAE purposes only template class PtrTraits, int NewDim> struct DowncastTHC : DowncastTHCRoot { }; template class PtrTraits, int NewDim> struct DowncastTHC : DowncastTHCRoot { static THCDeviceTensor make(THCState* state, THCudaTensor* t) { thc_static_assert(NewDim < Dim); return toDeviceTensor(state, t). template downcastOuter(); } }; } // namespace detail #define SWITCH_UNROLL_CUDA_CAST_FACTORY(i) \ case i: \ if (NewDim > i) { \ return detail::UpcastTHC i)>:: \ make(state, t); \ } else if (NewDim == i) { \ return toDeviceTensor(state, t); \ } else { \ return detail::DowncastTHC:: \ make(state, t); \ } \ /* break; */ template class PtrTraits> THCDeviceTensor toDeviceTensorCast(THCState* state, THCudaTensor* t) { switch (THCudaTensor_nDimension(state, t)) { SWITCH_UNROLL_CUDA_CAST_FACTORY(1); SWITCH_UNROLL_CUDA_CAST_FACTORY(2); SWITCH_UNROLL_CUDA_CAST_FACTORY(3); SWITCH_UNROLL_CUDA_CAST_FACTORY(4); SWITCH_UNROLL_CUDA_CAST_FACTORY(5); SWITCH_UNROLL_CUDA_CAST_FACTORY(6); SWITCH_UNROLL_CUDA_CAST_FACTORY(7); SWITCH_UNROLL_CUDA_CAST_FACTORY(8); SWITCH_UNROLL_CUDA_CAST_FACTORY(9); SWITCH_UNROLL_CUDA_CAST_FACTORY(10); default: ; } // Not implemented THError("THCDeviceTensor dimension size not supported"); return NULL; /* never enters this piece, appeasing compiler warnings */ } #undef SWITCH_UNROLL_CUDA_CAST_FACTORY