Unverified Commit 750e38f5 authored by peterjc123's avatar peterjc123 Committed by GitHub
Browse files

Fixes function signatures for torch::nn::Functional (#2463)

parent cf534fda
...@@ -14,22 +14,22 @@ namespace modelsimpl { ...@@ -14,22 +14,22 @@ namespace modelsimpl {
// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in // TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now // torch::nn::Fuctional don't work. so keeping these for now
inline torch::Tensor& relu_(torch::Tensor x) { inline torch::Tensor& relu_(const torch::Tensor& x) {
return torch::relu_(x); return x.relu_();
} }
inline torch::Tensor relu6_(torch::Tensor x) { inline torch::Tensor& relu6_(const torch::Tensor& x) {
return x.clamp_(0, 6); return x.clamp_(0, 6);
} }
inline torch::Tensor adaptive_avg_pool2d( inline torch::Tensor adaptive_avg_pool2d(
torch::Tensor x, const torch::Tensor& x,
torch::ExpandingArray<2> output_size) { torch::ExpandingArray<2> output_size) {
return torch::adaptive_avg_pool2d(x, output_size); return torch::adaptive_avg_pool2d(x, output_size);
} }
inline torch::Tensor max_pool2d( inline torch::Tensor max_pool2d(
torch::Tensor x, const torch::Tensor& x,
torch::ExpandingArray<2> kernel_size, torch::ExpandingArray<2> kernel_size,
torch::ExpandingArray<2> stride) { torch::ExpandingArray<2> stride) {
return torch::max_pool2d(x, kernel_size, stride); return torch::max_pool2d(x, kernel_size, stride);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment