Commit 68dd3bb4 authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 8d7a8a6c 7e53592e
...@@ -89,7 +89,7 @@ requests==2.28.2 ...@@ -89,7 +89,7 @@ requests==2.28.2
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.30.0 rocm-docs-core==0.30.1
# via -r requirements.in # via -r requirements.in
smmap==5.0.0 smmap==5.0.0
# via gitdb # via gitdb
......
...@@ -72,8 +72,8 @@ struct dequantizelinear ...@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) { visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) { visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) - output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<int64_t>(zero_pts[i])) * static_cast<double>(zero_pts[i])) *
scales[i]; scales[i];
}); });
}); });
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp> #include <migraphx/convolution.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
...@@ -87,11 +88,13 @@ struct quant_convolution ...@@ -87,11 +88,13 @@ struct quant_convolution
} }
// all input type must be int8_type and output is float_type // all input type must be int8_type and output is float_type
if(t != shape::int8_type) std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
} }
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size(); auto padding_size = padding.size();
...@@ -107,8 +110,11 @@ struct quant_convolution ...@@ -107,8 +110,11 @@ struct quant_convolution
stride[i] + stride[i] +
1))); 1)));
} }
if(t == shape::int8_type)
return inputs[0].with_lens(t, output_lens); {
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -80,10 +80,10 @@ struct quantizelinear ...@@ -80,10 +80,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) + double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<double>(max_value), quantized));
}); });
}); });
}); });
......
...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts) ...@@ -669,6 +669,15 @@ void module::finalize(std::vector<context>& contexts)
smod->finalize(contexts); smod->finalize(contexts);
} }
} }
#ifndef BUILD_DEV
if(std::any_of(this->begin(), this->end(), [](const auto i) {
return i.get_shape().type() == migraphx::shape::fp8e4m3fnuz_type;
}))
{
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
}
#endif
// Warn when an instruction is not normalized // Warn when an instruction is not normalized
auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); }); auto ins = std::find_if(begin(), end(), [](auto& i) { return i.need_normalization(); });
......
...@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype) ...@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type; case 18: {
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
return shape::fp8e4m3fnuz_type;
}
case 14: case 14:
case 15: case 15:
case 16: case 16:
......
...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
int64_t max_quant = 0; double max_quant = 0;
int64_t min_quant = 0; double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) { ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{ {
std::vector<int> min_data(s.elements(), min_quant); std::vector<double> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant); std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data)); min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data)); max_arg = m.add_literal(literal(s, max_data));
} }
......
...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops ...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that // Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op // occur between dequantizelinear and the quantized op
static auto static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop) propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{ {
auto qinp = dqins->inputs().front(); auto prev_ins = qop_arg;
auto next_ins = dqins; std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
while(next_ins != qop) // instructions
while(prev_ins != dqins)
{ {
if(next_ins->name() != "dequantizelinear") ins_inbetween.push_back(prev_ins);
{ prev_ins = prev_ins->inputs().front();
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp); }
} auto qinp = dqins->inputs().front();
next_ins = next_ins->outputs().front(); for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
} }
return qinp; return qinp;
} }
...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops ...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"]; auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"]; auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
// Only INT8 type currently supported std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or migraphx::shape::int8_type};
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type) if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return; return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed) // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops ...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop // Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop); qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop); qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq; instruction_ref dq;
instruction_ref out_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
......
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported. // supported.
if(is_float and name == "convert") if(is_float and name == "convert")
{ {
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
}); });
...@@ -404,12 +415,13 @@ struct find_mlir_standalone_op ...@@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto gemm_based_op = r.result; auto gemm_based_op = r.result;
// // enable only for fp32/fp16/i8/fp8 types
// enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, shape::type_t::half_type,
i->get_shape().type()); shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
})) }))
return; return;
static size_t counter = 0; static size_t counter = 0;
......
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -300,6 +300,8 @@ struct mlir_program ...@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
...@@ -58,8 +58,7 @@ bool rocblas_fp8_available() ...@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false; return false;
#else #else
const auto device_name = trim(split_string(get_device_name(), ':').front()); return gfx_has_fp8_intrinsics();
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif #endif
} }
......
...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {}; std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available()) if(not gpu::rocblas_fp8_available())
{ {
unsupported_fp8_ops.insert("dot"); unsupported_fp8_ops.insert("dot");
} }
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels // add all device kernels
unsupported_fp8_ops.insert("logsoftmax"); unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero"); unsupported_fp8_ops.insert("nonzero");
......
...@@ -527,6 +527,62 @@ TEST_CASE(dot_add) ...@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_multiple_dq_use)
{
migraphx::shape sh1{migraphx::shape::float_type, {32, 1}};
migraphx::shape sh2{migraphx::shape::float_type, {32, 32}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto d1_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d1);
auto d1_tmb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), d1_t);
auto d1_tmbc = m1.add_instruction(migraphx::make_op("contiguous"), d1_tmb);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot_1 = m1.add_instruction(migraphx::make_op("dot"), d1_tmbc, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot_1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto dot_2 = m1.add_instruction(migraphx::make_op("dot"), d3, d1);
auto add = m1.add_instruction(migraphx::make_op("add"), {dot_2, d1});
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q1_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q1);
auto q1_tmb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), q1_t);
auto q1_tmbc = m2.add_instruction(migraphx::make_op("contiguous"), q1_tmb);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale);
auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero);
auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1);
auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens());
auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2);
auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv) TEST_CASE(conv)
{ {
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet) ...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto mod1 = create_module(); auto mod1 = create_module();
auto mod2 = create_module(); auto mod2 = create_module();
run_pass(mod2); run_pass(mod2);
auto match_qdq = migraphx::match::name("dequantizelinear")( auto match_qdq = migraphx::match::name("dequantizelinear")(
......
...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) ...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"}); "test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv : verify_program<quant_conv> template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p; return p;
} }
}; };
template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_1 : verify_program<quant_conv_1> template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,16 @@ ...@@ -27,15 +27,16 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv_1d : verify_program<quant_conv_1d> template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}}; migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction( mm->add_instruction(
migraphx::make_op("quant_convolution", migraphx::make_op("quant_convolution",
...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d> ...@@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p; return p;
} }
}; };
template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp> #include <migraphx/op/quant_convolution.hpp>
struct quant_conv_2 : verify_program<quant_conv_2> template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}}; migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}}; migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc); mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p; return p;
} }
}; };
template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment