"maint/vscode:/vscode.git/clone" did not exist on "11456de25f06a2fef3cf8e13c1e284bd81df5996"
Commit 5fc48e77 authored by charlie's avatar charlie
Browse files

Merge branch 'refactor_dynamic_compute' of...

Merge branch 'refactor_dynamic_compute' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_model_test
parents 3a4d36cf a9c0252a
...@@ -3647,6 +3647,16 @@ def neg_test(): ...@@ -3647,6 +3647,16 @@ def neg_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def neg_dynamic_test():
x = helper.make_tensor_value_info('0', TensorProto.INT64, [None, 3])
y = helper.make_tensor_value_info('1', TensorProto.INT64, [None, 3])
node = onnx.helper.make_node('Neg', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test @onnx_test
def nms_test(): def nms_test():
b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [1, 6, 4]) b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [1, 6, 4])
...@@ -5280,6 +5290,20 @@ def sinh_test(): ...@@ -5280,6 +5290,20 @@ def sinh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def sinh_dynamic_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [None])
node = onnx.helper.make_node(
'Sinh',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])
@onnx_test @onnx_test
def size_float_test(): def size_float_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3, 4])
......
...@@ -856,8 +856,7 @@ TEST_CASE(conv_autopad_same_test) ...@@ -856,8 +856,7 @@ TEST_CASE(conv_autopad_same_test)
auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); auto l0 = mm->add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding = {1, 1, 1, 1}; op.padding = {1, 1, 1, 1};
op.padding_mode = migraphx::op::padding_mode_t::same;
mm->add_instruction(op, l0, l1); mm->add_instruction(op, l0, l1);
auto prog = optimize_onnx("conv_autopad_same_test.onnx"); auto prog = optimize_onnx("conv_autopad_same_test.onnx");
...@@ -1034,15 +1033,11 @@ TEST_CASE(conv_dynamic_batch_same_upper) ...@@ -1034,15 +1033,11 @@ TEST_CASE(conv_dynamic_batch_same_upper)
auto l0 = mm->add_parameter( auto l0 = mm->add_parameter(
"0", {migraphx::shape::float_type, {{1, 10, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}}); "0", {migraphx::shape::float_type, {{1, 10, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}});
auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}}); auto l1 = mm->add_parameter("1", {migraphx::shape::float_type, {1, 3, 3, 3}});
auto c0 = auto c0 = mm->add_instruction(
mm->add_instruction(migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"padding", {1, 1, 1, 1}}, {{"padding", {1, 1, 1, 1}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
{"stride", {1, 1}}, l0,
{"dilation", {1, 1}}, l1);
{"padding_mode", migraphx::op::padding_mode_t::same},
{"use_dynamic_same_auto_pad", false}}),
l0,
l1);
mm->add_return({c0}); mm->add_return({c0});
migraphx::onnx_options options; migraphx::onnx_options options;
...@@ -1064,8 +1059,7 @@ TEST_CASE(conv_dynamic_img_same_upper) ...@@ -1064,8 +1059,7 @@ TEST_CASE(conv_dynamic_img_same_upper)
{{"padding", {0, 0}}, {{"padding", {0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"padding_mode", migraphx::op::padding_mode_t::same_upper}, {"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
{"use_dynamic_same_auto_pad", true}}),
l0, l0,
l1); l1);
mm->add_return({c0}); mm->add_return({c0});
...@@ -1089,8 +1083,7 @@ TEST_CASE(conv_dynamic_kernel_same_lower) ...@@ -1089,8 +1083,7 @@ TEST_CASE(conv_dynamic_kernel_same_lower)
{{"padding", {0, 0}}, {{"padding", {0, 0}},
{"stride", {1, 1}}, {"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"padding_mode", migraphx::op::padding_mode_t::same_lower}, {"padding_mode", migraphx::op::padding_mode_t::same_lower}}),
{"use_dynamic_same_auto_pad", true}}),
l0, l0,
l1); l1);
mm->add_return({c0}); mm->add_return({c0});
...@@ -3483,6 +3476,21 @@ TEST_CASE(neg_test) ...@@ -3483,6 +3476,21 @@ TEST_CASE(neg_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(neg_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int64_type, {{1, 10, 0}, {3, 3, 0}}};
auto input = mm->add_parameter("0", s);
auto ret = mm->add_instruction(migraphx::make_op("neg"), input);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
auto prog = migraphx::parse_onnx("neg_dynamic_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nms_test) TEST_CASE(nms_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5206,6 +5214,29 @@ TEST_CASE(sinh_test) ...@@ -5206,6 +5214,29 @@ TEST_CASE(sinh_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sinh_dynamic_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape::dynamic_dimension dd{1, 10, 0};
std::vector<migraphx::shape::dynamic_dimension> dyn_dims;
dyn_dims.push_back(dd);
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dyn_dims});
mm->add_instruction(migraphx::make_op("sinh"), input);
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
auto prog = parse_onnx("sinh_dynamic_test.onnx", options);
auto* mm_onnx = prog.get_main_module();
auto last_ins = std::prev(mm_onnx->end());
if(last_ins->name() == "@return")
{
mm->remove_instruction(last_ins);
}
EXPECT(p == prog);
}
TEST_CASE(size_float_test) TEST_CASE(size_float_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -261,8 +261,7 @@ TEST_CASE(convolution_shape) ...@@ -261,8 +261,7 @@ TEST_CASE(convolution_shape)
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"stride", {1, 1}}, {{"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"padding_mode", migraphx::op::padding_mode_t::same_upper}, {"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
{"use_dynamic_same_auto_pad", true}}),
input_dyn_shape, input_dyn_shape,
weights_shape); weights_shape);
...@@ -275,8 +274,7 @@ TEST_CASE(convolution_shape) ...@@ -275,8 +274,7 @@ TEST_CASE(convolution_shape)
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"stride", {1, 1}}, {{"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"padding_mode", migraphx::op::padding_mode_t::same_upper}, {"padding_mode", migraphx::op::padding_mode_t::same_upper}}),
{"use_dynamic_same_auto_pad", true}}),
input_dyn_shape, input_dyn_shape,
weights_shape); weights_shape);
...@@ -290,8 +288,7 @@ TEST_CASE(convolution_shape) ...@@ -290,8 +288,7 @@ TEST_CASE(convolution_shape)
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"stride", {1, 1}}, {{"stride", {1, 1}},
{"dilation", {1, 1}}, {"dilation", {1, 1}},
{"padding_mode", migraphx::op::padding_mode_t::same_lower}, {"padding_mode", migraphx::op::padding_mode_t::same_lower}}),
{"use_dynamic_same_auto_pad", true}}),
input_dyn_shape, input_dyn_shape,
weights_shape); weights_shape);
} }
......
This diff is collapsed.
...@@ -327,10 +327,9 @@ migraphx::program create_conv() ...@@ -327,10 +327,9 @@ migraphx::program create_conv()
mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data); mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding = {1, 1, 1, 1};
op.padding = {1, 1, 1, 1}; op.stride = {1, 1};
op.stride = {1, 1}; op.dilation = {1, 1};
op.dilation = {1, 1};
auto l2 = auto l2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
mm->add_instruction(op, l0, l2); mm->add_instruction(op, l0, l2);
...@@ -406,11 +405,10 @@ TEST_CASE(depthwiseconv_test) ...@@ -406,11 +405,10 @@ TEST_CASE(depthwiseconv_test)
mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data); mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
migraphx::op::convolution op; migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same; op.padding = {1, 1};
op.padding = {1, 1}; op.stride = {1, 1};
op.stride = {1, 1}; op.dilation = {1, 1};
op.dilation = {1, 1}; op.group = 3;
op.group = 3;
auto l3 = auto l3 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {3, 2, 0, 1}}}), l1);
auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3); auto l4 = mm->add_instruction(migraphx::make_op("contiguous"), l3);
......
...@@ -37,10 +37,7 @@ struct quant_conv_default_mode : verify_program<quant_conv_default_mode> ...@@ -37,10 +37,7 @@ struct quant_conv_default_mode : verify_program<quant_conv_default_mode>
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p; return p;
} }
}; };
...@@ -37,10 +37,7 @@ struct quant_conv_int8x4_default : verify_program<quant_conv_int8x4_default> ...@@ -37,10 +37,7 @@ struct quant_conv_int8x4_default : verify_program<quant_conv_int8x4_default>
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::same},
pa,
pc);
return p; return p;
} }
}; };
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct quant_conv_valid_mode : verify_program<quant_conv_valid_mode>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}, migraphx::op::valid},
pa,
pc);
return p;
}
};
...@@ -34,7 +34,7 @@ struct test_elu : verify_program<test_elu> ...@@ -34,7 +34,7 @@ struct test_elu : verify_program<test_elu>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 1.0}}), x); mm->add_instruction(migraphx::make_op("elu", {{"alpha", 0.8}}), x);
return p; return p;
} }
}; };
...@@ -34,7 +34,7 @@ struct test_leaky_relu : verify_program<test_leaky_relu> ...@@ -34,7 +34,7 @@ struct test_leaky_relu : verify_program<test_leaky_relu>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.01}}), x); mm->add_instruction(migraphx::make_op("leaky_relu", {{"alpha", 0.41}}), x);
return p; return p;
} }
}; };
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/normalize_attributes.hpp> #include <migraphx/normalize_attributes.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
...@@ -94,6 +95,46 @@ bool has_finalize(const operation& x); ...@@ -94,6 +95,46 @@ bool has_finalize(const operation& x);
#else #else
struct dyn_output
{
// original shape from the instruction
shape ins_shape;
// shape computed at eval time using input arguments
shape computed_shape;
};
/**
* Handle dynamic and static shape at evaluation time.
* If converted to shape type, returns original ins_shape.
* If converted to dyn_output type, will compute an output shape using the input arguments.
*/
template <class F>
struct compute_output_shape
{
F ins_inputs;
operator dyn_output() const
{
return ins_inputs([](const auto& x, shape ins_shape, const std::vector<argument>& inputs) {
if(ins_shape.dynamic())
return dyn_output{ins_shape, compute_shape(x, to_shapes(inputs))};
return dyn_output{ins_shape, ins_shape};
});
}
operator shape() const
{
return ins_inputs(
[](const auto&, shape ins_shape, const std::vector<argument>&) { return ins_shape; });
}
};
template <class F>
compute_output_shape<F> make_compute_output_shape(F f)
{
return {f};
}
namespace detail { namespace detail {
namespace operation_operators { namespace operation_operators {
...@@ -199,9 +240,12 @@ auto compute_op(rank<1>, ...@@ -199,9 +240,12 @@ auto compute_op(rank<1>,
context& ctx, context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(auto_any_cast(ctx), output_shape, input)) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output_shape, input)),
input))
{ {
return x.compute(auto_any_cast(ctx), output_shape, input); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -220,9 +264,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto ...@@ -220,9 +264,9 @@ compute_op(const T& x, context& ctx, const shape& output_shape, const std::vecto
template <class T> template <class T>
auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input) auto compute_op(rank<1>, const T& x, const shape& output_shape, const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input)) -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input))
{ {
return x.compute(output_shape, input); return x.compute(make_compute_output_shape(pack(x, output_shape, input)), input);
} }
template <class T> template <class T>
...@@ -244,9 +288,11 @@ auto compute_op(rank<1>, ...@@ -244,9 +288,11 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -278,9 +324,17 @@ auto compute_op(rank<4>, ...@@ -278,9 +324,17 @@ auto compute_op(rank<4>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(auto_any_cast(ctx), output, inputs, module_args, f)) F f) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f))
{ {
return x.compute(auto_any_cast(ctx), output, inputs, module_args, f); return x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs,
module_args,
f);
} }
template <class T, class F> template <class T, class F>
...@@ -290,9 +344,11 @@ auto compute_op(rank<3>, ...@@ -290,9 +344,11 @@ auto compute_op(rank<3>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>& module_args, const std::vector<module_ref>& module_args,
F f) -> decltype(x.compute(output, inputs, module_args, f)) F f)
-> decltype(
x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f))
{ {
return x.compute(output, inputs, module_args, f); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs, module_args, f);
} }
template <class T, class F> template <class T, class F>
...@@ -302,9 +358,10 @@ auto compute_op(rank<2>, ...@@ -302,9 +358,10 @@ auto compute_op(rank<2>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(output, inputs)) F)
-> decltype(x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs))
{ {
return x.compute(output, inputs); return x.compute(make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -314,9 +371,12 @@ auto compute_op(rank<1>, ...@@ -314,9 +371,12 @@ auto compute_op(rank<1>,
const shape& output, const shape& output,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
const std::vector<module_ref>&, const std::vector<module_ref>&,
F) -> decltype(x.compute(auto_any_cast(ctx), output, inputs)) F) -> decltype(x.compute(auto_any_cast(ctx),
make_compute_output_shape(pack(x, output, inputs)),
inputs))
{ {
return x.compute(auto_any_cast(ctx), output, inputs); return x.compute(
auto_any_cast(ctx), make_compute_output_shape(pack(x, output, inputs)), inputs);
} }
template <class T, class F> template <class T, class F>
...@@ -348,7 +408,8 @@ auto is_context_free_op(rank<1>, ...@@ -348,7 +408,8 @@ auto is_context_free_op(rank<1>,
const T& x, const T& x,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& input) const std::vector<argument>& input)
-> decltype(x.compute(output_shape, input), std::true_type{}); -> decltype(x.compute(make_compute_output_shape(pack(x, output_shape, input)), input),
std::true_type{});
template <class T> template <class T>
auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&) auto is_context_free_op(rank<0>, const T&, const shape&, const std::vector<argument>&)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment