Commit 38a196f6 authored by charlie's avatar charlie
Browse files

Revert "Fix eliminate_contiguous pass"

This reverts commit 52113ea0.
parent 52113ea0
......@@ -42,13 +42,6 @@ static bool try_compute_shape(instruction_ref ins,
try
{
shape new_shape = ins->get_operator().compute_shape(inputs, mods);
// Cannot tell if a dynamic shape will need to be made contiguous
if(new_shape.dynamic())
{
return false;
}
// If the output shape is a standard shape, no need to try its output
if(new_shape.standard())
{
......@@ -143,12 +136,9 @@ static void remove_contiguous(const std::string& op_name, module& m, F f)
// Perform evaluations in parallel
std::vector<argument> literals(const_instructions.size());
par_for(const_instructions.size(), 1, [&](const auto i) {
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
std::vector<shape> prev_shape = {prev->get_shape()};
const std::vector<argument>& prev_eval = {prev->eval()};
auto co_shape = make_compute_output_shape(pack(c, prev_shape, prev_eval));
literals[i] = c.compute(co_shape, {prev->eval()});
auto c = op::contiguous{};
auto prev = const_instructions[i]->inputs().front();
literals[i] = c.compute(c.compute_shape({prev->get_shape()}), {prev->eval()});
});
for(size_t i = 0; i < const_instructions.size(); i++)
......
......@@ -28,7 +28,6 @@
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -43,31 +42,19 @@ namespace op {
struct contiguous
{
std::string name() const { return "contiguous"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs.front();
if(s0.dynamic())
{
return s0;
}
else
{
if(s0.standard())
{
return inputs.front();
}
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
}
check_shapes{inputs, *this}.has(1);
if(inputs.front().standard())
return inputs.front();
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return {t, lens};
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
argument compute(const shape& output_shape, std::vector<argument> args) const
{
assert(dyn_out.computed_shape.standard());
argument result{dyn_out.computed_shape};
assert(output_shape.standard());
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end());
......
......@@ -49,8 +49,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
// auto_contiguous{},
// dead_code_elimination{},
auto_contiguous{},
dead_code_elimination{},
lowering{},
dead_code_elimination{}};
}
......
......@@ -357,12 +357,6 @@ TEST_CASE(contiguous_shape)
expect_shape(single, migraphx::make_op("contiguous"), single);
}
TEST_CASE(contiguous_dyn_shape)
{
migraphx::shape s0{migraphx::shape::float_type, {{1, 4, 0}, {2, 2, 2}}};
expect_shape(s0, migraphx::make_op("contiguous"), s0);
}
TEST_CASE(contiguous_shape_scalar)
{
migraphx::shape output{migraphx::shape::float_type};
......
......@@ -918,88 +918,11 @@ TEST_CASE(contiguous_test)
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::cout << "results_vector: [";
for(auto r : results_vector)
std::cout << r << ", ";
std::cout << "]\n";
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(contiguous_param_test)
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("X", a_shape);
mm->add_instruction(migraphx::make_op("contiguous"), a);
p.compile(migraphx::ref::target{});
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(a_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::cout << "results_vector: [";
for(auto r : results_vector)
std::cout << r << ", ";
std::cout << "]\n";
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
}
TEST_CASE(contiguous_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape dyn_shape{migraphx::shape::float_type,
{{1, 1, 0}, {2, 6, 0}, {2, 2, 0}, {2, 2, 0}}};
auto input = mm->add_parameter("X", dyn_shape);
mm->add_instruction(migraphx::make_op("contiguous"), input);
p.compile(migraphx::ref::target{});
migraphx::shape static_shape{migraphx::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12);
std::iota(data.begin(), data.end(), 0);
migraphx::parameter_map params;
params["X"] = migraphx::argument(static_shape, data.data());
auto result = p.eval(params).back();
result.visit([&](auto output) {
std::vector<size_t> new_strides = {12, 4, 2, 1};
EXPECT(bool{output.get_shape().strides() == new_strides});
});
std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::cout << "results_vector: [";
for(auto r : results_vector)
std::cout << r << ", ";
std::cout << "]\n";
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(conv_dynamic_batch_test)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment