"...resnet50_tensorflow.git" did not exist on "d55a62919f070ef1d9b4bad8a118c34605ba4427"
Unverified Commit b3a610df authored by Krzysztof Drewniak's avatar Krzysztof Drewniak Committed by GitHub
Browse files

Update list of pointwise operators supported by MLIR (#1848)

Bump MLIR commit to include latest supported pointwise ops.
Expand the MLIR approve list 
Ensure that operations such as tanh() that don't have integer implementations (at least in MLIR) aren't used within MLIR modules.
Add additional tests.
parent bad39242
...@@ -191,6 +191,68 @@ struct find_mlir_op ...@@ -191,6 +191,68 @@ struct find_mlir_op
return {new_gemm_based_op, top_inputs}; return {new_gemm_based_op, top_inputs};
} }
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool is_pointwise_op_supported_by_mlir(const instruction& i) const
{
using type_t = shape::type_t;
const auto& name = i.name();
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
{
return false;
}
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"clip",
"sub",
"mul",
"div",
"pow",
"where",
"quantizelinear",
"dequantizelinear",
"abs",
"neg"};
const std::initializer_list<std::string> fp_only_ops = {"ceil",
"erf",
"exp",
"floor",
"log",
"recip",
"rsqrt",
"sigmoid"
"softmax",
"tanh"};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name))
return true;
if(is_float && contains(fp_only_ops, name))
return true;
// Only conversions between floating types are known to be unambigiously
// supported.
if(is_float && name == "convert")
{
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
}
return false;
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -198,32 +260,12 @@ struct find_mlir_op ...@@ -198,32 +260,12 @@ struct find_mlir_op
auto x_ins = r.instructions["x"]; // input after contiguous auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators.
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
return not contains({"@literal", return not is_pointwise_op_supported_by_mlir(i);
"@param",
"@return",
"convolution",
"quant_convolution",
"dot",
"quant_dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
}))
return;
// Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
}
template <class F>
migraphx::instruction_ref add_mlir(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
std::vector<std::string> arg_names,
F f)
{
assert(inputs.size() == arg_names.size() && "One interior parameter name given per input.");
auto* mm = p.get_main_module();
auto* pm = p.create_module(name);
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
for(size_t i = 0, e = inputs.size(); i < e; ++i)
{
params.push_back(pm->add_parameter(arg_names[i], inputs[i]->get_shape()));
}
auto values = f(pm, params);
auto root = std::get<0>(values);
auto r = std::get<1>(values);
pm->add_return({r});
return mm->add_instruction(
migraphx::make_op("gpu::mlir_op", {{"op", migraphx::to_value(root->get_operator())}}),
inputs,
{pm});
}
TEST_CASE(dot_add)
{
migraphx::shape s{migraphx::shape::float_type, {1, 3, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = add_pointwise(p1, "main:pointwise0", {dot, x}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s);
auto b = mm->add_parameter("b", s);
auto x = mm->add_parameter("x", s);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0",
{x, a, b},
{"x1", "y0", "y1"},
[=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("dot"), inputs[1], inputs[2]);
auto add = pm->add_instruction(migraphx::make_op("add"), dot, inputs[0]);
return std::make_tuple(dot, add);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_abs)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto abs = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("abs"));
mm->add_return({abs});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto fused = add_mlir(
p2, "mlir_main:pointwise0", {a, b}, {"y0", "y1"}, [=](auto* pm, const auto& inputs) {
auto dot =
pm->add_instruction(migraphx::make_op("quant_dot"), inputs[0], inputs[1]);
auto abs = pm->add_instruction(migraphx::make_op("abs"), dot);
return std::make_tuple(dot, abs);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(int_quant_dot_tanh_fails)
{
migraphx::shape s_a{migraphx::shape::int8_type, {5, 4}};
migraphx::shape s_b{migraphx::shape::int8_type, {4, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto a = mm->add_parameter("a", s_a);
auto b = mm->add_parameter("b", s_b);
auto dot = mm->add_instruction(migraphx::make_op("quant_dot"), a, b);
auto tanh = add_pointwise(p1, "main:pointwise0", {dot}, single_pointwise("tanh"));
mm->add_return({tanh});
}
migraphx::program p2(p1);
// This pass should do nothing as int32_t tanh isn't supported.
run_pass(p1);
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
return 0;
}
...@@ -273,4 +273,57 @@ module { ...@@ -273,4 +273,57 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(dot_convert)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>) -> tensor<1x5x3xf16> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.convert(%0) {target_type = 1 : i64} : (tensor<1x5x3xf32>) -> tensor<1x5x3xf16>
return %1 : tensor<1x5x3xf16>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto trunc = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), dot);
m.add_return({trunc});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
TEST_CASE(dot_where)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @mlir_dot(%arg0: tensor<1x5x4xf32>, %arg1: tensor<1x4x3xf32>, %arg2: tensor<1x5x3xi8>, %arg3: tensor<1x5x3xf32>) -> tensor<1x5x3xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.dot(%arg0, %arg1) : (tensor<1x5x4xf32>, tensor<1x4x3xf32>) -> tensor<1x5x3xf32>
%1 = migraphx.where(%arg2, %0, %arg3) : (tensor<1x5x3xi8>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor<1x5x3xf32>
return %1 : tensor<1x5x3xf32>
}
}
)__migraphx__";
migraphx::module m;
auto arg0 = m.add_parameter("arg0", {migraphx::shape::float_type, {1, 5, 4}});
auto arg1 = m.add_parameter("arg1", {migraphx::shape::float_type, {1, 4, 3}});
auto arg2 = m.add_parameter("arg2", {migraphx::shape::bool_type, {1, 5, 3}});
auto arg3 = m.add_parameter("arg3", {migraphx::shape::float_type, {1, 5, 3}});
auto dot = m.add_instruction(migraphx::make_op("dot"), arg0, arg1);
auto where = m.add_instruction(migraphx::make_op("where"), arg2, dot, arg3);
m.add_return({where});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment