Commit fc093b02 authored by Umang Yadav's avatar Umang Yadav
Browse files

formatting

parent 119a6b86
......@@ -333,7 +333,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
......@@ -412,9 +413,11 @@ struct find_mlir_standalone_op
auto gemm_based_op = r.result;
// enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type, shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
......
......@@ -45,5 +45,3 @@ struct quant_conv_2 : verify_program<quant_conv_2<DType>>
template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -34,7 +34,7 @@ struct test_conv : verify_program<test_conv<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
......
......@@ -34,10 +34,8 @@ struct test_conv2 : verify_program<test_conv2<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
......
......@@ -34,9 +34,9 @@ struct test_conv_add : verify_program<test_conv_add<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
......
......@@ -35,12 +35,10 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto bias_literal = migraphx::literal{migraphx::shape{DType, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto bias_literal =
migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction(
......
......@@ -54,7 +54,7 @@ struct test_conv_bn : verify_program<test_conv_bn<DType>>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
......
......@@ -74,7 +74,7 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add<DType>>
auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}});
auto w =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}});
auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}});
auto v =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
......
......@@ -35,10 +35,8 @@ struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
auto weights =
mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
......
......@@ -34,7 +34,7 @@ struct test_conv_relu : verify_program<test_conv_relu<DType>>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment