Commit 22aadbd6 authored by Umang Yadav's avatar Umang Yadav
Browse files

add some MLIR fp8 tests for convolutions

parent afe12765
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides> template <migraphx::shape::type_t DType>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction( auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return p; return p;
} }
}; };
template struct test_conv_add_1x1_diff_strides<migraphx::shape::float_type>;
template struct test_conv_add_1x1_diff_strides<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -29,16 +29,17 @@ ...@@ -29,16 +29,17 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn : verify_program<test_conv_bn> template <migraphx::shape::type_t DType>
struct test_conv_bn : verify_program<test_conv_bn<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
// non-symmetrical tiling // non-symmetrical tiling
...@@ -53,8 +54,8 @@ struct test_conv_bn : verify_program<test_conv_bn> ...@@ -53,8 +54,8 @@ struct test_conv_bn : verify_program<test_conv_bn>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
auto usq_scale = auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -74,3 +75,6 @@ struct test_conv_bn : verify_program<test_conv_bn> ...@@ -74,3 +75,6 @@ struct test_conv_bn : verify_program<test_conv_bn>
return p; return p;
} }
}; };
template struct test_conv_bn<migraphx::shape::float_type>;
template struct test_conv_bn<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -30,16 +30,17 @@ ...@@ -30,16 +30,17 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}}; migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}}; migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}}; migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs); auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws); auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction( auto conv = mm->add_instruction(
...@@ -52,8 +53,8 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -52,8 +53,8 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
auto usq_scale = auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -81,3 +82,6 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling> ...@@ -81,3 +82,6 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
return p; return p;
} }
}; };
template struct test_conv_bn_relu_pooling<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -30,21 +30,22 @@ ...@@ -30,21 +30,22 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/common.hpp> #include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DType>>
{ {
static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x) static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
{ {
auto bn_lens = x->get_shape().lens(); auto bn_lens = x->get_shape().lens();
auto c_len = bn_lens.at(1); auto c_len = bn_lens.at(1);
migraphx::shape vars{migraphx::shape::float_type, {c_len}}; migraphx::shape vars{DType, {c_len}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len))); auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len))); auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len))); auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len))); auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}}); auto rt = m.add_literal(migraphx::literal{DType, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}}); auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}});
auto usq_scale = auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale); m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
...@@ -66,10 +67,10 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -66,10 +67,10 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}}; migraphx::shape xs1{DType, {1, 512, 7, 7}};
migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}}; migraphx::shape xs2{DType, {1, 1024, 14, 14}};
migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}}; migraphx::shape ws1{DType, {2048, 512, 1, 1}};
migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}}; migraphx::shape ws2{DType, {2048, 1024, 1, 1}};
auto x1 = mm->add_parameter("x1", xs1); auto x1 = mm->add_parameter("x1", xs1);
auto w1 = mm->add_parameter("w1", ws1); auto w1 = mm->add_parameter("w1", ws1);
auto conv1 = mm->add_instruction( auto conv1 = mm->add_instruction(
...@@ -97,3 +98,6 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2> ...@@ -97,3 +98,6 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
return p; return p;
} }
}; };
template struct test_conv_bn_relu_pooling2<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling2<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment