Commit 12258d8f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

added more gemm_impl tests; fix wrong argument in validate()

parent a82ea1d6
......@@ -220,7 +220,6 @@ struct gemm_impl
d_stride = is_3inputs ? get_batch_stride(input_shapes[3]) : c_stride;
num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and b_stride == 0))
{
// If the batch dimension of B is broadcasted, then we can
......@@ -291,7 +290,7 @@ struct gemm_impl
auto common_args = create_gemm_ex_args_common(ctx, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_standard,
rocblas_gemm_algo_solution_index,
solution_idx,
rocblas_gemm_flags_check_solution_index);
}
......
......@@ -79,6 +79,8 @@ struct rocblas_gemm
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
// When input shapes are A, B, C the GEMM equation is C  =  α AB+ β C where α, β are
// scalars
check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]);
blas_shape(inputs[1]);
......
......@@ -94,7 +94,132 @@ TEST_CASE(gemm_tune_with_rocblas)
#endif
}
// TODO: make tests that run both strided and not-strided GEMMs. Also, try tuning with an invalid
// start value.
// GEMM tuning of a strided-batch matrix; invokes rocblas_gemm_strided_batched_ex
TEST_CASE(gemm_tune_strided)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {4, 2, 2}};
migraphx::shape sb{migraphx::shape::float_type, {4, 2, 2}};
migraphx::shape s_output{migraphx::shape::float_type, {4, 2, 2}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
auto output = mm->add_parameter("out", s_output);
auto gemm_oper = migraphx::make_op("gpu::gemm", {{"beta", 2}});
mm->add_instruction(gemm_oper, a, b, output);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
auto gemmv = gemm_op.to_value();
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef ROCBLAS_BETA_FEATURES_API
EXPECT(0 != solution_idx.to<std::size_t>());
#else
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
// GEMM tuning of a strided-batch matrix; created by lowering
TEST_CASE(gemm_tune_strided_lowered)
{
migraphx::program p;
auto* mm = p.get_main_module();
// At time of writing this test, gemm_impl considers a shape is strided if it has
// at least three dimensions and the 3rd-to-last is nonzero, invoking
// rocblas_gemm_strided_batched_ex. Also, DOT operator requires all dimensions except the last
// two to be equal.
migraphx::shape sa{migraphx::shape::float_type, {4, 2, 5}};
migraphx::shape sb{migraphx::shape::float_type, {4, 5, 3}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
migraphx::operation dot_op = migraphx::make_op("dot");
mm->add_instruction(dot_op, a, b);
// lowering adds gemm implementation for dot operator
run_lowering(p);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef ROCBLAS_BETA_FEATURES_API
EXPECT(0 != solution_idx.to<std::size_t>());
#else
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
TEST_CASE(gemm_tune_invalid_sol_index)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sa{migraphx::shape::float_type, {4, 2}};
migraphx::shape sb{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_output{migraphx::shape::float_type, {4, 3}};
auto a = mm->add_parameter("a", sa);
auto b = mm->add_parameter("b", sb);
auto output = mm->add_parameter("out", s_output);
auto gemm_oper = migraphx::make_op("gpu::gemm", {{"solution_idx", 987654321}});
mm->add_instruction(gemm_oper, a, b, output);
migraphx::target gpu_t = migraphx::gpu::target{};
migraphx::compile_options options;
options.exhaustive_tune = true;
p.compile(gpu_t, options);
migraphx::value solution_idx(0);
for(auto ins : iterator_for(*p.get_main_module()))
{
if(ins->name() == "gpu::gemm")
{
auto gemm_op = migraphx::get_operation(ins);
auto gemmv = gemm_op.to_value();
// given invalid starting index, should return default 0
solution_idx = gemm_op.to_value()["solution_idx"];
break;
}
}
#ifdef ROCBLAS_BETA_FEATURES_API
EXPECT(0 == solution_idx.to<std::size_t>());
#else
EXPECT(0 != solution_idx.to<std::size_t>());
#endif
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment