Commit cac6c759 authored by Paul's avatar Paul
Browse files

Merge

parents 4bde67c4 a60bdb67
...@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test): ...@@ -118,9 +118,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test.exclude(r'test_convtranspose_1d_cpu') backend_test.exclude(r'test_convtranspose_1d_cpu')
backend_test.exclude(r'test_det_2d_cpu') backend_test.exclude(r'test_det_2d_cpu')
backend_test.exclude(r'test_det_nd_cpu') backend_test.exclude(r'test_det_nd_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_max_adjusted_cpu')
backend_test.exclude(r'test_dynamicquantizelinear_min_adjusted_cpu')
backend_test.exclude(r'test_edge_pad_cpu') backend_test.exclude(r'test_edge_pad_cpu')
backend_test.exclude(r'test_einsum_batch_diagonal_cpu') backend_test.exclude(r'test_einsum_batch_diagonal_cpu')
backend_test.exclude(r'test_einsum_batch_matmul_cpu') backend_test.exclude(r'test_einsum_batch_matmul_cpu')
......
...@@ -30,7 +30,7 @@ ...@@ -30,7 +30,7 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/quantization.hpp> #include <migraphx/quantization.hpp>
#include <migraphx/quantize_int8.hpp> #include <migraphx/quantize_8bits.hpp>
#include <migraphx/quantize_fp16.hpp> #include <migraphx/quantize_fp16.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
...@@ -654,7 +654,8 @@ TEST_CASE(dot_float) ...@@ -654,7 +654,8 @@ TEST_CASE(dot_float)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
EXPECT(p == qp); EXPECT(p == qp);
...@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args) ...@@ -748,7 +749,8 @@ TEST_CASE(dot_double_2args)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg) ...@@ -821,7 +823,8 @@ TEST_CASE(dot_half_1arg)
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"dot"}, {}, &param_index}});
migraphx::run_passes( migraphx::run_passes(
p, p,
{migraphx::quantize_int8_pass{{"dot"}, quant_params}, migraphx::dead_code_elimination{}}); {migraphx::quantize_8bits_pass{migraphx::shape::int8_type, quant_params},
migraphx::dead_code_elimination{}});
EXPECT(p == create_int8_quantized_prog()); EXPECT(p == create_int8_quantized_prog());
optimize_prog_int8(p); optimize_prog_int8(p);
...@@ -876,7 +879,9 @@ TEST_CASE(conv_float) ...@@ -876,7 +879,9 @@ TEST_CASE(conv_float)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw) ...@@ -901,7 +906,9 @@ TEST_CASE(conv_float_throw)
auto p = create_program(); auto p = create_program();
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
test::throws([&] { test::throws([&] {
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"add"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
}); });
} }
...@@ -952,7 +959,9 @@ TEST_CASE(conv_half) ...@@ -952,7 +959,9 @@ TEST_CASE(conv_half)
const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}}; const std::vector<std::pair<float, float>>& quant_params{{0.1f, 0.0f}, {0.1f, 0.0f}};
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}}); migraphx::run_passes(p, {migraphx::capture_arguments_pass{{"convolution"}, {}, &param_index}});
migraphx::run_passes(p, {migraphx::quantize_int8_pass{{"convolution"}, quant_params}}); migraphx::run_passes(p,
{migraphx::quantize_8bits_pass{
migraphx::shape::type_t::int8_type, quant_params}});
optimize_prog_int8(p); optimize_prog_int8(p);
auto qp = create_int8_quantized_prog(); auto qp = create_int8_quantized_prog();
...@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph) ...@@ -1231,7 +1240,9 @@ TEST_CASE(int8_subgraph)
std::size_t param_index = 0; std::size_t param_index = 0;
migraphx::run_passes( migraphx::run_passes(
p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}}); p1, {migraphx::capture_arguments_pass{{"convolution", "dot"}, {}, &param_index}});
migraphx::run_passes(p1, {migraphx::quantize_int8_pass{{"convolution", "dot"}, quant_params}}); migraphx::run_passes(p1,
{migraphx::quantize_8bits_pass{migraphx::shape::type_t::int8_type,
quant_params}});
optimize_prog_int8(p1); optimize_prog_int8(p1);
auto p2 = create_int8_program(); auto p2 = create_int8_program();
......
...@@ -24,19 +24,23 @@ ...@@ -24,19 +24,23 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> template <typename DType, typename CType>
struct batch_quant_dot_1 : verify_program<batch_quant_dot_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}}; migraphx::shape m1_shape{dtype, {3, 2, 8, 2}};
migraphx::shape m2_shape{dtype, {3, 2, 7, 8}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction( auto tl1 = mm->add_instruction(
...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1> ...@@ -45,7 +49,11 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
auto tl2 = mm->add_instruction( auto tl2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2); migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct batch_quant_dot_1<int8_t, int32_t>;
template struct batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -25,23 +25,31 @@ ...@@ -25,23 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2> template <typename DType, typename CType>
struct batch_quant_dot_2 : verify_program<batch_quant_dot_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}}; auto dtype = migraphx::shape::get_type<DType>{};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}}; auto ctype = migraphx::shape::get_type<CType>{};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
migraphx::shape m1_shape{dtype, {3, 2, 2, 8}};
migraphx::shape m2_shape{dtype, {3, 2, 8, 7}};
migraphx::shape m3_shape{ctype, {3, 2, 2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct batch_quant_dot_2<int8_t, int32_t>;
template struct batch_quant_dot_2<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> template <migraphx::shape::type_t DType>
struct batch_quant_dot_3 : verify_program<batch_quant_dot_3<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 6}}; migraphx::shape m1_shape{DType, {3, 2, 2, 6}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 6, 7}}; migraphx::shape m2_shape{DType, {3, 2, 6, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3> ...@@ -42,3 +43,5 @@ struct batch_quant_dot_3 : verify_program<batch_quant_dot_3>
return p; return p;
} }
}; };
template struct batch_quant_dot_3<migraphx::shape::int8_type>;
template struct batch_quant_dot_3<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> template <migraphx::shape::type_t DType>
struct batch_quant_dot_4 : verify_program<batch_quant_dot_4<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 4, 6, 3}}; migraphx::shape m1_shape{DType, {2, 4, 6, 3}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 2, 6, 3}}; migraphx::shape m2_shape{DType, {7, 2, 6, 3}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4> ...@@ -46,3 +47,5 @@ struct batch_quant_dot_4 : verify_program<batch_quant_dot_4>
return p; return p;
} }
}; };
template struct batch_quant_dot_4<migraphx::shape::int8_type>;
template struct batch_quant_dot_4<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,15 @@ ...@@ -27,14 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> template <migraphx::shape::type_t DType>
struct batch_quant_dot_5 : verify_program<batch_quant_dot_5<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 7, 2}}; migraphx::shape m1_shape{DType, {3, 2, 7, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 5, 7}}; migraphx::shape m2_shape{DType, {3, 2, 5, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5> ...@@ -48,3 +49,5 @@ struct batch_quant_dot_5 : verify_program<batch_quant_dot_5>
return p; return p;
} }
}; };
template struct batch_quant_dot_5<migraphx::shape::int8_type>;
template struct batch_quant_dot_5<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -78,6 +78,16 @@ int main(int argc, const char* argv[]) ...@@ -78,6 +78,16 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>", "test_instancenorm_large_3d<migraphx::shape::half_type>",
// these tests are disabled due issue of lossy downcast, see issue#2517
#if defined(__GNUC__) and !defined(__clang__)
"batch_quant_dot_1<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"quant_dot_3args_4<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
"quant_dot_3args_5<migraphx::fp8::float8<migraphx::fp8::f8_type::fp8, true>, float>",
#else
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>",
#endif
"test_block_reduce_small<3, migraphx::shape::int8_type>", "test_block_reduce_small<3, migraphx::shape::int8_type>",
"test_block_reduce_small<4, migraphx::shape::int8_type>", "test_block_reduce_small<4, migraphx::shape::int8_type>",
"test_block_reduce_small<8, migraphx::shape::int8_type>", "test_block_reduce_small<8, migraphx::shape::int8_type>",
...@@ -89,5 +99,10 @@ int main(int argc, const char* argv[]) ...@@ -89,5 +99,10 @@ int main(int argc, const char* argv[])
"test_block_reduce_small<128, migraphx::shape::int8_type>", "test_block_reduce_small<128, migraphx::shape::int8_type>",
"test_block_reduce_small<129, migraphx::shape::int8_type>", "test_block_reduce_small<129, migraphx::shape::int8_type>",
}); });
rv.disable_test_for("gpu",
{// These passes on MI300 but fails on others, same issue as CPU.
"batch_quant_dot_1<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>",
"quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -25,23 +25,31 @@ ...@@ -25,23 +25,31 @@
#include "verify_program.hpp" #include "verify_program.hpp"
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/apply_alpha_beta.hpp> #include <migraphx/apply_alpha_beta.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1> template <typename DType, typename CType>
struct quant_dot_3args_1 : verify_program<quant_dot_3args_1<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), 1, 1); migraphx::add_apply_alpha_beta(
*mm, {l1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{1});
return p; return p;
} }
}; };
template struct quant_dot_3args_1<int8_t, int32_t>;
template struct quant_dot_3args_1<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,22 +28,29 @@ ...@@ -28,22 +28,29 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2> template <typename DType, typename CType>
struct quant_dot_3args_2 : verify_program<quant_dot_3args_2<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {8, 7}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {8, 7}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), 1, 3); migraphx::add_apply_alpha_beta(
*mm, {tl1, l2, l3}, migraphx::make_op("quant_dot"), CType{1}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_2<int8_t, int32_t>;
template struct quant_dot_3args_2<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,22 +28,28 @@ ...@@ -28,22 +28,28 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3> template <typename DType, typename CType>
struct quant_dot_3args_3 : verify_program<quant_dot_3args_3<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {2, 8}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {2, 8}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), 2, 3); migraphx::add_apply_alpha_beta(
*mm, {l1, tl2, l3}, migraphx::make_op("quant_dot"), CType{2}, CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_3<int8_t, int32_t>;
template struct quant_dot_3args_3<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,15 +28,18 @@ ...@@ -28,15 +28,18 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> template <typename DType, typename CType>
struct quant_dot_3args_4 : verify_program<quant_dot_3args_4<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {8, 2}}; auto ctype = migraphx::shape::get_type<CType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 8}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m3_shape{migraphx::shape::int32_type, {2, 7}}; migraphx::shape m1_shape{dtype, {8, 2}};
migraphx::shape m2_shape{dtype, {7, 8}};
migraphx::shape m3_shape{ctype, {2, 7}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4> ...@@ -45,7 +48,11 @@ struct quant_dot_3args_4 : verify_program<quant_dot_3args_4>
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
auto l3 = mm->add_parameter("c", m3_shape); auto l3 = mm->add_parameter("c", m3_shape);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), 3, 2); migraphx::add_apply_alpha_beta(
*mm, {tl1, tl2, l3}, migraphx::make_op("quant_dot"), CType{3}, CType{2});
return p; return p;
} }
}; };
template struct quant_dot_3args_4<int8_t, int32_t>;
template struct quant_dot_3args_4<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -28,14 +28,17 @@ ...@@ -28,14 +28,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5> template <typename DType, typename CType>
struct quant_dot_3args_5 : verify_program<quant_dot_3args_5<DType, CType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {6, 2}}; auto dtype = migraphx::shape::get_type<DType>();
migraphx::shape m2_shape{migraphx::shape::int8_type, {7, 6}};
migraphx::shape m1_shape{dtype, {6, 2}};
migraphx::shape m2_shape{dtype, {7, 6}};
auto l1 = mm->add_parameter("a", m1_shape); auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = auto tl1 =
...@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5> ...@@ -43,7 +46,10 @@ struct quant_dot_3args_5 : verify_program<quant_dot_3args_5>
auto l2 = mm->add_parameter("b", m2_shape); auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = auto tl2 =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), 3); migraphx::add_apply_alpha_beta(*mm, {tl1, tl2}, migraphx::make_op("quant_dot"), CType{3});
return p; return p;
} }
}; };
template struct quant_dot_3args_5<int8_t, int32_t>;
template struct quant_dot_3args_5<migraphx::fp8::fp8e4m3fnuz, float>;
...@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names) ...@@ -232,12 +232,12 @@ void quantize_fp16_with_op_names(program& prog, std::vector<std::string>& names)
struct quantize_int8_options struct quantize_int8_options
{ {
std::vector<parameter_map> calibration = {}; std::vector<parameter_map> calibration = {};
std::vector<std::string> op_names = {}; std::unordered_set<std::string> op_names = {};
}; };
void add_op_name(quantize_int8_options& options, const char* name) void add_op_name(quantize_int8_options& options, const char* name)
{ {
options.op_names.push_back(name); options.op_names.insert(name);
} }
void add_calibration_data(quantize_int8_options& options, parameter_map& data) void add_calibration_data(quantize_int8_options& options, parameter_map& data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment