Commit 19e73bbe authored by Yan Yan's avatar Yan Yan
Browse files

format code with clang-format, better c++ code

parent c336139f
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -16,14 +16,11 @@
#define POINTPILLARS_SCATTER_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace spconv
{
namespace functor
{
namespace spconv {
namespace functor {
template <typename Device, typename T, typename Index>
struct PointPillarScatter
{
void operator()(const Device& d, tv::TensorView<T> canvas,
struct PointPillarScatter {
void operator()(const Device &d, tv::TensorView<T> canvas,
tv::TensorView<const T> features,
tv::TensorView<const T> coors);
};
......
......@@ -16,8 +16,8 @@
#define PILLAR_SCATTER_OP_H_
#include <spconv/pillar_scatter_functor.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h>
namespace spconv {
......@@ -42,9 +42,10 @@ torch::Tensor pointPillarScatter(torch::Tensor features, torch::Tensor coors,
torch::zeros({shapeData[0], shapeData[1], shapeData[2], shapeData[3]},
features.options());
TV_ASSERT_RT_ERR(shapeData[1] == features.size(1), "error");
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
functor::PointPillarScatter<tv::GPU, T, int> ftor;
ftor(tv::TorchGPU(), tv::torch2tv<T>(canvas), tv::torch2tv<const T>(features.squeeze()),
ftor(tv::TorchGPU(), tv::torch2tv<T>(canvas),
tv::torch2tv<const T>(features.squeeze()),
tv::torch2tv<const T>(coors.squeeze()));
#endif
return canvas;
......
......@@ -29,7 +29,8 @@ using namespace pybind11::literals;
template <typename DType, int NDim>
int points_to_voxel_3d_np(py::array_t<DType> points, py::array_t<DType> voxels,
py::array_t<DType> voxel_point_mask, py::array_t<int> coors,
py::array_t<DType> voxel_point_mask,
py::array_t<int> coors,
py::array_t<int> num_points_per_voxel,
py::array_t<int> coor_to_voxelidx,
std::vector<DType> voxel_size,
......@@ -94,14 +95,12 @@ int points_to_voxel_3d_np(py::array_t<DType> points, py::array_t<DType> voxels,
}
template <typename DType, int NDim>
int points_to_voxel_3d_np_mean(py::array_t<DType> points,
py::array_t<DType> voxel_point_mask, py::array_t<DType> voxels,
py::array_t<DType> means, py::array_t<int> coors,
py::array_t<int> num_points_per_voxel,
py::array_t<int> coor_to_voxelidx,
std::vector<DType> voxel_size,
std::vector<DType> coors_range, int max_points,
int max_voxels) {
int points_to_voxel_3d_np_mean(
py::array_t<DType> points, py::array_t<DType> voxel_point_mask,
py::array_t<DType> voxels, py::array_t<DType> means, py::array_t<int> coors,
py::array_t<int> num_points_per_voxel, py::array_t<int> coor_to_voxelidx,
std::vector<DType> voxel_size, std::vector<DType> coors_range,
int max_points, int max_voxels) {
auto points_rw = points.template mutable_unchecked<2>();
auto means_rw = means.template mutable_unchecked<2>();
auto voxels_rw = voxels.template mutable_unchecked<3>();
......@@ -174,8 +173,8 @@ int points_to_voxel_3d_np_mean(py::array_t<DType> points,
template <typename DType, int NDim>
int points_to_voxel_3d_with_filtering(
py::array_t<DType> points, py::array_t<DType> voxels,
py::array_t<DType> voxel_point_mask, py::array_t<int> voxel_mask, py::array_t<DType> mins,
py::array_t<DType> maxs, py::array_t<int> coors,
py::array_t<DType> voxel_point_mask, py::array_t<int> voxel_mask,
py::array_t<DType> mins, py::array_t<DType> maxs, py::array_t<int> coors,
py::array_t<int> num_points_per_voxel, py::array_t<int> coor_to_voxelidx,
std::vector<DType> voxel_size, std::vector<DType> coors_range,
int max_points, int max_voxels, int block_factor, int block_size,
......
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -16,14 +16,14 @@
#define SPARSE_POOL_OP_H_
#include <spconv/maxpool.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h>
namespace spconv {
template <typename T>
torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
torch::Tensor indiceNum, int64_t numAct) {
torch::Tensor indiceNum, int64_t numAct) {
auto device = features.device().type();
auto kernelVolume = indicePairs.size(0);
auto numInPlanes = features.size(1);
......@@ -43,8 +43,8 @@ torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
forwardFtor(tv::CPU(), tv::torch2tv<T>(output),
tv::torch2tv<const T>(features),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
}
#ifdef SPCONV_CUDA
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
functor::SparseMaxPoolForwardFunctor<tv::GPU, T, int> forwardFtor;
forwardFtor(tv::TorchGPU(), tv::torch2tv<T>(output),
......@@ -53,7 +53,7 @@ torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
TV_CHECK_CUDA_ERR();
}
#endif
else{
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
// totalTime += timer.report() / 1000.0;
......@@ -63,17 +63,17 @@ torch::Tensor indiceMaxPool(torch::Tensor features, torch::Tensor indicePairs,
}
template <typename T>
torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
torch::Tensor outFeatures,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum) {
torch::Tensor
indiceMaxPoolBackward(torch::Tensor features, torch::Tensor outFeatures,
torch::Tensor outGrad, torch::Tensor indicePairs,
torch::Tensor indiceNum) {
auto device = features.device().type();
auto numInPlanes = features.size(1);
auto indicePairNumCpu = indiceNum.to({torch::kCPU});
auto options =
torch::TensorOptions().dtype(features.dtype()).device(features.device());
torch::Tensor inputGrad = torch::zeros(features.sizes(), options);
auto kernelVolume = indicePairs.size(0);
auto kernelVolume = indicePairs.size(0);
for (int i = 0; i < kernelVolume; ++i) {
auto nHot = indicePairNumCpu.data_ptr<int>()[i];
if (nHot <= 0) {
......@@ -85,8 +85,8 @@ torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
tv::torch2tv<const T>(features),
tv::torch2tv<const T>(outGrad), tv::torch2tv<T>(inputGrad),
tv::torch2tv<const int>(indicePairs).subview(i), nHot);
}
#ifdef SPCONV_CUDA
}
#ifdef TV_CUDA
else if (device == torch::kCUDA) {
functor::SparseMaxPoolBackwardFunctor<tv::GPU, T, int> backwardFtor;
backwardFtor(tv::TorchGPU(), tv::torch2tv<const T>(outFeatures),
......@@ -96,10 +96,9 @@ torch::Tensor indiceMaxPoolBackward(torch::Tensor features,
TV_CHECK_CUDA_ERR();
}
#endif
else{
else {
TV_ASSERT_INVALID_ARG(false, "unknown device type");
}
}
return inputGrad;
}
......
......@@ -14,7 +14,7 @@
#ifndef REORDERING_CU_H_
#define REORDERING_CU_H_
#include <tensorview/helper_kernel.cu.h>
#include <tensorview/kernel_utils.h>
// see http://www.nvidia.com/content/GTC-2010/pdfs/2238_GTC2010.pdf.
namespace spconv {
......
// Copyright 2019 Yan Yan
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -16,23 +16,21 @@
#define SPARSE_REORDERING_FUNCTOR_H_
#include <tensorview/tensorview.h>
namespace spconv
{
namespace functor
{
namespace spconv {
namespace functor {
template <typename Device, typename T, typename Index>
struct SparseGatherFunctor
{
void operator()(const Device& d, tv::TensorView<T> buffer, tv::TensorView<const T> features,
tv::TensorView<const Index> indices, int size);
struct SparseGatherFunctor {
void operator()(const Device &d, tv::TensorView<T> buffer,
tv::TensorView<const T> features,
tv::TensorView<const Index> indices, int size);
};
template <typename Device, typename T, typename Index>
struct SparseScatterAddFunctor
{
void operator()(const Device& d, tv::TensorView<T> out_features,
tv::TensorView<const T> buffer, tv::TensorView<const Index> indices,
int size, bool stable=false);
struct SparseScatterAddFunctor {
void operator()(const Device &d, tv::TensorView<T> out_features,
tv::TensorView<const T> buffer,
tv::TensorView<const Index> indices, int size,
bool stable = false);
};
} // namespace functor
} // namespace spconv
......
......@@ -17,8 +17,8 @@
#include <spconv/indice.h>
#include <spconv/reordering.h>
#include <tensorview/torch_utils.h>
#include <torch/script.h>
#include <torch_utils.h>
#include <utility/timer.h>
namespace spconv {
......@@ -101,7 +101,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize,
tv::torch2tv<int>(indiceNum), kernelSize32, stride32, padding32,
dilation32, outSpatialShape32, transpose, false, useHash);
}
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
else if (indices.device().type() == torch::kCUDA) {
auto getIndicePairFtor =
functor::CreateSubMIndicePairFunctor<tv::GPU, int, int, NDim>();
......@@ -149,7 +149,7 @@ getIndicePair(torch::Tensor indices, int64_t batchSize,
kernelSize32, stride32, padding32, dilation32, outSpatialShape32,
transpose);
}
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
else if (indices.device().type() == torch::kCUDA) {
auto getIndicePairFtorP1 =
functor::CreateConvIndicePairFunctorP1<tv::GPU, int, int, NDim>();
......@@ -269,7 +269,7 @@ std::vector<torch::Tensor> getIndicePairPreGrid(
dilation32, outSpatialShape32, transpose);
gridOut.fill_(-1);
}
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
else if (indices.device().type() == torch::kCUDA) {
auto getIndicePairFtor =
functor::CreateSubMIndicePairFunctor<tv::GPU, int, int, NDim>();
......@@ -299,7 +299,7 @@ std::vector<torch::Tensor> getIndicePairPreGrid(
transpose, true);
gridOut.fill_(-1);
}
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
else if (indices.device().type() == torch::kCUDA) {
auto getIndicePairFtorP1 =
functor::CreateConvIndicePairFunctorP1<tv::GPU, int, int, NDim>();
......
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <sstream>
#ifdef TV_USE_STACKTRACE
#if defined(WIN32) || defined(_WIN32) || \
defined(__WIN32) && !defined(__CYGWIN__)
#define BOOST_STACKTRACE_USE_WINDBG
#else
// require linking with -ldl and -lbacktrace in linux
#define BOOST_STACKTRACE_USE_BACKTRACE
#endif
#include <boost/stacktrace.hpp>
#endif
namespace tv {
template <class SStream, class T> void sstream_print(SStream &ss, T val) {
ss << val;
}
template <class SStream, class T, class... TArgs>
void sstream_print(SStream &ss, T val, TArgs... args) {
ss << val << " ";
sstream_print(ss, args...);
}
template <class... TArgs> void ssprint(TArgs... args) {
std::stringstream ss;
sstream_print(ss, args...);
std::cout << ss.str() << std::endl;
}
#ifdef TV_USE_STACKTRACE
#define TV_BACKTRACE_PRINT(ss) \
ss << std::endl << boost::stacktrace::stacktrace();
#else
#define TV_BACKTRACE_PRINT(ss)
#endif
#define TV_THROW_RT_ERR(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
}
#define TV_THROW_INVALID_ARG(...) \
{ \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
}
#define TV_ASSERT_RT_ERR(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::runtime_error(__macro_s.str()); \
} \
}
#define TV_ASSERT_INVALID_ARG(expr, ...) \
{ \
if (!(expr)) { \
std::stringstream __macro_s; \
__macro_s << __FILE__ << " " << __LINE__ << "\n"; \
__macro_s << #expr << " assert faild. "; \
tv::sstream_print(__macro_s, __VA_ARGS__); \
TV_BACKTRACE_PRINT(__macro_s); \
throw std::invalid_argument(__macro_s.str()); \
} \
}
} // namespace tv
\ No newline at end of file
#pragma once
// from pytorch.aten
#include "tensorview.h"
#include <type_traits>
namespace tv {
namespace cuda {
template <typename T1, typename T2> inline int DivUp(const T1 a, const T2 b) {
return (a + b - 1) / b;
}
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int getNumThreads(const int N) {
if (N > CUDA_NUM_THREADS) {
return CUDA_NUM_THREADS;
}
return DivUp(N, 32) * 32;
}
inline int getBlocks(const int N) {
TV_ASSERT_RT_ERR(N > 0,
"CUDA kernel launch blocks must be positive, but got N=", N);
return DivUp(N, getNumThreads(N));
}
} // namespace cuda
} // namespace tv
\ No newline at end of file
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <eigen3/Eigen/Dense>
namespace tv {
template <typename T, int Row = Eigen::Dynamic, int Col = Eigen::Dynamic>
Eigen::Map<Eigen::Matrix<T, Row, Col, Eigen::RowMajor>>
tv2eigen(TensorView<T> view) {
TV_ASSERT_INVALID_ARG(view.ndim() <= 2 && view.ndim() > 0, "error");
if (Row != Eigen::Dynamic) {
TV_ASSERT_INVALID_ARG(view.dim(0) == Row, "error");
}
if (Col != Eigen::Dynamic) {
TV_ASSERT_INVALID_ARG(view.dim(1) == Col, "error");
}
int row = 1;
if (view.ndim() == 2) {
row = view.dim(0);
}
Eigen::Map<Eigen::Matrix<T, Row, Col, Eigen::RowMajor>> eigen_map(
view.data(), row, view.dim(1));
return eigen_map;
}
} // namespace tv
#pragma once
// from pytorch.aten
#include "tensorview.h"
namespace tv
{
namespace launch
{
template <typename T1, typename T2>
inline int DivUp(const T1 a, const T2 b) { return (a + b - 1) / b; }
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int getBlocks(const int N)
{
TV_ASSERT_RT_ERR(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
return DivUp(N, CUDA_NUM_THREADS);
}
} // namespace launch
} // namespace tv
\ No newline at end of file
#pragma once
// from tensorflow
namespace tv
{
namespace detail
{
namespace tv {
namespace detail {
template <typename T>
class KernelLoop
{
struct Iterator
{
__forceinline__ __device__ Iterator(T index, T delta) : index_(index), delta_(delta) {}
template <typename T> class KernelLoop {
struct Iterator {
__forceinline__ __device__ Iterator(T index, T delta)
: index_(index), delta_(delta) {}
__forceinline__ __device__ T operator*() const { return index_; }
__forceinline__ __device__ Iterator &operator++()
{
__forceinline__ __device__ Iterator &operator++() {
index_ += delta_;
return *this;
}
__forceinline__ __device__ bool operator!=(const Iterator &other) const
{
__forceinline__ __device__ bool operator!=(const Iterator &other) const {
bool greater = index_ > other.index_;
bool less = index_ < other.index_;
// Anything past an end iterator (delta_ == 0) is equal.
// In range-based for loops, this optimizes to 'return less'.
if (!other.delta_)
{
if (!other.delta_) {
return less;
}
if (!delta_)
{
if (!delta_) {
return greater;
}
return less || greater;
......@@ -43,7 +35,9 @@ public:
__forceinline__ __device__ KernelLoop(T begin, T delta, T end)
: begin_(begin), delta_(delta), end_(end) {}
__forceinline__ __device__ Iterator begin() const { return Iterator{begin_, delta_}; }
__forceinline__ __device__ Iterator begin() const {
return Iterator{begin_, delta_};
}
__forceinline__ __device__ Iterator end() const { return Iterator{end_, 0}; }
private:
......@@ -53,29 +47,26 @@ private:
};
} // namespace detail
template <typename T, int NumILP=1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopX(T count)
{
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopX(T count) {
return detail::KernelLoop<T>(blockIdx.x * blockDim.x + threadIdx.x,
gridDim.x * blockDim.x * NumILP, count);
gridDim.x * blockDim.x * NumILP, count);
}
// Helper to visit indices in the range 0 <= i < count using the y-coordinate.
// Usage: for(int i : KernelLoopY(count)) { visit(i); }
template <typename T, int NumILP=1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopY(T count)
{
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopY(T count) {
return detail::KernelLoop<T>(blockIdx.y * blockDim.y + threadIdx.y,
gridDim.y * blockDim.y * NumILP, count);
gridDim.y * blockDim.y * NumILP, count);
}
// Helper to visit indices in the range 0 <= i < count using the z-coordinate.
// Usage: for(int i : KernelLoopZ(count)) { visit(i); }
template <typename T, int NumILP=1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopZ(T count)
{
template <typename T, int NumILP = 1>
__forceinline__ __device__ detail::KernelLoop<T> KernelLoopZ(T count) {
return detail::KernelLoop<T>(blockIdx.z * blockDim.z + threadIdx.z,
gridDim.z * blockDim.z * NumILP, count);
gridDim.z * blockDim.z * NumILP, count);
}
} // namespace tv
\ No newline at end of file
......@@ -3,7 +3,7 @@
#include <type_traits>
#include <utility>
namespace spconv {
namespace tv {
template <class... T> struct mp_list {};
template <class T, T... I>
......@@ -11,9 +11,10 @@ using mp_list_c = mp_list<std::integral_constant<T, I>...>;
namespace detail {
template <class... T, class F>
constexpr F mp_for_each_impl(mp_list<T...>, F &&f) {
return std::initializer_list<int>{(f(T()), 0)...}, std::forward<F>(f);
template <class... Ts, class F>
constexpr F mp_for_each_impl(mp_list<Ts...>, F &&f) {
return (void)(std::initializer_list<int>{(f(Ts()), 0)...}),
std::forward<F>(f);
}
template <class F> constexpr F mp_for_each_impl(mp_list<>, F &&f) {
......@@ -42,6 +43,6 @@ using mp_rename = typename detail::mp_rename_impl<A, B>::type;
template <class L, class F> constexpr F mp_for_each(F &&f) {
return detail::mp_for_each_impl(mp_rename<L, mp_list>(), std::forward<F>(f));
}
} // namespace spconv
} // namespace tv
#endif
\ No newline at end of file
// Copyright Louis Delacroix 2010 - 2014.
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)
//
// A pretty printing library for C++
//
// Usage:
// Include this header, and operator<< will "just work".
#ifndef H_PRETTY_PRINT
#define H_PRETTY_PRINT
#include <cstddef>
#include <iterator>
#include <memory>
#include <ostream>
#include <set>
#include <tuple>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <valarray>
namespace pretty_print {
namespace detail {
// SFINAE type trait to detect whether T::const_iterator exists.
struct sfinae_base {
using yes = char;
using no = yes[2];
};
template <typename T> struct has_const_iterator : private sfinae_base {
private:
template <typename C> static yes &test(typename C::const_iterator *);
template <typename C> static no &test(...);
public:
static const bool value = sizeof(test<T>(nullptr)) == sizeof(yes);
using type = T;
};
template <typename T> struct has_begin_end : private sfinae_base {
private:
template <typename C>
static yes &
f(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator (C::*)()
const>(&C::begin)),
typename C::const_iterator (C::*)() const>::value>::type *);
template <typename C> static no &f(...);
template <typename C>
static yes &
g(typename std::enable_if<
std::is_same<decltype(static_cast<typename C::const_iterator (C::*)()
const>(&C::end)),
typename C::const_iterator (C::*)() const>::value,
void>::type *);
template <typename C> static no &g(...);
public:
static bool const beg_value = sizeof(f<T>(nullptr)) == sizeof(yes);
static bool const end_value = sizeof(g<T>(nullptr)) == sizeof(yes);
};
} // namespace detail
// Holds the delimiter values for a specific character type
template <typename TChar> struct delimiters_values {
using char_type = TChar;
const char_type *prefix;
const char_type *delimiter;
const char_type *postfix;
};
// Defines the delimiter values for a specific container and character type
template <typename T, typename TChar> struct delimiters {
using type = delimiters_values<TChar>;
static const type values;
};
// Functor to print containers. You can use this directly if you want
// to specificy a non-default delimiters type. The printing logic can
// be customized by specializing the nested template.
template <typename T, typename TChar = char,
typename TCharTraits = ::std::char_traits<TChar>,
typename TDelimiters = delimiters<T, TChar>>
struct print_container_helper {
using delimiters_type = TDelimiters;
using ostream_type = std::basic_ostream<TChar, TCharTraits>;
template <typename U> struct printer {
static void print_body(const U &c, ostream_type &stream) {
using std::begin;
using std::end;
auto it = begin(c);
const auto the_end = end(c);
if (it != the_end) {
for (;;) {
stream << *it;
if (++it == the_end)
break;
if (delimiters_type::values.delimiter != NULL)
stream << delimiters_type::values.delimiter;
}
}
}
};
print_container_helper(const T &container) : container_(container) {}
inline void operator()(ostream_type &stream) const {
if (delimiters_type::values.prefix != NULL)
stream << delimiters_type::values.prefix;
printer<T>::print_body(container_, stream);
if (delimiters_type::values.postfix != NULL)
stream << delimiters_type::values.postfix;
}
private:
const T &container_;
};
// Specialization for pairs
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename T1, typename T2>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::pair<T1, T2>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
static void print_body(const std::pair<T1, T2> &c, ostream_type &stream) {
stream << c.first;
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << c.second;
}
};
// Specialization for tuples
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
template <typename... Args>
struct print_container_helper<T, TChar, TCharTraits,
TDelimiters>::printer<std::tuple<Args...>> {
using ostream_type =
typename print_container_helper<T, TChar, TCharTraits,
TDelimiters>::ostream_type;
using element_type = std::tuple<Args...>;
template <std::size_t I> struct Int {};
static void print_body(const element_type &c, ostream_type &stream) {
tuple_print(c, stream, Int<0>());
}
static void tuple_print(const element_type &, ostream_type &,
Int<sizeof...(Args)>) {}
static void
tuple_print(const element_type &c, ostream_type &stream,
typename std::conditional<sizeof...(Args) != 0, Int<0>,
std::nullptr_t>::type) {
stream << std::get<0>(c);
tuple_print(c, stream, Int<1>());
}
template <std::size_t N>
static void tuple_print(const element_type &c, ostream_type &stream, Int<N>) {
if (print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter != NULL)
stream << print_container_helper<T, TChar, TCharTraits,
TDelimiters>::delimiters_type::values
.delimiter;
stream << std::get<N>(c);
tuple_print(c, stream, Int<N + 1>());
}
};
// Prints a print_container_helper to the specified stream.
template <typename T, typename TChar, typename TCharTraits,
typename TDelimiters>
inline std::basic_ostream<TChar, TCharTraits> &operator<<(
std::basic_ostream<TChar, TCharTraits> &stream,
const print_container_helper<T, TChar, TCharTraits, TDelimiters> &helper) {
helper(stream);
return stream;
}
// Basic is_container template; specialize to derive from std::true_type for all
// desired container types
template <typename T>
struct is_container
: public std::integral_constant<bool,
detail::has_const_iterator<T>::value &&
detail::has_begin_end<T>::beg_value &&
detail::has_begin_end<T>::end_value> {};
template <typename T, std::size_t N>
struct is_container<T[N]> : std::true_type {};
template <std::size_t N> struct is_container<char[N]> : std::false_type {};
template <typename T> struct is_container<std::valarray<T>> : std::true_type {};
template <typename T1, typename T2>
struct is_container<std::pair<T1, T2>> : std::true_type {};
template <typename... Args>
struct is_container<std::tuple<Args...>> : std::true_type {};
// Default delimiters
template <typename T> struct delimiters<T, char> {
static const delimiters_values<char> values;
};
template <typename T>
const delimiters_values<char> delimiters<T, char>::values = {"[", ", ", "]"};
template <typename T> struct delimiters<T, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T>
const delimiters_values<wchar_t> delimiters<T, wchar_t>::values = {L"[", L", ",
L"]"};
// Delimiters for (multi)set and unordered_(multi)set
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char>
delimiters<::std::set<T, TComp, TAllocator>, char>::values = {"{", ", ",
"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::set<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::set<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<char> delimiters<::std::multiset<T, TComp, TAllocator>,
char>::values = {"{", ", ", "}"};
template <typename T, typename TComp, typename TAllocator>
struct delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename TComp, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::multiset<T, TComp, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t> delimiters<
::std::unordered_set<T, THash, TEqual, TAllocator>, wchar_t>::values = {
L"{", L", ", L"}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
char> {
static const delimiters_values<char> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<char> delimiters<
::std::unordered_multiset<T, THash, TEqual, TAllocator>, char>::values = {
"{", ", ", "}"};
template <typename T, typename THash, typename TEqual, typename TAllocator>
struct delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T, typename THash, typename TEqual, typename TAllocator>
const delimiters_values<wchar_t>
delimiters<::std::unordered_multiset<T, THash, TEqual, TAllocator>,
wchar_t>::values = {L"{", L", ", L"}"};
// Delimiters for pair and tuple
template <typename T1, typename T2> struct delimiters<std::pair<T1, T2>, char> {
static const delimiters_values<char> values;
};
template <typename T1, typename T2>
const delimiters_values<char> delimiters<std::pair<T1, T2>, char>::values = {
"(", ", ", ")"};
template <typename T1, typename T2>
struct delimiters<::std::pair<T1, T2>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename T1, typename T2>
const delimiters_values<wchar_t>
delimiters<::std::pair<T1, T2>, wchar_t>::values = {L"(", L", ", L")"};
template <typename... Args> struct delimiters<std::tuple<Args...>, char> {
static const delimiters_values<char> values;
};
template <typename... Args>
const delimiters_values<char> delimiters<std::tuple<Args...>, char>::values = {
"(", ", ", ")"};
template <typename... Args> struct delimiters<::std::tuple<Args...>, wchar_t> {
static const delimiters_values<wchar_t> values;
};
template <typename... Args>
const delimiters_values<wchar_t>
delimiters<::std::tuple<Args...>, wchar_t>::values = {L"(", L", ", L")"};
// Type-erasing helper class for easy use of custom delimiters.
// Requires TCharTraits = std::char_traits<TChar> and TChar = char or wchar_t,
// and MyDelims needs to be defined for TChar. Usage: "cout <<
// pretty_print::custom_delims<MyDelims>(x)".
struct custom_delims_base {
virtual ~custom_delims_base() {}
virtual std::ostream &stream(::std::ostream &) = 0;
virtual std::wostream &stream(::std::wostream &) = 0;
};
template <typename T, typename Delims>
struct custom_delims_wrapper : custom_delims_base {
custom_delims_wrapper(const T &t_) : t(t_) {}
std::ostream &stream(std::ostream &s) {
return s << print_container_helper<T, char, std::char_traits<char>, Delims>(
t);
}
std::wostream &stream(std::wostream &s) {
return s << print_container_helper<T, wchar_t, std::char_traits<wchar_t>,
Delims>(t);
}
private:
const T &t;
};
template <typename Delims> struct custom_delims {
template <typename Container>
custom_delims(const Container &c)
: base(new custom_delims_wrapper<Container, Delims>(c)) {}
std::unique_ptr<custom_delims_base> base;
};
template <typename TChar, typename TCharTraits, typename Delims>
inline std::basic_ostream<TChar, TCharTraits> &
operator<<(std::basic_ostream<TChar, TCharTraits> &s,
const custom_delims<Delims> &p) {
return p.base->stream(s);
}
// A wrapper for a C-style array given as pointer-plus-size.
// Usage: std::cout << pretty_print_array(arr, n) << std::endl;
template <typename T> struct array_wrapper_n {
typedef const T *const_iterator;
typedef T value_type;
array_wrapper_n(const T *const a, size_t n) : _array(a), _n(n) {}
inline const_iterator begin() const { return _array; }
inline const_iterator end() const { return _array + _n; }
private:
const T *const _array;
size_t _n;
};
// A wrapper for hash-table based containers that offer local iterators to each
// bucket. Usage: std::cout << bucket_print(m, 4) << std::endl; (Prints bucket
// 5 of container m.)
template <typename T> struct bucket_print_wrapper {
typedef typename T::const_local_iterator const_iterator;
typedef typename T::size_type size_type;
const_iterator begin() const { return m_map.cbegin(n); }
const_iterator end() const { return m_map.cend(n); }
bucket_print_wrapper(const T &m, size_type bucket) : m_map(m), n(bucket) {}
private:
const T &m_map;
const size_type n;
};
} // namespace pretty_print
// Global accessor functions for the convenience wrappers
template <typename T>
inline pretty_print::array_wrapper_n<T> pretty_print_array(const T *const a,
size_t n) {
return pretty_print::array_wrapper_n<T>(a, n);
}
template <typename T>
pretty_print::bucket_print_wrapper<T> bucket_print(const T &m,
typename T::size_type n) {
return pretty_print::bucket_print_wrapper<T>(m, n);
}
// Main magic entry point: An overload snuck into namespace std.
// Can we do better?
namespace std {
// Prints a container to the stream using default delimiters
template <typename T, typename TChar, typename TCharTraits>
inline typename enable_if<::pretty_print::is_container<T>::value,
basic_ostream<TChar, TCharTraits> &>::type
operator<<(basic_ostream<TChar, TCharTraits> &stream, const T &container) {
return stream
<< ::pretty_print::print_container_helper<T, TChar, TCharTraits>(
container);
}
} // namespace std
#endif // H_PRETTY_PRINT
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "tensor.h"
#include "tensorview.h"
#include <algorithm>
#include <array>
#include <iostream>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace tv {
template <typename Tarr> bool is_c_stype(const Tarr &arr) {
return bool(arr.flags() & py::array::c_style);
}
template <typename T, int Rank = -1>
TensorView<T, Rank> arrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array");
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
if (Rank >= 0) {
TV_ASSERT_INVALID_ARG(shape.ndim() == Rank, "error");
}
return TensorView<T, Rank>(arr.mutable_data(), shape);
}
template <typename T, int Rank = -1>
TensorView<const T> carrayt2tv(py::array_t<T> arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array");
Shape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
if (Rank >= 0) {
TV_ASSERT_INVALID_ARG(shape.ndim() == Rank, "error");
}
return TensorView<const T, Rank>(arr.data(), shape);
}
template <typename Tarr> tv::DType get_array_tv_dtype(const Tarr &arr) {
switch (arr.dtype().kind()) {
case 'b':
return tv::bool_;
case 'i': {
switch (arr.itemsize()) {
case 1:
return tv::int8;
case 2:
return tv::int16;
case 4:
return tv::int32;
case 8:
return tv::int64;
default:
break;
}
}
case 'u': {
switch (arr.itemsize()) {
case 1:
return tv::uint8;
case 2:
return tv::uint16;
case 4:
return tv::uint32;
case 8:
return tv::uint64;
default:
break;
}
}
case 'f': {
switch (arr.itemsize()) {
case 2:
return tv::float16;
case 4:
return tv::float32;
case 8:
return tv::float64;
default:
break;
}
}
}
TV_THROW_RT_ERR("unknown dtype", arr.dtype().kind(), arr.itemsize());
}
template <typename Tarr> Tensor array2tensor(Tarr &arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array");
TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return tv::from_blob(arr.mutable_data(), shape, get_array_tv_dtype(arr), -1);
}
template <typename T> Tensor arrayt2tensor(py::array_t<T> &arr) {
TV_ASSERT_INVALID_ARG(is_c_stype(arr), "array must be c-contiguous array");
TensorShape shape;
for (int i = 0; i < arr.ndim(); ++i) {
shape.push_back(arr.shape(i));
}
return tv::from_blob(arr.mutable_data(), shape, tv::type_v<T>, -1);
}
template <typename TDType> py::dtype tv_dtype_to_py(TDType d) {
switch (d) {
case float32:
return py::dtype("float32");
case float64:
return py::dtype("float64");
case float16:
return py::dtype("float16");
case int32:
return py::dtype("int32");
case int16:
return py::dtype("int16");
case int8:
return py::dtype("int8");
case int64:
return py::dtype("int64");
case uint32:
return py::dtype("uint32");
case uint16:
return py::dtype("uint16");
case uint8:
return py::dtype("uint8");
case uint64:
return py::dtype("uint64");
case bool_:
return py::dtype("bool_");
default:;
}
TV_THROW_INVALID_ARG("unknown dtype", d);
}
// add template to define function in header
template <typename Ttensor> py::array tensor2array(Ttensor &tensor) {
// you cant call this function during GIL released.
TV_ASSERT_INVALID_ARG(tensor.device() == -1, "must be cpu tensor");
auto shape = tensor.shape();
std::vector<int> shape_vec(shape.begin(), shape.end());
auto dtype = tv_dtype_to_py(tensor.dtype());
// construct py::array will copy content from ptr.
// its expected because we can't transfer ownership from
// c++ tv::Tensor to numpy array when c++ object is deleted.
return py::array(dtype, shape_vec, {}, tensor.raw_data());
}
} // namespace tv
This diff is collapsed.
This diff is collapsed.
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#ifdef TV_CUDA
#include <cuda_runtime_api.h>
#endif
#include <iostream>
namespace tv {
#ifdef TV_CUDA
template <typename TimeT = std::chrono::microseconds> struct CudaContextTimer {
CudaContextTimer() {
cudaDeviceSynchronize();
mCurTime = std::chrono::steady_clock::now();
}
typename TimeT::rep report() {
cudaDeviceSynchronize();
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
#endif
template <typename TimeT = std::chrono::microseconds> struct CPUTimer {
CPUTimer() { mCurTime = std::chrono::steady_clock::now(); }
typename TimeT::rep report() {
auto duration = std::chrono::duration_cast<TimeT>(
std::chrono::steady_clock::now() - mCurTime);
auto res = duration.count();
mCurTime = std::chrono::steady_clock::now();
return res;
}
private:
std::chrono::time_point<std::chrono::steady_clock> mCurTime;
};
} // namespace tv
// Copyright 2019-2020 Yan Yan
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mp_helper.h"
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace tv {
#ifdef TV_CUDA
struct TorchGPU : public tv::GPU {
virtual cudaStream_t getStream() const override {
return at::cuda::getCurrentCUDAStream();
}
};
#endif
namespace detail {
template <typename T> struct TypeToTorchDtypeTraits;
template <> struct TypeToTorchDtypeTraits<int32_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt32;
};
template <> struct TypeToTorchDtypeTraits<int16_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt16;
};
template <> struct TypeToTorchDtypeTraits<int8_t> {
static constexpr decltype(torch::kInt8) value = torch::kInt8;
};
template <> struct TypeToTorchDtypeTraits<int64_t> {
static constexpr decltype(torch::kInt32) value = torch::kInt64;
};
template <> struct TypeToTorchDtypeTraits<uint8_t> {
static constexpr decltype(torch::kInt32) value = torch::kUInt8;
};
template <> struct TypeToTorchDtypeTraits<bool> {
static constexpr decltype(torch::kInt32) value = torch::kBool;
};
template <> struct TypeToTorchDtypeTraits<float> {
static constexpr decltype(torch::kInt32) value = torch::kFloat32;
};
template <> struct TypeToTorchDtypeTraits<double> {
static constexpr decltype(torch::kInt32) value = torch::kFloat64;
};
template <> struct TypeToTorchDtypeTraits<at::Half> {
static constexpr decltype(torch::kInt32) value = torch::kHalf;
};
using all_torch_types_t = std::tuple<float, double, int8_t, int16_t, int32_t,
int64_t, uint8_t, bool, at::Half>;
} // namespace detail
template <typename T>
constexpr decltype(torch::kInt32) torch_type_v =
detail::TypeToTorchDtypeTraits<T>::value;
template <class... Ts, typename F>
void dispatch_torch(at::ScalarType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true;
tv::mp_for_each<mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (detail::TypeToTorchDtypeTraits<decltype(I)>::value == t) {
std::forward<F>(f)(decltype(I)());
notFound = false;
}
});
if (notFound) {
std::stringstream ss;
tv::mp_for_each<mp_list<Ts...>>([=, &ss](auto I) {
ss << tv::detail::TypeToString<decltype(I)>::value << " ";
});
TV_THROW_RT_ERR("unknown type", t, ", available:", ss.str());
}
}
template <class T> struct DispatchTorch;
template <template <class...> class T, class... Args>
struct DispatchTorch<T<Args...>> {
template <typename F> inline void operator()(at::ScalarType t, F &&f) {
return dispatch_torch<Args...>(t, std::forward<F>(f));
}
};
template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
DispatchTorch<detail::all_torch_types_t>()(tensor.scalar_type(), [&](auto I) {
using Ttensor = decltype(I);
constexpr bool val = std::is_same<std::remove_cv_t<T>, Ttensor>::value;
TV_ASSERT_RT_ERR(val, "error");
});
}
template <typename T, int Rank = -1,
template <class> class PtrTraits = DefaultPtrTraits,
typename Tindex = int>
TensorView<T, Rank, PtrTraits, Tindex> torch2tv(const torch::Tensor &tensor) {
using tv_shape_t =
typename TensorView<T, Rank, PtrTraits, Tindex>::tv_shape_t;
check_torch_dtype<T>(tensor);
// TODO stride
if (Rank > 0) {
TV_ASSERT_INVALID_ARG(tensor.dim() == Rank, "error");
}
tv_shape_t shape;
for (auto i : tensor.sizes()) {
shape.push_back(i);
}
return tv::TensorView<T, Rank, PtrTraits, Tindex>(
tensor.data_ptr<std::remove_const_t<T>>(), shape);
}
namespace detail {
template <> struct TypeToString<at::Half> {
static constexpr const char *value = "half";
};
} // namespace detail
} // namespace tv
\ No newline at end of file
......@@ -13,18 +13,18 @@
// limitations under the License.
#pragma once
#include <spconv/mp_helper.h>
#include <tensorview/mp_helper.h>
#include <tensorview/tensorview.h>
#include <ATen/ATen.h>
#include <torch/script.h>
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
#include <ATen/cuda/CUDAContext.h>
#endif
namespace tv {
#ifdef SPCONV_CUDA
#ifdef TV_CUDA
struct TorchGPU : public tv::GPU {
virtual cudaStream_t getStream() const override {
return at::cuda::getCurrentCUDAStream();
......@@ -103,10 +103,10 @@ template <> struct TypeToString<at::Half> {
};
} // namespace detail
template <class... Ts, typename F>
void torch_dispatch(at::ScalarType t, F &&f) {
void dispatch_torch(at::ScalarType t, F &&f) {
static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
bool notFound = true;
spconv::mp_for_each<spconv::mp_list<Ts...>>([=, &notFound, &f](auto I) {
spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &notFound, &f](auto I) {
if (torch_type_v<decltype(I)> == t) {
std::forward<F>(f)(decltype(I)());
notFound = false;
......@@ -114,7 +114,7 @@ void torch_dispatch(at::ScalarType t, F &&f) {
});
if (notFound) {
std::stringstream ss;
spconv::mp_for_each<spconv::mp_list<Ts...>>([=, &ss](auto I) {
spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &ss](auto I) {
ss << tv::detail::TypeToString<decltype(I)>::value << " ";
});
TV_THROW_RT_ERR("unknown type", t, ", available: ", ss.str());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment