// Copyright 2019 Yan Yan // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #ifdef TV_CUDA #include #endif namespace tv { #ifdef TV_CUDA struct TorchGPU : public tv::GPU { virtual cudaStream_t getStream() const override { return at::cuda::getCurrentCUDAStream(); } }; #endif template void check_torch_dtype(const torch::Tensor &tensor) { switch (tensor.scalar_type()) { case at::ScalarType::Double: { auto val = std::is_same, double>::value; TV_ASSERT_RT_ERR(val, "error"); break; } case at::ScalarType::Float: { auto val = std::is_same, float>::value; TV_ASSERT_RT_ERR(val, "error"); break; } case at::ScalarType::Int: { auto val = std::is_same, int>::value; TV_ASSERT_RT_ERR(val, "error"); break; } case at::ScalarType::Half: { auto val = std::is_same, at::Half>::value; TV_ASSERT_RT_ERR(val, "error"); break; } case at::ScalarType::Long: { auto val = std::is_same, long>::value; TV_ASSERT_RT_ERR(val, "error"); break; } default: TV_ASSERT_RT_ERR(false, "error"); } } namespace detail { template struct TypeToTorchDtypeTraits; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kInt32; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kInt64; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kFloat32; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kFloat64; }; template <> struct TypeToTorchDtypeTraits { static constexpr decltype(torch::kInt32) value = torch::kHalf; }; } // namespace detail template constexpr decltype(torch::kInt32) torch_type_v = detail::TypeToTorchDtypeTraits::value; template tv::TensorView torch2tv(const torch::Tensor &tensor) { check_torch_dtype(tensor); tv::Shape shape; for (auto i : tensor.sizes()) { shape.push_back(i); } return tv::TensorView(tensor.data_ptr>(), shape); } namespace detail { template <> struct TypeToString { static constexpr const char *value = "half"; }; } // namespace detail template void dispatch_torch(at::ScalarType t, F &&f) { static_assert(sizeof...(Ts) > 0, "you need to provide at least one type"); bool notFound = true; spconv::tv::mp_for_each>([=, ¬Found, &f](auto I) { if (torch_type_v == t) { std::forward(f)(decltype(I)()); notFound = false; } }); if (notFound) { std::stringstream ss; spconv::tv::mp_for_each>([=, &ss](auto I) { ss << tv::detail::TypeToString::value << " "; }); TV_THROW_RT_ERR("unknown type", t, ", available: ", ss.str()); } } } // namespace tv