Unverified Commit 450c5e84 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fold const on last instruction (#1626)

This is the original testcase that sparked the error with missing proper const
folding. Pushing changes up to this branch and closing out the PR #1622
parent aafaeae3
...@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -44,7 +44,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
...@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const ...@@ -54,14 +54,23 @@ void propagate_constant::apply(module& m) const
// Find instructions that can be evaluated to a literal // Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m)) for(auto i : iterator_for(m))
{ {
if(is_const(i) and i != last) const bool is_const = is_const_ins(i);
if(is_const and i != last)
continue; continue;
std::copy_if( if(i == last and is_const)
i->inputs().begin(), {
i->inputs().end(), const_instrs.insert(i);
std::inserter(const_instrs, const_instrs.begin()), }
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; }); else
{
std::copy_if(i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) {
return is_const_ins(ins) and ins->name() != "@literal";
});
}
} }
// Compute literals in parallel // Compute literals in parallel
......
...@@ -163,4 +163,26 @@ TEST_CASE(const_dot) ...@@ -163,4 +163,26 @@ TEST_CASE(const_dot)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(last_const)
{
const std::vector<float> vec = {1.0f, 2.0f, 1.0f, 2.0f};
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}};
auto l = m1.add_literal(migraphx::literal(s, vec));
m1.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), l);
}
run_pass(m1);
migraphx::module m2;
{
migraphx::shape s{migraphx::shape::half_type, {2, 2}};
auto l = m2.add_literal(migraphx::literal(s, vec));
m2.add_instruction(migraphx::make_op("identity"), l);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_gather_literal_inputs : verify_program<test_gather_literal_inputs>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape g_shape{migraphx::shape::int32_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment