Commit 4cc5393d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into subwave-reduce

parents f7d97e53 fe61d940
......@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct quant_conv_1 : verify_program<quant_conv_1>
template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_conv_1d : verify_program<quant_conv_1d>
template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}};
migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution",
......@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p;
}
};
template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>
struct quant_conv_2 : verify_program<quant_conv_2>
template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}};
migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}};
migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};
template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_conv_padding : verify_program<quant_conv_padding>
template <migraphx::shape::type_t DType>
struct quant_conv_padding : verify_program<quant_conv_padding<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
......@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return p;
}
};
template struct quant_conv_padding<migraphx::shape::int8_type>;
template struct quant_conv_padding<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
template <migraphx::shape::type_t DType>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
......@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return p;
}
};
template struct quant_conv_padding_stride<migraphx::shape::int8_type>;
template struct quant_conv_padding_stride<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -136,15 +136,18 @@ void run_verify::validate(const migraphx::target& t,
ti.validate(p, m);
}
std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_verify::run_ref(migraphx::program p,
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const
{
migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()};
compile_check(p, t, c_opts);
return p.eval(std::move(inputs));
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
return std::make_pair(std::move(p), p.eval(std::move(inputs)));
}
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_verify::run_target(const migraphx::target& t,
migraphx::program p,
......@@ -225,7 +228,7 @@ void run_verify::verify(const std::string& name,
}
}
auto gold_f = detach_async([=] { return run_ref(p, m, c_opts); });
auto ref_f = detach_async([=] { return run_ref(p, m, c_opts); });
for(const auto& tname : target_names)
{
target_info ti = get_target_info(tname);
......@@ -234,8 +237,8 @@ void run_verify::verify(const std::string& name,
tname, detach_async([=] { return run_target(t, p, m, c_opts); }, ti.parallel));
}
assert(gold_f.valid());
auto gold = gold_f.get();
assert(ref_f.valid());
auto ref_results = ref_f.get();
for(auto&& pp : results)
{
......@@ -244,7 +247,7 @@ void run_verify::verify(const std::string& name,
auto x = pp.second.get();
auto cp = x.first;
auto result = x.second;
auto gold = ref_results.second;
bool passed = true;
passed &= (gold.size() == result.size());
std::size_t num = gold.size();
......@@ -257,7 +260,7 @@ void run_verify::verify(const std::string& name,
if(not passed or migraphx::enabled(MIGRAPHX_TRACE_TEST{}))
{
std::cout << p << std::endl;
std::cout << "ref:\n" << p << std::endl;
std::cout << "ref:\n" << ref_results.first << std::endl;
std::cout << tname << ":\n" << cp << std::endl;
std::cout << std::endl;
}
......
......@@ -39,9 +39,11 @@ struct target_info
struct run_verify
{
std::vector<migraphx::argument> run_ref(migraphx::program p,
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const;
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_ref(migraphx::program p,
migraphx::parameter_map inputs,
const migraphx::compile_options& c_opts) const;
std::pair<migraphx::program, std::vector<migraphx::argument>>
run_target(const migraphx::target& t,
migraphx::program p,
......
......@@ -27,17 +27,19 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv : verify_program<test_conv>
template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
}
};
template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,16 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv2 : verify_program<test_conv2>
template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
......@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p;
}
};
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_add : verify_program<test_conv_add>
template <migraphx::shape::type_t DType>
struct test_conv_add : verify_program<test_conv_add<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
......@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add>
return p;
}
};
template struct test_conv_add<migraphx::shape::float_type>;
template struct test_conv_add<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides>
template <migraphx::shape::type_t DType>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}});
auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1));
auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
......@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return p;
}
};
template struct test_conv_add_1x1_diff_strides<migraphx::shape::float_type>;
template struct test_conv_add_1x1_diff_strides<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -28,18 +28,17 @@
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
struct test_conv_add_relu : verify_program<test_conv_add_relu>
template <migraphx::shape::type_t DType>
struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto bias_literal =
migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction(
......@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu>
return p;
}
};
template struct test_conv_add_relu<migraphx::shape::float_type>;
template struct test_conv_add_relu<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -29,26 +29,24 @@
#include <migraphx/instruction.hpp>
struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
template <migraphx::shape::type_t DType>
struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto l0 = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto l0 = migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(l0);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_add = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
bias);
auto bias_add = mm->add_instruction(migraphx::make_op("add"), conv, bcast_add);
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
auto min_val = mm->add_literal(migraphx::literal(DType, {0.0f}));
auto max_val = mm->add_literal(migraphx::literal(DType, {6.0f}));
min_val = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", conv->get_shape().lens()}}), min_val);
max_val = mm->add_instruction(
......@@ -57,3 +55,6 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
return p;
}
};
template struct test_conv_bias_clipped_relu<migraphx::shape::float_type>;
template struct test_conv_bias_clipped_relu<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -29,16 +29,17 @@
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
struct test_conv_bn : verify_program<test_conv_bn>
template <migraphx::shape::type_t DType>
struct test_conv_bn : verify_program<test_conv_bn<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}};
migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws);
// non-symmetrical tiling
......@@ -53,8 +54,14 @@ struct test_conv_bn : verify_program<test_conv_bn>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = mm->add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
......@@ -74,3 +81,6 @@ struct test_conv_bn : verify_program<test_conv_bn>
return p;
}
};
template struct test_conv_bn<migraphx::shape::float_type>;
template struct test_conv_bn<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -29,22 +29,27 @@
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
struct test_conv_bn_add : verify_program<test_conv_bn_add>
template <migraphx::shape::type_t DType>
struct test_conv_bn_add : verify_program<test_conv_bn_add<DType>>
{
static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
{
auto bn_lens = x->get_shape().lens();
auto c_len = bn_lens.at(1);
migraphx::shape vars{migraphx::shape::float_type, {c_len}};
migraphx::shape vars{DType, {c_len}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto rt = m.add_literal(migraphx::literal{DType, {0.5}});
auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = m.add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias);
......@@ -66,12 +71,12 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add>
auto* mm = p.get_main_module();
std::size_t ichannels = 64;
std::size_t ochannels = 256;
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto w = mm->add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, ichannels, 56, 56}});
auto v = mm->add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {ochannels, ichannels, 1, 1}}, 2));
auto x = mm->add_parameter("x", {DType, {1, ichannels, 56, 56}});
auto w =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 1));
auto y = mm->add_parameter("y", {DType, {1, ichannels, 56, 56}});
auto v =
mm->add_literal(migraphx::generate_literal({DType, {ochannels, ichannels, 1, 1}}, 2));
auto relu1 = mm->add_instruction(migraphx::make_op("relu"), x);
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), relu1, w);
auto bn1 = add_bn(*mm, conv1);
......@@ -83,3 +88,6 @@ struct test_conv_bn_add : verify_program<test_conv_bn_add>
return p;
}
};
template struct test_conv_bn_add<migraphx::shape::float_type>;
template struct test_conv_bn_add<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -30,16 +30,17 @@
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape xs{migraphx::shape::float_type, {1, 3, 224, 224}};
migraphx::shape ws{migraphx::shape::float_type, {64, 3, 7, 7}};
migraphx::shape vars{migraphx::shape::float_type, {64}};
migraphx::shape xs{DType, {1, 3, 224, 224}};
migraphx::shape ws{DType, {64, 3, 7, 7}};
migraphx::shape vars{DType, {64}};
auto x = mm->add_parameter("x", xs);
auto w = mm->add_parameter("w", ws);
auto conv = mm->add_instruction(
......@@ -52,9 +53,13 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
auto mean = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 3)));
auto variance = mm->add_literal(migraphx::abs(migraphx::generate_literal(vars, 4)));
auto rt = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto rt = mm->add_literal(migraphx::literal{DType, {0.5}});
auto eps = mm->add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = mm->add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias =
......@@ -82,3 +87,6 @@ struct test_conv_bn_relu_pooling : verify_program<test_conv_bn_relu_pooling>
return p;
}
};
template struct test_conv_bn_relu_pooling<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -30,22 +30,27 @@
#include <migraphx/instruction.hpp>
#include <migraphx/common.hpp>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
template <migraphx::shape::type_t DType>
struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2<DType>>
{
static migraphx::instruction_ref add_bn(migraphx::module& m, migraphx::instruction_ref x)
{
auto bn_lens = x->get_shape().lens();
auto c_len = bn_lens.at(1);
migraphx::shape vars{migraphx::shape::float_type, {c_len}};
migraphx::shape vars{DType, {c_len}};
auto scale = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 1 + c_len)));
auto bias = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 2 + c_len)));
auto mean = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 3 + c_len)));
auto variance = m.add_literal(migraphx::abs(migraphx::generate_literal(vars, 4 + c_len)));
auto rt = m.add_literal(migraphx::literal{migraphx::shape::float_type, {0.5}});
auto eps = m.add_literal(migraphx::literal{migraphx::shape::float_type, {1e-5f}});
auto rt = m.add_literal(migraphx::literal{DType, {0.5}});
auto eps = m.add_literal(migraphx::literal{DType, {1e-5f}});
if constexpr((DType) == migraphx::shape::fp8e4m3fnuz_type)
{
// use 5e-2f for the fp8
eps = m.add_literal(migraphx::literal{DType, {5e-2f}});
}
auto usq_scale =
m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), scale);
auto usq_bias = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2}}}), bias);
......@@ -66,10 +71,10 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape xs1{migraphx::shape::float_type, {1, 512, 7, 7}};
migraphx::shape xs2{migraphx::shape::float_type, {1, 1024, 14, 14}};
migraphx::shape ws1{migraphx::shape::float_type, {2048, 512, 1, 1}};
migraphx::shape ws2{migraphx::shape::float_type, {2048, 1024, 1, 1}};
migraphx::shape xs1{DType, {1, 512, 7, 7}};
migraphx::shape xs2{DType, {1, 1024, 14, 14}};
migraphx::shape ws1{DType, {2048, 512, 1, 1}};
migraphx::shape ws2{DType, {2048, 1024, 1, 1}};
auto x1 = mm->add_parameter("x1", xs1);
auto w1 = mm->add_parameter("w1", ws1);
auto conv1 = mm->add_instruction(
......@@ -98,3 +103,6 @@ struct test_conv_bn_relu_pooling2 : verify_program<test_conv_bn_relu_pooling2>
return p;
}
};
template struct test_conv_bn_relu_pooling2<migraphx::shape::float_type>;
template struct test_conv_bn_relu_pooling2<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,16 +27,17 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_group_add : verify_program<test_conv_group_add>
template <migraphx::shape::type_t DType>
struct test_conv_group_add : verify_program<test_conv_group_add<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {1, 68, 28, 28}};
migraphx::shape s{DType, {1, 68, 28, 28}};
auto x = mm->add_parameter("x", s);
auto w = mm->add_parameter("w", {migraphx::shape::float_type, {68, 17, 1, 1}});
auto b = mm->add_parameter("b", {migraphx::shape::float_type, {68}});
auto w = mm->add_parameter("w", {DType, {68, 17, 1, 1}});
auto b = mm->add_parameter("b", {DType, {68}});
auto conv = mm->add_instruction(migraphx::make_op("convolution", {{"group", 4}}), x, w);
auto bb = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 68, 28, 28}}}), b);
......@@ -44,3 +45,6 @@ struct test_conv_group_add : verify_program<test_conv_group_add>
return p;
}
};
template struct test_conv_group_add<migraphx::shape::float_type>;
// grouped convolutions are not supported with MLIR therefore disable it
// template struct test_conv_group_add<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -28,16 +28,15 @@
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct test_conv_pooling : verify_program<test_conv_pooling>
template <migraphx::shape::type_t DType>
struct test_conv_pooling : verify_program<test_conv_pooling<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 32, 32}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 32, 32}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto pooling = mm->add_instruction(
migraphx::make_op("pooling", {{"mode", migraphx::op::pooling_mode::max}}), conv);
......@@ -45,3 +44,6 @@ struct test_conv_pooling : verify_program<test_conv_pooling>
return p;
}
};
template struct test_conv_pooling<migraphx::shape::float_type>;
template struct test_conv_pooling<migraphx::shape::fp8e4m3fnuz_type>;
......@@ -27,18 +27,20 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_relu : verify_program<test_conv_relu>
template <migraphx::shape::type_t DType>
struct test_conv_relu : verify_program<test_conv_relu<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
template struct test_conv_relu<migraphx::shape::float_type>;
template struct test_conv_relu<migraphx::shape::half_type>;
template struct test_conv_relu<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment