"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "60fd7a8fb6eaf67d4d3838619623d029bf93ebfe"
Commit a392d84c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bugs_for_bert

parents 5a264889 0628e570
...@@ -134,8 +134,6 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -134,8 +134,6 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(conv.op.padding_mode != op::padding_mode_t::default_)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp> #include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -12,7 +10,7 @@ namespace gpu { ...@@ -12,7 +10,7 @@ namespace gpu {
struct context; struct context;
struct hip_convert : unary_device<hip_convert, device::convert> struct hip_convert
{ {
op::convert op; op::convert op;
...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert> ...@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert>
return migraphx::reflect(self.op, f); return migraphx::reflect(self.op, f);
} }
hip_convert(op::convert oper) : op(oper) {} std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
inputs.pop_back(); return shapes.size() - 1;
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
} }
}; };
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <fstream>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
...@@ -46,6 +47,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -46,6 +47,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
//common_subexpression_elimination{}, //common_subexpression_elimination{},
//dead_code_elimination{}, //dead_code_elimination{},
......
...@@ -574,23 +574,18 @@ struct tf_parser ...@@ -574,23 +574,18 @@ struct tf_parser
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; auto lens = args[0]->get_shape().lens();
// check if conditions for GlobalAvgPool are met auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector<int64_t>(), lens.size());
auto lens = args[0]->get_shape().lens();
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());
if(axes == hw_axes and lens.size() == 4) if(keep_dims)
{ {
op::pooling op{"average"}; return prog.add_instruction(op::reduce_mean{axes}, args[0]);
op.lengths[0] = lens[2]; }
op.lengths[1] = lens[3]; else
auto l0 = prog.add_instruction(op, args.front()); {
if(keep_dims) auto ins = prog.add_instruction(op::reduce_mean{axes}, args[0]);
return l0; return prog.add_instruction(op::squeeze{axes}, ins);
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
instruction_ref parse_pack(const std::string&, instruction_ref parse_pack(const std::string&,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
...@@ -2028,4 +2029,39 @@ TEST_CASE(sqdiff_test) ...@@ -2028,4 +2029,39 @@ TEST_CASE(sqdiff_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(op_capture)
{
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
std::vector<float> d1(s1.elements());
std::vector<float> d2(s2.elements());
std::iota(d1.begin(), d1.end(), 0.0f);
std::iota(d2.begin(), d2.end(), 0.0f);
auto p1 = p.add_literal(s1, d1);
auto p2 = p.add_literal(s1, d1);
auto pb = p.add_literal(s2, d2);
auto pc = p.add_literal(s2, d2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
migraphx::program capture_p = p;
migraphx::capture_arguments(capture_p);
p.compile(migraphx::cpu::target{});
capture_p.compile(migraphx::cpu::target{});
auto cap_res = capture_p.eval({});
auto res = p.eval({});
std::vector<float> vec;
std::vector<float> cap_vec;
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(vec, cap_vec));
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3816,4 +3816,21 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half> ...@@ -3816,4 +3816,21 @@ struct test_reduce_mean_half : verify_program<test_reduce_mean_half>
}; };
}; };
struct test_convert : verify_program<test_convert>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {8, 24}};
migraphx::shape sb{migraphx::shape::float_type, {24, 6}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto ia = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pa);
auto ib = p.add_instruction(migraphx::op::convert{migraphx::shape::int8_type}, pb);
p.add_instruction(migraphx::op::quant_dot{}, ia, ib);
return p;
};
};
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -257,8 +257,7 @@ TEST_CASE(mean_test) ...@@ -257,8 +257,7 @@ TEST_CASE(mean_test)
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l); p.add_literal(l);
p.add_literal(l); p.add_literal(l);
migraphx::op::pooling op; migraphx::op::reduce_mean op{{2, 3}};
op.lengths = {16, 16};
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
...@@ -272,9 +271,8 @@ TEST_CASE(mean_test_nhwc) ...@@ -272,9 +271,8 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling op; migraphx::op::reduce_mean op{{2, 3}};
op.lengths = {16, 16}; auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
......
...@@ -202,4 +202,55 @@ TEST_CASE(literal_add) ...@@ -202,4 +202,55 @@ TEST_CASE(literal_add)
} }
} }
TEST_CASE(op_capture)
{
auto test_func = [&](std::size_t ins_index, const std::vector<migraphx::argument>& args) {
(void)ins_index;
(void)args;
};
auto create_program_float = [] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto ps = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::dot{}, pa, ps);
return p;
};
auto create_program_op = [&] {
migraphx::program p;
migraphx::shape s1{migraphx::shape::float_type, {3, 3}};
migraphx::shape s2{migraphx::shape::float_type, {3, 6}};
auto p1 = p.add_parameter("x", s1);
auto p2 = p.add_parameter("y", s1);
auto pb = p.add_parameter("b", s2);
auto pc = p.add_parameter("c", s2);
auto pa = p.add_instruction(migraphx::op::add{}, p1, p2);
auto opb = p.insert_instruction(std::next(pb), migraphx::op::capture{1, test_func}, pb);
auto opc = p.insert_instruction(std::next(pc), migraphx::op::capture{2, test_func}, pc);
auto opa = p.add_instruction(migraphx::op::capture{0, test_func}, pa);
auto ps = p.add_instruction(migraphx::op::dot{}, opa, opb, opc);
auto ops = p.add_instruction(migraphx::op::capture{3, test_func}, ps);
p.add_instruction(migraphx::op::dot{}, opa, ops);
return p;
};
{
auto p = create_program_float();
auto op_capture_p = create_program_op();
migraphx::capture_arguments(p);
EXPECT(p == op_capture_p);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment