Commit 5d240fa4 authored by Umang Yadav's avatar Umang Yadav
Browse files

quant dot

parent 7772a428
...@@ -195,7 +195,7 @@ struct gemm_impl ...@@ -195,7 +195,7 @@ struct gemm_impl
ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc; ldd = is_3inputs ? input_shapes[3].strides()[dim_0] : ldc;
arg_type = get_type(input_shapes[0].type()); arg_type = get_type(input_shapes[0].type());
output_type = arg_type; output_type = get_type(input_shapes[2].type());
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
{ {
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
......
...@@ -25,26 +25,31 @@ ...@@ -25,26 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <migraphx::shape::type_t DType, migraphx::shape::type_t CType> template <typename DType, typename CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>> struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{DType, {3, 2, 2, 8}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{DType, {3, 2, 8, 7}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{CType, {3, 2, 2, 7}};
migraphx::shape m1_shape{dtype, {3, 2, 2, 8}};
migraphx::shape m2_shape{dtype, {3, 2, 8, 7}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct batch_quant_dot_2<migraphx::shape::int8_type, migraphx::shape::int32_type>; template struct batch_quant_dot_2<int8_t, int32_t>;
template struct batch_quant_dot_2<migraphx::shape::fp8e4m3fnuz_type, migraphx::shape::float_type>; template struct batch_quant_dot_2<migraphx::fp8::fp8e4m3fnuz, float>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment