Commit 84a3f56e authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

bug_fix_for_gemm_copy (#378)

* fixed a bug related to removing gemm copy

* clang format

* fix review comments

* clang format

* fix unit test failure

* fix review comments

* clang format
parent 9b55685c
......@@ -198,26 +198,27 @@ struct miopen_apply
auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs();
if((refs.size() == 2) or (refs.size() == 3 and refs.back()->outputs().size() > 1) or
(ins == last))
{
auto output = insert_allocation(ins, ins->get_shape());
if(refs.size() == 2)
{
auto output = insert_allocation(ins, ins->get_shape());
beta = 0;
refs.push_back(output);
}
else
{
auto c_alias = instruction::get_output_alias(refs.back());
if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty())
{
auto output = insert_allocation(ins, ins->get_shape());
auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output);
refs.back() = copy_out;
refs.push_back(copy_out);
}
}
else
{
refs.push_back(refs.back());
}
}
return prog->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs);
});
......
......@@ -138,6 +138,7 @@ migraphx::argument run_gpu(migraphx::program& p)
EXPECT(is_shared(ctx, p.get_context()));
p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context()));
p.eval(m);
return migraphx::gpu::from_gpu(p.eval(m));
}
......@@ -1052,6 +1053,24 @@ struct test_gemm : verify_program<test_gemm>
}
};
struct test_gemm_copy : verify_program<test_gemm_copy>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
auto dr = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::add{}, dr, dr);
return p;
}
};
struct test_gemm_ex : verify_program<test_gemm_ex>
{
migraphx::program create_program() const
......
......@@ -15,55 +15,22 @@
#include "test.hpp"
#include <migraphx/half.hpp>
TEST_CASE(target_copy)
TEST_CASE(gpu_target_copy)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
migraphx::program::parameter_map& m_in,
std::vector<float>& res) {
p.compile(t);
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m));
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};
auto create_program = [] {
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
auto p1 = p.add_parameter("x", s);
auto p2 = p.add_parameter("y", s);
p.add_instruction(migraphx::op::add{}, p1, p2);
return p;
};
{
auto p = create_program();
migraphx::program::parameter_map m;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
m["x"] = migraphx::generate_argument(s);
std::vector<float> cpu_result;
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, m, cpu_result);
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{};
run_prog(p, gpu_t, m, gpu_result);
auto cpu_arg_orig = migraphx::generate_argument(s, 0x123456L);
auto gpu_arg = gpu_t.copy_to(cpu_arg_orig);
auto cpu_arg_final = gpu_t.copy_from(gpu_arg);
EXPECT(migraphx::verify_range(cpu_result, gpu_result));
}
std::vector<int8_t> val_orig;
cpu_arg_orig.visit([&](auto v) { val_orig.assign(v.begin(), v.end()); });
std::vector<int8_t> val_final;
cpu_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify_range(val_orig, val_final));
}
TEST_CASE(int8_quantization)
......@@ -110,8 +77,10 @@ TEST_CASE(int8_quantization)
auto p = create_program();
migraphx::program::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc);
std::vector<float> cpu_result;
migraphx::target cpu_t = migraphx::cpu::target{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment