modelsimpl.h 1.19 KB
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4

Shahriar's avatar
Shahriar committed
5
6
7
8
9
10
11
namespace vision {
namespace models {
namespace modelsimpl {

// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now

12
13
inline torch::Tensor& relu_(const torch::Tensor& x) {
  return x.relu_();
Shahriar's avatar
Shahriar committed
14
15
}

16
inline torch::Tensor& relu6_(const torch::Tensor& x) {
17
  return x.clamp_(0, 6);
Shahriar's avatar
Shahriar committed
18
19
20
}

inline torch::Tensor adaptive_avg_pool2d(
21
    const torch::Tensor& x,
Shahriar's avatar
Shahriar committed
22
23
24
25
26
    torch::ExpandingArray<2> output_size) {
  return torch::adaptive_avg_pool2d(x, output_size);
}

inline torch::Tensor max_pool2d(
27
    const torch::Tensor& x,
Shahriar's avatar
Shahriar committed
28
29
30
31
32
33
34
35
36
    torch::ExpandingArray<2> kernel_size,
    torch::ExpandingArray<2> stride) {
  return torch::max_pool2d(x, kernel_size, stride);
}

inline bool double_compare(double a, double b) {
  return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};

37
38
inline void deprecation_warning() {
  TORCH_WARN_ONCE(
39
40
      "The vision::models namespace is deprecated since 0.12 and will be "
      "removed in 0.14. We recommend using Torch Script instead: "
41
42
43
      "https://pytorch.org/tutorials/advanced/cpp_export.html");
}

Shahriar's avatar
Shahriar committed
44
45
46
} // namespace modelsimpl
} // namespace models
} // namespace vision