Unverified Commit 3ec62e53 authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #325 from ROCmSoftwarePlatform/op_capture

add capture operator
parents de1d5919 47a00c6a
...@@ -30,23 +30,29 @@ struct binary : op_name<Derived> ...@@ -30,23 +30,29 @@ struct binary : op_name<Derived>
argument result{output_shape}; argument result{output_shape};
auto s1 = args[0].get_shape(); auto s1 = args[0].get_shape();
auto s2 = args[1].get_shape(); auto s2 = args[1].get_shape();
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) { if(s1 == s2 and s1.packed())
if(s1 == s2 and input1.get_shape().packed() and input2.get_shape().packed()) {
{ shape std_shape{s1.type(), s1.lens()};
argument std_result{std_shape, result.data()};
argument std_arg0{std_shape, args[0].data()};
argument std_arg1{std_shape, args[1].data()};
visit_all(std_result, std_arg0, std_arg1)([&](auto output, auto input1, auto input2) {
std::transform(input1.begin(), std::transform(input1.begin(),
input1.end(), input1.end(),
input2.begin(), input2.begin(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
} });
else }
{ else
{
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
shape_for_each(output.get_shape(), [&](const auto& idx) { shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()( output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end())); input1(idx.begin(), idx.end()), input2(idx.begin(), idx.end()));
}); });
} });
}); }
return result; return result;
} }
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#define MIGRAPHX_GUARD_OPERATORS_CAPTURE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct capture
{
std::size_t ins_index;
std::function<void(std::size_t ins_index, std::vector<argument>)> f{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ins_index, "ins_index"));
}
std::string name() const { return "capture"; }
shape compute_shape(std::vector<shape> inputs) const { return inputs.front(); }
argument compute(const shape&, std::vector<argument> args) const
{
if(f)
{
f(ins_index, args);
}
else
{
MIGRAPHX_THROW("CAPTURE: callback function is not callable!");
}
return args.front();
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -27,26 +27,34 @@ struct unary : op_name<Derived> ...@@ -27,26 +27,34 @@ struct unary : op_name<Derived>
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
result.visit([&](auto output) { auto in_shape = args[0].get_shape();
args[0].visit([&](auto input) { if(in_shape.packed())
if(input.get_shape().packed()) {
{ shape std_in_shape{in_shape.type(), in_shape.lens()};
shape std_out_shape{output_shape.type(), output_shape.lens()};
argument arg_in{std_in_shape, args[0].data()};
argument arg_out{std_out_shape, result.data()};
arg_out.visit([&](auto output) {
arg_in.visit([&](auto input) {
std::transform(input.begin(), std::transform(input.begin(),
input.end(), input.end(),
output.begin(), output.begin(),
static_cast<const Derived&>(*this).apply()); static_cast<const Derived&>(*this).apply());
return result;
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
}); });
return result;
}); });
}); }
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = static_cast<const Derived&>(*this).apply()(
input(idx.begin(), idx.end()));
});
});
});
}
return result; return result;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <migraphx/op/batch_norm.hpp> #include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/op/broadcast.hpp> #include <migraphx/op/broadcast.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp> #include <migraphx/op/concat.hpp>
......
...@@ -126,6 +126,9 @@ struct program ...@@ -126,6 +126,9 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
private: private:
void assign(const program& p); void assign(const program& p);
......
...@@ -15,6 +15,14 @@ struct program; ...@@ -15,6 +15,14 @@ struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names); void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog); void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized
// to int8
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names);
void capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -112,7 +112,8 @@ void program::assign(const program& p) ...@@ -112,7 +112,8 @@ void program::assign(const program& p)
{ {
impl->instructions.clear(); impl->instructions.clear();
} }
impl->ctx = p.impl->ctx; impl->ctx = p.impl->ctx;
int8_quant_params = p.int8_quant_params;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m) ...@@ -156,6 +156,7 @@ PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::program>(m, "program") py::class_<migraphx::program>(m, "program")
.def("clone", [](migraphx::program& p) { return *(new migraphx::program(p)); })
.def("get_parameter_shapes", &migraphx::program::get_parameter_shapes) .def("get_parameter_shapes", &migraphx::program::get_parameter_shapes)
.def("get_shape", &migraphx::program::get_shape) .def("get_shape", &migraphx::program::get_shape)
.def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); }) .def("compile", [](migraphx::program& p, const migraphx::target& t) { p.compile(t); })
...@@ -186,6 +187,11 @@ PYBIND11_MODULE(migraphx, m) ...@@ -186,6 +187,11 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names); migraphx::quantize(p, ins_names);
}); });
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
migraphx::capture_arguments(p, ins_names);
});
m.def("capture_arguments", [](migraphx::program& p) { migraphx::capture_arguments(p); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -3,32 +3,53 @@ ...@@ -3,32 +3,53 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/capture.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <utility> #include <utility>
#include <iomanip>
#include <fstream>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog, instruction_ref insert_quant_ins(program& prog,
instruction_ref& ins, instruction_ref& ins,
shape::type_t type, shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16) std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{ {
if(map_fp16.count(ins) > 0) if(map_ins.count(ins) > 0)
{ {
return map_fp16[ins]; return map_ins[ins];
}
if(ins->name() == "undefined")
{
return ins;
} }
assert(ins->get_shape().type() == shape::float_type || assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type); ins->get_shape().type() == shape::double_type ||
instruction_ref ins_fp16{}; ins->get_shape().type() == shape::int32_type);
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins); instruction_ref quant_ins{};
map_fp16[ins] = ins_fp16; quant_ins = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_ins[ins] = quant_ins;
return ins_fp16; return quant_ins;
} }
// This function is to convert any instructions specified in the input
// from double or float to float16 by inserting a convert operator.
// For the conversion, there could be cases of overflowing, but it
// is very rare in the area of deeping learning, so we just do a
// truncate of the input to get the fp16.
void quantize(program& prog, const std::vector<std::string>& ins_names) void quantize(program& prog, const std::vector<std::string>& ins_names)
{ {
std::unordered_map<instruction_ref, instruction_ref> map_fp16; std::unordered_map<instruction_ref, instruction_ref> map_fp16;
...@@ -59,7 +80,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -59,7 +80,7 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
} }
else else
{ {
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16); input_fp16 = insert_quant_ins(prog, input, shape::half_type, map_fp16);
} }
converted_inputs.push_back(input_fp16); converted_inputs.push_back(input_fp16);
} }
...@@ -79,21 +100,13 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -79,21 +100,13 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
auto ins_shape = compute_shape(op, converted_inputs); auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type) if(ins_shape.type() != orig_type)
{ {
// insert another convert instruction to convert it back // check the dead code case to avoid assert
if(ins == std::prev(prog.end())) bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{ {
prog.add_instruction(op::convert{orig_type}, ins); prog.replace_instruction(ins, ins_orig_type);
}
else
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
}
} }
} }
...@@ -103,5 +116,80 @@ void quantize(program& prog, const std::vector<std::string>& ins_names) ...@@ -103,5 +116,80 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
void quantize(program& prog) { quantize(prog, {"all"}); } void quantize(program& prog) { quantize(prog, {"all"}); }
// For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift
void capture_arguments(program& prog,
const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func)
{
size_t num_quant_params = 0;
// the int8 quantization only support dot and convolution
std::vector<std::string> op_names = {"dot", "convolution"};
if(!std::all_of(ins_names.begin(), ins_names.end(), [&](auto name) {
return std::find(op_names.begin(), op_names.end(), name) != op_names.end();
}))
{
MIGRAPHX_THROW("CAPTURE_ARGUMENTS: input operator is not supported");
}
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(prog))
{
if(not contains(ins_names, ins->name()))
{
continue;
}
auto inputs = ins->inputs();
std::vector<instruction_ref> new_args;
for(auto input : inputs)
{
instruction_ref new_ins{};
if(ins_map.count(input) > 0)
{
new_ins = ins_map[input];
}
else
{
new_ins = prog.insert_instruction(
std::next(input), op::capture{num_quant_params++, func}, input);
ins_map[input] = new_ins;
}
new_args.push_back(new_ins);
}
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
}
// set one pair of parameter for each argument
prog.int8_quant_params->resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
}
void capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{1.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0
std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
param_pair.first = 127.0f / max_abs;
(*prog.int8_quant_params)[ins_index] = param_pair;
};
capture_arguments(prog, ins_names, calc_quant_params);
}
void capture_arguments(program& prog)
{
std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names);
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -261,7 +261,8 @@ struct cpu_quant_convolution ...@@ -261,7 +261,8 @@ struct cpu_quant_convolution
const auto in_ch = group_id * wei_c + k; const auto in_ch = group_id * wei_c + k;
if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w) if(in_x >= 0 && in_x < in_h && in_y >= 0 && in_y < in_w)
{ {
acc += input(o, in_ch, in_x, in_y) * weights(w, k, x, y); acc += static_cast<int32_t>(input(o, in_ch, in_x, in_y)) *
weights(w, k, x, y);
} }
}); });
output(o, w, i, j) = acc; output(o, w, i, j) = acc;
...@@ -576,8 +577,7 @@ struct cpu_quant_gemm ...@@ -576,8 +577,7 @@ struct cpu_quant_gemm
} }
// 2 input arguments // 2 input arguments
int32_t beta = 0; migemm(result, arg_0, arg_1, op.alpha, int32_t{0});
migemm(result, arg_0, arg_1, op.alpha, beta);
return result; return result;
} }
......
...@@ -82,6 +82,7 @@ add_library(migraphx_gpu ...@@ -82,6 +82,7 @@ add_library(migraphx_gpu
elu.cpp elu.cpp
pad.cpp pad.cpp
gather.cpp gather.cpp
convert.cpp
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
......
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_convert::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
argument hip_convert::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::convert(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -69,6 +69,8 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument ...@@ -69,6 +69,8 @@ void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument
}); });
} }
void sync_stream(hipStream_t stream) { hipStreamSynchronize(stream); }
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,7 +10,7 @@ namespace gpu { ...@@ -12,7 +10,7 @@ namespace gpu {
struct context; struct context;
struct hip_convert : unary_device<hip_convert, device::convert> struct hip_convert
{ {
op::convert op; op::convert op;
...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert> ...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert>
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
hip_convert(op::convert oper) : op(oper) {} std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
inputs.pop_back(); return shapes.size() - 1;
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
} }
}; };
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <fstream>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -2028,4 +2029,39 @@ TEST_CASE(sqdiff_test) ...@@ -2028,4 +2029,39 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(op_capture)
{
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = p.add_literal(s1, d1);
auto p2 = p.add_literal(s1, d1);
auto pb = p.add_literal(s2, d2);
auto pc = p.add_literal(s2, d2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p;
migraphx::capture_arguments(capture_p);
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
auto cap_res = capture_p.eval({});
auto res = p.eval({});
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3816,4 +3816,21 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half> ...@@ -3816,4 +3816,21 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
}; };
}; };
struct test_convert : verify_program<test_convert>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {8, 24}};
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -202,4 +202,55 @@ TEST_CASE(literal_add) ...@@ -202,4 +202,55 @@ TEST_CASE(literal_add)
} }
} }
TEST_CASE(op_capture)
{
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
};
auto create_program_float = [] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
return p;
};
auto create_program_op = [&] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto opb = p.insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
auto opc = p.insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
auto opa = p.add_instruction(migraphx::op::capture{0, test_func}, pa);
auto ps = p.add_instruction(migraphx::op::dot{}, opa, opb, opc);
auto ops = p.add_instruction(migraphx::op::capture{3, test_func}, ps);
p.add_instruction(migraphx::op::dot{}, opa, ops);
return p;
};
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::capture_arguments(p);
EXPECT(p == op_capture_p);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment