Unverified Commit 750e38f5 authored by peterjc123's avatar peterjc123 Committed by GitHub
Browse files

Fixes function signatures for torch::nn::Functional (#2463)

parent cf534fda
......@@ -14,22 +14,22 @@ namespace modelsimpl {
// TODO here torch::relu_ and torch::adaptive_avg_pool2d wrapped in
// torch::nn::Fuctional don't work. so keeping these for now
inline torch::Tensor& relu_(torch::Tensor x) {
return torch::relu_(x);
inline torch::Tensor& relu_(const torch::Tensor& x) {
return x.relu_();
}
inline torch::Tensor relu6_(torch::Tensor x) {
inline torch::Tensor& relu6_(const torch::Tensor& x) {
return x.clamp_(0, 6);
}
inline torch::Tensor adaptive_avg_pool2d(
torch::Tensor x,
const torch::Tensor& x,
torch::ExpandingArray<2> output_size) {
return torch::adaptive_avg_pool2d(x, output_size);
}
inline torch::Tensor max_pool2d(
torch::Tensor x,
const torch::Tensor& x,
torch::ExpandingArray<2> kernel_size,
torch::ExpandingArray<2> stride) {
return torch::max_pool2d(x, kernel_size, stride);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment