Unverified Commit 38287064 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Turn on gemm unit tests (#997)



This PR is to turn on a few gemm unit test with int8 input datatype. Before rocm4.4, int8 input data type requires matrix size to be no less than 4 in rocblas implementation. Because of this limitation, we turned off a few gemm unit tests with int8 input data type.

This limitation is removed in rocm4.4, so after we upgrade to rocm4.5, we can turn on these unit tests. Also we change to unit test conv_bn_add to adding instructions to module instead of program.
Co-authored-by: default avatarkahmed10 <15948690+kahmed10@users.noreply.github.com>
parent fb39e5e4
...@@ -45,14 +45,6 @@ int main(int argc, const char* argv[]) ...@@ -45,14 +45,6 @@ int main(int argc, const char* argv[])
run_verify rv; run_verify rv;
rv.add_validation_for("gpu", &validate_gpu); rv.add_validation_for("gpu", &validate_gpu);
rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"}); rv.disable_test_for("cpu", {"test_if_lp", "test_if_param", "test_if_literal"});
rv.disable_test_for("gpu", rv.disable_test_for("gpu", {"test_conv_bn_add"});
{"batch_quant_dot_2",
"batch_quant_dot_3",
"batch_quant_dot_5",
"quant_dot_3args_1",
"quant_dot_3args_2",
"quant_dot_3args_3",
"quant_dot_3args_4",
"quant_dot_3args_5"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -2,44 +2,44 @@ ...@@ -2,44 +2,44 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/common.hpp> #include <migraphx/make_op.hpp>
// struct test_conv_bn_add : verify_program<test_conv_bn_add> struct test_conv_bn_add : verify_program<test_conv_bn_add>
// { {
// static migraphx::instruction_ref add_bn(migraphx::program& p, static migraphx::instruction_ref add_bn(migraphx::module& m,
// migraphx::instruction_ref x, migraphx::instruction_ref x,
// std::size_t channels, std::size_t channels,
// std::size_t seed = 1) std::size_t seed = 1)
// { {
// migraphx::shape vars{migraphx::shape::float_type, {channels}}; migraphx::shape vars{migraphx::shape::float_type, {channels}};
// auto scale = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + seed)));
// seed))); auto bias = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + seed)));
// + seed))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + seed)));
// 3 + seed))); auto variance = auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed)));
// mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + seed))); return return m.add_instruction(
// mm->add_instruction( migraphx::make_op("batch_norm_inference"), x, scale, bias, mean, variance);
// migraphx::op::batch_norm_inference{}, x, scale, bias, mean, variance); }
// }
// migraphx::program create_program() const migraphx::program create_program() const
// { {
// migraphx::program p; migraphx::program p;
// std::size_t ichannels = 64; auto* mm = p.get_main_module();
// std::size_t ochannels = 256; std::size_t ichannels = 64;
// auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, std::size_t ochannels = 256;
// 56}}); auto w = mm->add_literal(migraphx::generate_literal( auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1)); auto w = mm->add_literal(migraphx::generate_literal(
// auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
// 56}}); auto v = mm->add_literal(migraphx::generate_literal( auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
// {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2)); auto v = mm->add_literal(migraphx::generate_literal(
// auto relu1 = mm->add_instruction(migraphx::op::relu{}, x); {migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
// auto conv1 = mm->add_instruction(migraphx::op::convolution{}, relu1, w); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
// auto bn1 = add_bn(p, conv1, ochannels, 1); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
// auto relu2 = mm->add_instruction(migraphx::op::relu{}, y); auto bn1 = add_bn(*mm, conv1, ochannels, 1);
// auto conv2 = mm->add_instruction(migraphx::op::convolution{}, relu2, v); auto relu2 = mm->add_instruction(migraphx::make_op("relu"), y);
// auto bn2 = add_bn(p, conv2, ochannels, 1); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), relu2, v);
// auto sum = mm->add_instruction(migraphx::op::add{}, bn1, bn2); auto bn2 = add_bn(*mm, conv2, ochannels, 1);
// mm->add_instruction(migraphx::op::relu{}, sum); auto sum = mm->add_instruction(migraphx::make_op("add"), bn1, bn2);
// return p; mm->add_instruction(migraphx::make_op("relu"), sum);
// } return p;
// }; }
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment