Commit 4cc5393d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into subwave-reduce

parents f7d97e53 fe61d940
...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]}) ...@@ -58,6 +58,10 @@ Set the default dynamic dimension (format {min:x, max:y, optimals:[o1,o2,...]})
Optimize when reading Optimize when reading
.. option:: --apply-pass, -p
Passes to apply to model
.. option:: --graphviz, -g .. option:: --graphviz, -g
Print out a graphviz representation. Print out a graphviz representation.
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
add_executable(driver add_executable(driver
main.cpp main.cpp
verify.cpp verify.cpp
passes.cpp
perf.cpp perf.cpp
resnet50.cpp resnet50.cpp
inceptionv3.cpp inceptionv3.cpp
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "argument_parser.hpp" #include "argument_parser.hpp"
#include "command.hpp" #include "command.hpp"
#include "precision.hpp" #include "precision.hpp"
#include "passes.hpp"
#include "perf.hpp" #include "perf.hpp"
#include "models.hpp" #include "models.hpp"
#include "marker_roctx.hpp" #include "marker_roctx.hpp"
...@@ -83,6 +84,7 @@ struct loader ...@@ -83,6 +84,7 @@ struct loader
std::vector<std::string> param_dims; std::vector<std::string> param_dims;
std::vector<std::string> dyn_param_dims; std::vector<std::string> dyn_param_dims;
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::vector<std::string> passes;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -130,6 +132,7 @@ struct loader ...@@ -130,6 +132,7 @@ struct loader
ap.append(), ap.append(),
ap.nargs(2)); ap.nargs(2));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
ap(passes, {"--apply-pass", "-p"}, ap.help("Passes to apply to model"), ap.append());
ap(output_type, ap(output_type,
{"--graphviz", "-g"}, {"--graphviz", "-g"},
ap.help("Print out a graphviz representation."), ap.help("Print out a graphviz representation."),
...@@ -337,6 +340,8 @@ struct loader ...@@ -337,6 +340,8 @@ struct loader
migraphx::dead_code_elimination{}, migraphx::dead_code_elimination{},
}); });
} }
if(not passes.empty())
migraphx::run_passes(*p.get_main_module(), get_passes(passes));
return p; return p;
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "passes.hpp"
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_allocation.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/optimize_module.hpp>
#include <migraphx/promote_literals.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_dyn_ops.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
std::unordered_map<std::string, pass> create_passes_lookup()
{
std::unordered_map<std::string, pass> result;
// clang-format off
std::initializer_list<pass> passes = {
auto_contiguous{},
dead_code_elimination{},
eliminate_allocation{},
eliminate_common_subexpression{},
eliminate_concat{},
eliminate_contiguous{},
eliminate_data_type{},
eliminate_identity{},
eliminate_pad{},
inline_module{},
insert_pad{},
normalize_ops{},
optimize_module{},
promote_literals{},
propagate_constant{},
rewrite_gelu{},
rewrite_pooling{},
rewrite_quantization{},
rewrite_rnn{},
simplify_algebra{},
simplify_dyn_ops{},
simplify_qdq{},
simplify_reshapes{},
};
// clang-format on
for(const auto& pass : passes)
result[pass.name()] = pass;
result["eliminate_dead_code"] = dead_code_elimination{};
return result;
}
std::vector<pass> get_passes(const std::vector<std::string>& names)
{
std::vector<pass> result;
static const std::unordered_map<std::string, pass> lookup = create_passes_lookup();
std::transform(
names.begin(), names.end(), std::back_inserter(result), [](const std::string& name) {
if(not contains(lookup, name))
MIGRAPHX_THROW("Unknown pass: " + name);
return lookup.at(name);
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
} // namespace migraphx
...@@ -21,24 +21,20 @@ ...@@ -21,24 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#ifndef MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#define MIGRAPHX_GUARD_DRIVER_PASSES_HPP
#include "verify_program.hpp" #include <migraphx/pass.hpp>
#include <migraphx/program.hpp> #include <vector>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_conv_relu_half : verify_program<test_conv_relu_half> namespace migraphx {
{ namespace driver {
migraphx::program create_program() const inline namespace MIGRAPHX_INLINE_NS {
{
migraphx::program p; std::vector<pass> get_passes(const std::vector<std::string>& names);
auto* mm = p.get_main_module();
auto input = } // namespace MIGRAPHX_INLINE_NS
mm->add_parameter("x", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}}); } // namespace driver
auto weights = } // namespace migraphx
mm->add_parameter("w", migraphx::shape{migraphx::shape::half_type, {4, 3, 3, 3}});
auto conv = mm->add_instruction(migraphx::make_op("convolution"), input, weights); #endif
mm->add_instruction(migraphx::make_op("relu"), conv);
return p;
}
};
...@@ -72,8 +72,8 @@ struct dequantizelinear ...@@ -72,8 +72,8 @@ struct dequantizelinear
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) { visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) { visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) - output[i] = static_cast<double>(static_cast<double>(input[i]) -
static_cast<int64_t>(zero_pts[i])) * static_cast<double>(zero_pts[i])) *
scales[i]; scales[i];
}); });
}); });
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp> #include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/convolution.hpp> #include <migraphx/convolution.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
...@@ -87,11 +88,13 @@ struct quant_convolution ...@@ -87,11 +88,13 @@ struct quant_convolution
} }
// all input type must be int8_type and output is float_type // all input type must be int8_type and output is float_type
if(t != shape::int8_type) std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t"); MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
} }
t = shape::int32_type;
std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]}; std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size(); auto padding_size = padding.size();
...@@ -107,8 +110,11 @@ struct quant_convolution ...@@ -107,8 +110,11 @@ struct quant_convolution
stride[i] + stride[i] +
1))); 1)));
} }
if(t == shape::int8_type)
return inputs[0].with_lens(t, output_lens); {
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
} }
size_t kdims() const size_t kdims() const
......
...@@ -80,10 +80,10 @@ struct quantizelinear ...@@ -80,10 +80,10 @@ struct quantizelinear
auto min_value = std::numeric_limits<quant_type>::min(); auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max(); auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::nearbyint(input[i] / scales[i])) + double quantized = static_cast<double>(std::nearbyint(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]); static_cast<double>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value), output[i] = std::max(static_cast<double>(min_value),
std::min(static_cast<int64_t>(max_value), quantized)); std::min(static_cast<double>(max_value), quantized));
}); });
}); });
}); });
......
...@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype) ...@@ -625,7 +625,11 @@ shape::type_t get_type(int dtype)
case 11: return shape::double_type; case 11: return shape::double_type;
case 12: return shape::uint32_type; case 12: return shape::uint32_type;
case 13: return shape::uint64_type; case 13: return shape::uint64_type;
case 18: return shape::fp8e4m3fnuz_type; case 18: {
std::cout << "[Warning] : MIGraphX has BETA support for FP8. Using FP8 may result in "
"incorrect final outputs\n";
return shape::fp8e4m3fnuz_type;
}
case 14: case 14:
case 15: case 15:
case 16: case 16:
......
...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -58,8 +58,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point); add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
} }
int64_t max_quant = 0; double max_quant = 0;
int64_t min_quant = 0; double min_quant = 0;
ins->get_shape().visit_type([&](auto qt) { ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max(); max_quant = qt.max();
min_quant = qt.min(); min_quant = qt.min();
...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins) ...@@ -70,8 +70,8 @@ void apply_quantizelinear(module& m, instruction_ref ins)
if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{})) if(enabled(MIGRAPHX_ENABLE_CK_WORKAROUNDS{}))
{ {
std::vector<int> min_data(s.elements(), min_quant); std::vector<double> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant); std::vector<double> max_data(s.elements(), max_quant);
min_arg = m.add_literal(literal(s, min_data)); min_arg = m.add_literal(literal(s, min_data));
max_arg = m.add_literal(literal(s, max_data)); max_arg = m.add_literal(literal(s, max_data));
} }
......
...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops ...@@ -82,18 +82,21 @@ struct match_find_quantizable_ops
// Helper function to insert quantized versions of any broadcasts and transpose ops that // Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op // occur between dequantizelinear and the quantized op
static auto static auto
propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop) propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop_arg)
{ {
auto qinp = dqins->inputs().front(); auto prev_ins = qop_arg;
auto next_ins = dqins; std::vector<instruction_ref> ins_inbetween;
// matcher skips continguous, multi/broadcasts and transposes, collect all those
while(next_ins != qop) // instructions
{ while(prev_ins != dqins)
if(next_ins->name() != "dequantizelinear")
{ {
qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp); ins_inbetween.push_back(prev_ins);
prev_ins = prev_ins->inputs().front();
} }
next_ins = next_ins->outputs().front(); auto qinp = dqins->inputs().front();
for(auto ins : reverse_iterator_for(ins_inbetween))
{
qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp});
} }
return qinp; return qinp;
} }
...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops ...@@ -124,10 +127,11 @@ struct match_find_quantizable_ops
auto scale2 = r.instructions["scale2"]; auto scale2 = r.instructions["scale2"];
auto zp1 = r.instructions["zp1"]; auto zp1 = r.instructions["zp1"];
auto zp2 = r.instructions["zp2"]; auto zp2 = r.instructions["zp2"];
// Only INT8 or FP8 type currently supported
// Only INT8 type currently supported std::set<migraphx::shape::type_t> supported_types = {migraphx::shape::fp8e4m3fnuz_type,
if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or migraphx::shape::int8_type};
dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type) if(not contains(supported_types, dq1->inputs().front()->get_shape().type()) or
not contains(supported_types, dq2->inputs().front()->get_shape().type()))
return; return;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed) // Only symmetric quantization supported (ie. non-zero zero_points not allowed)
...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops ...@@ -140,8 +144,8 @@ struct match_find_quantizable_ops
// Propagate q1 and q2 through any broadcasts and transposes before qop // Propagate q1 and q2 through any broadcasts and transposes before qop
auto qop_args = qop->inputs(); auto qop_args = qop->inputs();
qop_args.at(0) = propagate_quantized_ins(m, dq1, qop); qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0]);
qop_args.at(1) = propagate_quantized_ins(m, dq2, qop); qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1]);
instruction_ref dq; instruction_ref dq;
instruction_ref out_scale; instruction_ref out_scale;
instruction_ref zero_point; instruction_ref zero_point;
......
...@@ -49,6 +49,12 @@ std::string get_device_name() ...@@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName; return props.gcnArchName;
} }
bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -218,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return false; return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
if(group != 1) if(group != 1)
...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode) ...@@ -225,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!" // Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4) if(ins->get_shape().lens().size() != 4)
return false; return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type) if(ins->get_shape().type() == shape::int8_type)
return true; return true;
if(mode == mlir_mode::int8) if(mode == mlir_mode::int8)
...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type(); const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type, const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type, type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type, type_t::int8_type,
type_t::int32_type, type_t::int32_type,
type_t::bool_type}; type_t::bool_type};
...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax", "softmax",
"tanh", "tanh",
}; };
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) ...@@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported. // supported.
if(is_float and name == "convert") if(is_float and name == "convert")
{ {
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
}); });
...@@ -404,11 +415,12 @@ struct find_mlir_standalone_op ...@@ -404,11 +415,12 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto gemm_based_op = r.result; auto gemm_based_op = r.result;
// // enable only for fp32/fp16/i8/fp8 types
// enable only for fp32/fp16/i8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) { if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains( return not contains({shape::type_t::float_type,
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type}, shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
...@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const ...@@ -531,7 +543,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
match::find_matches( match::find_matches(
mpm, mpm,
find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)}, find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::fast)},
find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)}); find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
#else #else
(void)mpm; (void)mpm;
......
...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name(); ...@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -300,6 +300,8 @@ struct mlir_program ...@@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get()); result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type) else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get()); result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type) else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
......
...@@ -58,8 +58,7 @@ bool rocblas_fp8_available() ...@@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API #ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false; return false;
#else #else
const auto device_name = trim(split_string(get_device_name(), ':').front()); return gfx_has_fp8_intrinsics();
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
#endif #endif
} }
......
...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type); unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {}; std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available()) if(not gpu::rocblas_fp8_available())
{ {
unsupported_fp8_ops.insert("dot"); unsupported_fp8_ops.insert("dot");
} }
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels // add all device kernels
unsupported_fp8_ops.insert("logsoftmax"); unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero"); unsupported_fp8_ops.insert("nonzero");
......
...@@ -527,6 +527,62 @@ TEST_CASE(dot_add) ...@@ -527,6 +527,62 @@ TEST_CASE(dot_add)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(dot_add_multiple_dq_use)
{
migraphx::shape sh1{migraphx::shape::float_type, {32, 1}};
migraphx::shape sh2{migraphx::shape::float_type, {32, 32}};
migraphx::module m1;
{
auto t1 = m1.add_parameter("t1", sh1);
auto t2 = m1.add_parameter("t2", sh2);
auto scale = m1.add_literal(0.5f);
auto zero = m1.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero);
auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero);
auto d1_t =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d1);
auto d1_tmb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), d1_t);
auto d1_tmbc = m1.add_instruction(migraphx::make_op("contiguous"), d1_tmb);
auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero);
auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero);
auto dot_1 = m1.add_instruction(migraphx::make_op("dot"), d1_tmbc, d2);
auto q3 = add_quantize_op(m1, "quantizelinear", dot_1, scale, zero);
auto d3 = add_quantize_op(m1, "dequantizelinear", q3, scale, zero);
auto dot_2 = m1.add_instruction(migraphx::make_op("dot"), d3, d1);
auto add = m1.add_instruction(migraphx::make_op("add"), {dot_2, d1});
m1.add_return({add});
}
migraphx::module m2;
{
auto t1 = m2.add_parameter("t1", sh1);
auto t2 = m2.add_parameter("t2", sh2);
auto scale = m2.add_literal(0.5f);
auto zero = m2.add_literal(std::int8_t{0});
auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero);
auto q1_t =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q1);
auto q1_tmb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {32, 32}}}), q1_t);
auto q1_tmbc = m2.add_instruction(migraphx::make_op("contiguous"), q1_tmb);
auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero);
auto dot_1 = m2.add_instruction(migraphx::make_op("quant_dot"), q1_tmbc, q2);
auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot_1->get_shape().lens());
auto d3 = add_quantize_op(m2, "dequantizelinear", dot_1, out_scale);
auto d3_q = add_quantize_op(m2, "quantizelinear", d3, scale, zero);
auto dot_2 = m2.add_instruction(migraphx::make_op("quant_dot"), d3_q, q1);
auto out_scale_2 = add_scale_mul(m2, scale, scale, 1, 1, dot_2->get_shape().lens());
auto d4 = add_quantize_op(m2, "dequantizelinear", dot_2, out_scale_2);
auto add = m2.add_instruction(migraphx::make_op("add"), d4, t1);
m2.add_return({add});
}
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(conv) TEST_CASE(conv)
{ {
migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}};
...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet) ...@@ -919,7 +975,6 @@ TEST_CASE(mobilenet_snippet)
auto mod1 = create_module(); auto mod1 = create_module();
auto mod2 = create_module(); auto mod2 = create_module();
run_pass(mod2); run_pass(mod2);
auto match_qdq = migraphx::match::name("dequantizelinear")( auto match_qdq = migraphx::match::name("dequantizelinear")(
......
...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[]) ...@@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim", "test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>", "test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"}); "test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv); rv.run(argc, argv);
} }
...@@ -27,17 +27,21 @@ ...@@ -27,17 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct quant_conv : verify_program<quant_conv> template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}}; migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape); auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}}; migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape); auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc); mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p; return p;
} }
}; };
template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment