"...utils/git@developer.sourcefind.cn:guobj/qwen_lmdeploy.git" did not exist on "1f88baa5b7a9dde22b11200fd530fe1059e1facb"
Commit 84a3f56e authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

bug_fix_for_gemm_copy (#378)

* fixed a bug related to removing gemm copy

* clang format

* fix review comments

* clang format

* fix unit test failure

* fix review comments

* clang format
parent 9b55685c
...@@ -198,25 +198,26 @@ struct miopen_apply ...@@ -198,25 +198,26 @@ struct miopen_apply
auto&& op = any_cast<Op>(ins->get_operator()); auto&& op = any_cast<Op>(ins->get_operator());
auto beta = op.beta; auto beta = op.beta;
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if((refs.size() == 2) or (refs.size() == 3 and refs.back()->outputs().size() > 1) or if(refs.size() == 2)
(ins == last))
{ {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
if(refs.size() == 2) beta = 0;
{ refs.push_back(output);
beta = 0; }
refs.push_back(output); else
} {
else auto c_alias = instruction::get_output_alias(refs.back());
if(ins == last or refs.back()->outputs().size() > 1 or c_alias->inputs().empty())
{ {
auto output = insert_allocation(ins, ins->get_shape());
auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output); auto copy_out = prog->insert_instruction(ins, hip_copy{}, refs.back(), output);
refs.back() = copy_out; refs.back() = copy_out;
refs.push_back(copy_out); refs.push_back(copy_out);
} }
} else
else {
{ refs.push_back(refs.back());
refs.push_back(refs.back()); }
} }
return prog->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs); return prog->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs);
......
...@@ -138,6 +138,7 @@ migraphx::argument run_gpu(migraphx::program& p) ...@@ -138,6 +138,7 @@ migraphx::argument run_gpu(migraphx::program& p)
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
p.dry_run(m); p.dry_run(m);
EXPECT(is_shared(ctx, p.get_context())); EXPECT(is_shared(ctx, p.get_context()));
p.eval(m);
return migraphx::gpu::from_gpu(p.eval(m)); return migraphx::gpu::from_gpu(p.eval(m));
} }
...@@ -1052,6 +1053,24 @@ struct test_gemm : verify_program<test_gemm> ...@@ -1052,6 +1053,24 @@ struct test_gemm : verify_program<test_gemm>
} }
}; };
struct test_gemm_copy : verify_program<test_gemm_copy>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
auto pa = p.add_parameter("a", sa);
auto pb = p.add_parameter("b", sb);
auto pc = p.add_parameter("c", sc);
auto dr = p.add_instruction(migraphx::op::dot{}, pa, pb, pc);
p.add_instruction(migraphx::op::add{}, dr, dr);
return p;
}
};
struct test_gemm_ex : verify_program<test_gemm_ex> struct test_gemm_ex : verify_program<test_gemm_ex>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -15,55 +15,22 @@ ...@@ -15,55 +15,22 @@
#include "test.hpp" #include "test.hpp"
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
TEST_CASE(target_copy) TEST_CASE(gpu_target_copy)
{ {
auto run_prog = [](migraphx::program p, migraphx::target gpu_t = migraphx::gpu::target{};
const migraphx::target& t, migraphx::target cpu_t = migraphx::cpu::target{};
migraphx::program::parameter_map& m_in, migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
std::vector<float>& res) {
p.compile(t);
migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes())
{
if(m_in.count(x.first) > 0)
{
m[x.first] = t.copy_to(m_in[x.first]);
}
else
{
m[x.first] = t.allocate(x.second);
}
}
auto result = t.copy_from(p.eval(m)); auto cpu_arg_orig = migraphx::generate_argument(s, 0x123456L);
result.visit([&](auto v) { res.assign(v.begin(), v.end()); }); auto gpu_arg = gpu_t.copy_to(cpu_arg_orig);
}; auto cpu_arg_final = gpu_t.copy_from(gpu_arg);
auto create_program = [] { std::vector<int8_t> val_orig;
migraphx::program p; cpu_arg_orig.visit([&](auto v) { val_orig.assign(v.begin(), v.end()); });
migraphx::shape s{migraphx::shape::float_type, {3, 3}}; std::vector<int8_t> val_final;
auto p1 = p.add_parameter("x", s); cpu_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
auto p2 = p.add_parameter("y", s);
p.add_instruction(migraphx::op::add{}, p1, p2);
return p;
};
{ EXPECT(migraphx::verify_range(val_orig, val_final));
auto p = create_program();
migraphx::program::parameter_map m;
migraphx::shape s{migraphx::shape::float_type, {3, 3}};
m["x"] = migraphx::generate_argument(s);
std::vector<float> cpu_result;
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, m, cpu_result);
std::vector<float> gpu_result;
migraphx::target gpu_t = migraphx::gpu::target{};
run_prog(p, gpu_t, m, gpu_result);
EXPECT(migraphx::verify_range(cpu_result, gpu_result));
}
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -110,8 +77,10 @@ TEST_CASE(int8_quantization) ...@@ -110,8 +77,10 @@ TEST_CASE(int8_quantization)
auto p = create_program(); auto p = create_program();
migraphx::program::parameter_map m; migraphx::program::parameter_map m;
migraphx::shape sa{migraphx::shape::float_type, {2, 16}}; migraphx::shape sa{migraphx::shape::float_type, {2, 16}};
migraphx::shape sb{migraphx::shape::float_type, {16, 8}};
migraphx::shape sc{migraphx::shape::float_type, {2, 8}}; migraphx::shape sc{migraphx::shape::float_type, {2, 8}};
m["a"] = migraphx::generate_argument(sa); m["a"] = migraphx::generate_argument(sa);
m["b"] = migraphx::generate_argument(sb);
m["c"] = migraphx::generate_argument(sc); m["c"] = migraphx::generate_argument(sc);
std::vector<float> cpu_result; std::vector<float> cpu_result;
migraphx::target cpu_t = migraphx::cpu::target{}; migraphx::target cpu_t = migraphx::cpu::target{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment