#pragma once #include namespace vision { namespace models { namespace modelsimpl { // TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in // torch::nn::Fuctional don't work. so keeping these for now inline torch::Tensor& relu_(const torch::Tensor& x) { return x.relu_(); } inline torch::Tensor& relu6_(const torch::Tensor& x) { return x.clamp_(0, 6); } inline torch::Tensor adaptive_avg_pool2d( const torch::Tensor& x, torch::ExpandingArray<2> output_size) { return torch::adaptive_avg_pool2d(x, output_size); } inline torch::Tensor max_pool2d( const torch::Tensor& x, torch::ExpandingArray<2> kernel_size, torch::ExpandingArray<2> stride) { return torch::max_pool2d(x, kernel_size, stride); } inline bool double_compare(double a, double b) { return double(std::abs(a - b)) < std::numeric_limits::epsilon(); }; inline void deprecation_warning() { TORCH_WARN_ONCE( "The vision::models namespace is deprecated since 0.12 and will be " "removed in 0.14. We recommend using Torch Script instead: " "https://pytorch.org/tutorials/advanced/cpp_export.html"); } } // namespace modelsimpl } // namespace models } // namespace vision