// Copyright 2019-2020 Yan Yan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "mp_helper.h" #include #include #include #include #ifdef TV_CUDA #include #endif namespace tv { #ifdef TV_CUDA struct TorchGPU : public tv::GPU { virtual cudaStream_t getStream() const override { return at::cuda::getCurrentCUDAStream(); } }; #endif namespace detail { template struct TypeToTorchDtypeTraits; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kInt32; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kInt16; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt8) value = torch::kInt8; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kInt64; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kUInt8; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kBool; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kFloat32; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kFloat64; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kHalf; }; using all_torch_types_t = std::tuple; } // namespace detail template constexpr decltype(torch::kInt32) torch_type_v = detail::TypeToTorchDtypeTraits::value; template void dispatch_torch(at::ScalarType t, F &&f) { static_assert(sizeof...(Ts) > 0, "you need to provide at least one type"); bool notFound = true; tv::mp_for_each>([=, ¬Found, &f](auto I) { if (detail::TypeToTorchDtypeTraits::value == t) { std::forward(f)(decltype(I)()); notFound = false; } }); if (notFound) { std::stringstream ss; tv::mp_for_each>([=, &ss](auto I) { ss << tv::detail::TypeToString::value << " "; }); TV_THROW_RT_ERR("unknown type", t, ", available:", ss.str()); } } template struct DispatchTorch; template