torch_utils.h 3.67 KB
Newer Older
traveller59's avatar
traveller59 committed
1
// Copyright 2019 Yan Yan
tusimple's avatar
tusimple committed
2
//
traveller59's avatar
traveller59 committed
3
4
5
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
tusimple's avatar
tusimple committed
6
//
traveller59's avatar
traveller59 committed
7
//     http://www.apache.org/licenses/LICENSE-2.0
tusimple's avatar
tusimple committed
8
//
traveller59's avatar
traveller59 committed
9
10
11
12
13
14
15
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16
#include <tensorview/mp_helper.h>
traveller59's avatar
traveller59 committed
17
#include <tensorview/tensorview.h>
tusimple's avatar
tusimple committed
18

traveller59's avatar
traveller59 committed
19
#include <ATen/ATen.h>
tusimple's avatar
tusimple committed
20
#include <torch/script.h>
21
#ifdef TV_CUDA
traveller59's avatar
traveller59 committed
22
#include <ATen/cuda/CUDAContext.h>
traveller59's avatar
traveller59 committed
23
#endif
traveller59's avatar
traveller59 committed
24
25

namespace tv {
tusimple's avatar
tusimple committed
26

27
#ifdef TV_CUDA
tusimple's avatar
tusimple committed
28
struct TorchGPU : public tv::GPU {
29
30
  virtual cudaStream_t getStream() const override {
    return at::cuda::getCurrentCUDAStream();
traveller59's avatar
traveller59 committed
31
32
  }
};
traveller59's avatar
traveller59 committed
33
#endif
traveller59's avatar
traveller59 committed
34
template <typename T> void check_torch_dtype(const torch::Tensor &tensor) {
tusimple's avatar
tusimple committed
35
  switch (tensor.scalar_type()) {
traveller59's avatar
traveller59 committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  case at::ScalarType::Double: {
    auto val = std::is_same<std::remove_const_t<T>, double>::value;
    TV_ASSERT_RT_ERR(val, "error");
    break;
  }
  case at::ScalarType::Float: {
    auto val = std::is_same<std::remove_const_t<T>, float>::value;
    TV_ASSERT_RT_ERR(val, "error");
    break;
  }
  case at::ScalarType::Int: {
    auto val = std::is_same<std::remove_const_t<T>, int>::value;
    TV_ASSERT_RT_ERR(val, "error");
    break;
  }
  case at::ScalarType::Half: {
    auto val = std::is_same<std::remove_const_t<T>, at::Half>::value;
    TV_ASSERT_RT_ERR(val, "error");
    break;
  }
56
57
58
59
60
  case at::ScalarType::Long: {
    auto val = std::is_same<std::remove_const_t<T>, long>::value;
    TV_ASSERT_RT_ERR(val, "error");
    break;
  }
traveller59's avatar
traveller59 committed
61
62
63
64
  default:
    TV_ASSERT_RT_ERR(false, "error");
  }
}
tusimple's avatar
tusimple committed
65
66
namespace detail {
template <typename T> struct TypeToTorchDtypeTraits;
traveller59's avatar
traveller59 committed
67

tusimple's avatar
tusimple committed
68
69
70
template <> struct TypeToTorchDtypeTraits<int32_t> {
  static constexpr decltype(torch::kInt32) value = torch::kInt32;
};
71

tusimple's avatar
tusimple committed
72
73
74
template <> struct TypeToTorchDtypeTraits<int64_t> {
  static constexpr decltype(torch::kInt32) value = torch::kInt64;
};
75

tusimple's avatar
tusimple committed
76
77
78
79
80
81
82
83
84
template <> struct TypeToTorchDtypeTraits<float> {
  static constexpr decltype(torch::kInt32) value = torch::kFloat32;
};
template <> struct TypeToTorchDtypeTraits<double> {
  static constexpr decltype(torch::kInt32) value = torch::kFloat64;
};
template <> struct TypeToTorchDtypeTraits<at::Half> {
  static constexpr decltype(torch::kInt32) value = torch::kHalf;
};
85

tusimple's avatar
tusimple committed
86
} // namespace detail
87

traveller59's avatar
traveller59 committed
88
template <typename T>
tusimple's avatar
tusimple committed
89
90
91
92
constexpr decltype(torch::kInt32) torch_type_v =
    detail::TypeToTorchDtypeTraits<T>::value;

template <typename T> tv::TensorView<T> torch2tv(const torch::Tensor &tensor) {
traveller59's avatar
traveller59 committed
93
94
95
96
97
  check_torch_dtype<T>(tensor);
  tv::Shape shape;
  for (auto i : tensor.sizes()) {
    shape.push_back(i);
  }
traveller59's avatar
traveller59 committed
98
  return tv::TensorView<T>(tensor.data_ptr<std::remove_const_t<T>>(), shape);
traveller59's avatar
traveller59 committed
99
}
tusimple's avatar
tusimple committed
100
101
102
103
104
105
namespace detail {
template <> struct TypeToString<at::Half> {
  static constexpr const char *value = "half";
};
} // namespace detail
template <class... Ts, typename F>
106
void dispatch_torch(at::ScalarType t, F &&f) {
tusimple's avatar
tusimple committed
107
108
  static_assert(sizeof...(Ts) > 0, "you need to provide at least one type");
  bool notFound = true;
109
  spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &notFound, &f](auto I) {
tusimple's avatar
tusimple committed
110
111
112
113
114
115
116
    if (torch_type_v<decltype(I)> == t) {
      std::forward<F>(f)(decltype(I)());
      notFound = false;
    }
  });
  if (notFound) {
    std::stringstream ss;
117
    spconv::tv::mp_for_each<spconv::mp_list<Ts...>>([=, &ss](auto I) {
tusimple's avatar
tusimple committed
118
119
120
121
122
123
      ss << tv::detail::TypeToString<decltype(I)>::value << " ";
    });
    TV_THROW_RT_ERR("unknown type", t, ", available: ", ss.str());
  }
}

traveller59's avatar
traveller59 committed
124
} // namespace tv