Commit fc093b02 authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 119a6b86
...@@ -333,7 +333,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -333,7 +333,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type); bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -412,9 +413,11 @@ struct find_mlir_standalone_op ...@@ -412,9 +413,11 @@ struct find_mlir_standalone_op
auto gemm_based_op = r.result; auto gemm_based_op = r.result;
// enable only for fp32/fp16/i8/fp8 types // enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type}, shape::type_t::half_type,
i->get_shape().type()); shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
})) }))
return; return;
static size_t counter = 0; static size_t counter = 0;
......
...@@ -45,5 +45,3 @@ struct quant_conv_2 : verify_program<quant_conv_2<DType>> ...@@ -45,5 +45,3 @@ struct quant_conv_2 : verify_program<quant_conv_2<DType>>
template struct quant_conv_2<migraphx::shape::int8_type>; template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>; template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -34,7 +34,7 @@ struct test_conv : verify_program<test_conv<DType>> ...@@ -34,7 +34,7 @@ struct test_conv : verify_program<test_conv<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p; return p;
......
...@@ -34,10 +34,8 @@ struct test_conv2 : verify_program<test_conv2<DType>> ...@@ -34,10 +34,8 @@ struct test_conv2 : verify_program<test_conv2<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
......
...@@ -34,9 +34,9 @@ struct test_conv_add : verify_program<test_conv_add<DType>> ...@@ -34,9 +34,9 @@ struct test_conv_add : verify_program<test_conv_add<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}}); auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1)); auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}}); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2)); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
......
...@@ -35,12 +35,10 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>> ...@@ -35,12 +35,10 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = auto bias_literal =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias_literal = migraphx::literal{migraphx::shape{DType, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal); auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction( auto bcast_bias = mm->add_instruction(
......
...@@ -54,7 +54,7 @@ struct test_conv_bn : verify_program<test_conv_bn<DType>> ...@@ -54,7 +54,7 @@ struct test_conv_bn : verify_program<test_conv_bn<DType>>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3))); auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4))); auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{DType, {0.5}}); auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}}); auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type) if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
......
...@@ -74,7 +74,7 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add<DType>> ...@@ -74,7 +74,7 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add<DType>>
auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}}); auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}});
auto w = auto w =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1)); mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}}); auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}});
auto v = auto v =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2)); mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x); auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
......
...@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>> ...@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction( auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv); migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
......
...@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>> ...@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}}); auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv); mm->add_instruction(migraphx::make_op("relu"), conv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment