Unverified Commit 6a72e8fc authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

FP8 2D forward convolution using rocMLIR (#2507)

parent a09dc502
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp> #include <migraphx/convolution.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
...@@ -87,11 +88,13 @@ struct quant_convolution ...@@ -87,11 +88,13 @@ struct quant_convolution
} }
// all input type must be int8_type and output is float_type // all input type must be int8_type and output is float_type
if(t != shape::int8_type) std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
} }
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size(); auto padding_size = padding.size();
...@@ -107,8 +110,11 @@ struct quant_convolution ...@@ -107,8 +110,11 @@ struct quant_convolution
stride[i] + stride[i] +
1))); 1)));
} }
if(t == shape::int8_type)
return inputs[0].with_lens(t, output_lens); {
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported. // supported.
if(is_float and name == "convert") if(is_float and name == "convert")
{ {
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
}); });
...@@ -404,11 +415,12 @@ struct find_mlir_standalone_op ...@@ -404,11 +415,12 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto gemm_based_op = r.result; auto gemm_based_op = r.result;
// // enable only for fp32/fp16/i8/fp8 types
// enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
......
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -300,6 +300,8 @@ struct mlir_program ...@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
...@@ -58,8 +58,7 @@ bool rocblas_fp8_available() ...@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false; return false;
#else #else
const auto device_name = trim(split_string(get_device_name(), ':').front()); return gfx_has_fp8_intrinsics();
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif #endif
} }
......
...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {}; std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available()) if(not gpu::rocblas_fp8_available())
{ {
unsupported_fp8_ops.insert("dot"); unsupported_fp8_ops.insert("dot");
} }
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels // add all device kernels
unsupported_fp8_ops.insert("logsoftmax"); unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero"); unsupported_fp8_ops.insert("nonzero");
......
...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) ...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"}); "test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv : verify_program<quant_conv> template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p; return p;
} }
}; };
template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_1 : verify_program<quant_conv_1> template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_1d : verify_program<quant_conv_1d> template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}}; migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", migraphx::make_op("quant_convolution",
...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d> ...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p; return p;
} }
}; };
template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_2 : verify_program<quant_conv_2> template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding : verify_program<quant_conv_padding> template <migraphx::shape::type_t DType>
struct quant_conv_padding : verify_program<quant_conv_padding<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}), migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
...@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding> ...@@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return p; return p;
} }
}; };
template struct quant_conv_padding<migraphx::shape::int8_type>;
template struct quant_conv_padding<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> template <migraphx::shape::type_t DType>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}), migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
...@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride> ...@@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return p; return p;
} }
}; };
template struct quant_conv_padding_stride<migraphx::shape::int8_type>;
template struct quant_conv_padding_stride<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p, ...@@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
{ {
migraphx::target t = migraphx::make_target("ref"); migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()}; auto_print pp{p, t.name()};
compile_check(p, t, c_opts); auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
return p.eval(std::move(inputs)); return p.eval(std::move(inputs));
} }
std::pair<migraphx::program, std::vector<migraphx::argument>> std::pair<migraphx::program, std::vector<migraphx::argument>>
......
...@@ -27,17 +27,19 @@ ...@@ -27,17 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv : verify_program<test_conv> template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights); mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p; return p;
} }
}; };
template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,15 @@ ...@@ -27,16 +27,15 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv2 : verify_program<test_conv2> template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
mm->add_instruction( mm->add_instruction(
migraphx::make_op("convolution", migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}), {{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2> ...@@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p; return p;
} }
}; };
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add : verify_program<test_conv_add> template <migraphx::shape::type_t DType>
struct test_conv_add : verify_program<test_conv_add<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto x = mm->add_parameter("x", {DType, {1, 8, 4, 4}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 3, 3}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 3, 3}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v); auto conv2 = mm->add_instruction(migraphx::make_op("convolution"), y, v);
auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2); auto sum = mm->add_instruction(migraphx::make_op("add"), conv1, conv2);
...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add> ...@@ -46,3 +45,6 @@ struct test_conv_add : verify_program<test_conv_add>
return p; return p;
} }
}; };
template struct test_conv_add<migraphx::shape::float_type>;
template struct test_conv_add<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,18 +27,17 @@ ...@@ -27,18 +27,17 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides> template <migraphx::shape::type_t DType>
struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_strides<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 8, 2, 2}}); auto x = mm->add_parameter("x", {DType, {1, 8, 2, 2}});
auto w = mm->add_literal( auto w = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 1));
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 1)); auto y = mm->add_parameter("y", {DType, {1, 8, 4, 4}});
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 8, 4, 4}}); auto v = mm->add_literal(migraphx::generate_literal({DType, {2, 8, 1, 1}}, 2));
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {2, 8, 1, 1}}, 2));
auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w); auto conv1 = mm->add_instruction(migraphx::make_op("convolution"), x, w);
auto conv2 = mm->add_instruction( auto conv2 = mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v); migraphx::make_op("convolution", {{"padding", {0, 0}}, {"stride", {2, 2}}}), y, v);
...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st ...@@ -47,3 +46,6 @@ struct test_conv_add_1x1_diff_strides : verify_program<test_conv_add_1x1_diff_st
return p; return p;
} }
}; };
template struct test_conv_add_1x1_diff_strides<migraphx::shape::float_type>;
template struct test_conv_add_1x1_diff_strides<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -28,18 +28,17 @@ ...@@ -28,18 +28,17 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
struct test_conv_add_relu : verify_program<test_conv_add_relu> template <migraphx::shape::type_t DType>
struct test_conv_add_relu : verify_program<test_conv_add_relu<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto input = auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = auto bias_literal =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); migraphx::literal{migraphx::shape{DType, {4}}, {2.0f, 2.0f, 2.0f, 2.0f}};
auto bias_literal = migraphx::literal{migraphx::shape{migraphx::shape::float_type, {4}},
{2.0f, 2.0f, 2.0f, 2.0f}};
auto bias = mm->add_literal(bias_literal); auto bias = mm->add_literal(bias_literal);
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights);
auto bcast_bias = mm->add_instruction( auto bcast_bias = mm->add_instruction(
...@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu> ...@@ -50,3 +49,6 @@ struct test_conv_add_relu : verify_program<test_conv_add_relu>
return p; return p;
} }
}; };
template struct test_conv_add_relu<migraphx::shape::float_type>;
template struct test_conv_add_relu<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment