modelsimpl.h 1.02 KB
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
#ifndef MODELSIMPL_H
#define MODELSIMPL_H

#include <torch/torch.h>

6
7
8
9
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif

Shahriar's avatar
Shahriar committed
10
11
12
13
14
15
16
17
18
19
20
21
namespace vision {
namespace models {
namespace modelsimpl {

// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now

inline torch::Tensor& relu_(torch::Tensor x) {
  return torch::relu_(x);
}

inline torch::Tensor relu6_(torch::Tensor x) {
22
  return x.clamp_(0, 6);
Shahriar's avatar
Shahriar committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
}

inline torch::Tensor adaptive_avg_pool2d(
    torch::Tensor x,
    torch::ExpandingArray<2> output_size) {
  return torch::adaptive_avg_pool2d(x, output_size);
}

inline torch::Tensor max_pool2d(
    torch::Tensor x,
    torch::ExpandingArray<2> kernel_size,
    torch::ExpandingArray<2> stride) {
  return torch::max_pool2d(x, kernel_size, stride);
}

inline bool double_compare(double a, double b) {
  return double(std::abs(a - b)) < std::numeric_limits<double>::epsilon();
};

} // namespace modelsimpl
} // namespace models
} // namespace vision

#endif // MODELSIMPL_H