Commit b3c2e278 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into op_capture

parents f41abee5 0d796941
...@@ -44,8 +44,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -44,8 +44,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
# Install cget # Install cget
# RUN pip install cget RUN pip install cget
RUN pip install https://github.com/pfultz2/cget/archive/57b3289000fcdb3b7e424c60a35ea09bc44d8538.tar.gz
# Install rclone # Install rclone
RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz RUN pip install https://github.com/pfultz2/rclone/archive/master.tar.gz
......
...@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const ...@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate // Make a copy so we can modify it while we iterate
auto args = ins->inputs(); auto args = ins->inputs();
for(auto arg : ins->inputs()) for(auto arg : ins->inputs())
......
...@@ -21,6 +21,14 @@ bool disabled(const char* name) ...@@ -21,6 +21,14 @@ bool disabled(const char* name)
return contains({"0", "disable", "disabled", "no", "false"}, e.front()); return contains({"0", "disable", "disabled", "no", "false"}, e.front());
} }
std::size_t value_of(const char* name)
{
auto e = env(name);
if(e.empty())
return 0;
return std::stoul(e.front());
}
std::vector<std::string> env(const char* name) std::vector<std::string> env(const char* name)
{ {
auto p = std::getenv(name); auto p = std::getenv(name);
......
...@@ -103,6 +103,13 @@ struct check_shapes ...@@ -103,6 +103,13 @@ struct check_shapes
return *this; return *this;
} }
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
const check_shapes& packed() const const check_shapes& packed() const
{ {
if(!this->all_of([](const shape& s) { return s.packed(); })) if(!this->all_of([](const shape& s) { return s.packed(); }))
......
...@@ -19,6 +19,8 @@ bool enabled(const char* name); ...@@ -19,6 +19,8 @@ bool enabled(const char* name);
bool disabled(const char* name); bool disabled(const char* name);
std::vector<std::string> env(const char* name); std::vector<std::string> env(const char* name);
std::size_t value_of(const char* name);
template <class T> template <class T>
bool enabled(T) bool enabled(T)
{ {
...@@ -33,6 +35,13 @@ bool disabled(T) ...@@ -33,6 +35,13 @@ bool disabled(T)
return result; return result;
} }
template <class T>
std::size_t value_of(T)
{
static const std::size_t result = value_of(T::value());
return result;
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -29,7 +29,7 @@ struct reshape ...@@ -29,7 +29,7 @@ struct reshape
std::string name() const { return "reshape"; } std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1).standard();
auto&& idims = inputs.front().lens(); auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end()); std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1); auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
...@@ -29,6 +29,7 @@ struct squeeze ...@@ -29,6 +29,7 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -29,6 +29,7 @@ struct unsqueeze ...@@ -29,6 +29,7 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; } std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -27,7 +27,8 @@ struct raw_data : raw_data_base ...@@ -27,7 +27,8 @@ struct raw_data : raw_data_base
template <class Stream> template <class Stream>
friend Stream& operator<<(Stream& os, const Derived& d) friend Stream& operator<<(Stream& os, const Derived& d)
{ {
d.visit([&](auto x) { os << x; }); if(not d.empty())
d.visit([&](auto x) { os << x; });
return os; return os;
} }
...@@ -40,8 +41,11 @@ struct raw_data : raw_data_base ...@@ -40,8 +41,11 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit_at(Visitor v, std::size_t n = 0) const void visit_at(Visitor v, std::size_t n = 0) const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& derived = static_cast<const Derived&>(*this);
auto&& buffer = static_cast<const Derived&>(*this).data(); if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); }); s.visit_type([&](auto as) { v(*(as.from(buffer) + s.index(n))); });
} }
...@@ -55,8 +59,11 @@ struct raw_data : raw_data_base ...@@ -55,8 +59,11 @@ struct raw_data : raw_data_base
template <class Visitor> template <class Visitor>
void visit(Visitor v) const void visit(Visitor v) const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& derived = static_cast<const Derived&>(*this);
auto&& buffer = static_cast<const Derived&>(*this).data(); if(derived.empty())
MIGRAPHX_THROW("Visiting empty data!");
auto&& s = derived.get_shape();
auto&& buffer = derived.data();
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); }); s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
......
...@@ -12,6 +12,14 @@ ...@@ -12,6 +12,14 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T>
T as_number(T x)
{
return x;
}
inline int32_t as_number(int8_t x) { return static_cast<int32_t>(x); }
inline uint32_t as_number(uint8_t x) { return static_cast<uint32_t>(x); }
template <class T> template <class T>
struct tensor_view struct tensor_view
{ {
...@@ -130,10 +138,10 @@ struct tensor_view ...@@ -130,10 +138,10 @@ struct tensor_view
{ {
if(!x.empty()) if(!x.empty())
{ {
os << x.front(); os << as_number(x.front());
for(std::size_t i = 1; i < x.m_shape.elements(); i++) for(std::size_t i = 1; i < x.m_shape.elements(); i++)
{ {
os << ", " << x.m_data[x.m_shape.index(i)]; os << ", " << as_number(x.m_data[x.m_shape.index(i)]);
} }
} }
return os; return os;
......
...@@ -437,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const ...@@ -437,13 +437,20 @@ argument program::eval(std::unordered_map<std::string, argument> params) const
#else #else
auto check_context = [](auto f) { return f(); }; auto check_context = [](auto f) { return f(); };
#endif #endif
if(enabled(MIGRAPHX_TRACE_EVAL{}))
auto trace_level = value_of(MIGRAPHX_TRACE_EVAL{});
if(trace_level > 0)
{ {
return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) { return generic_eval(*this, ctx, std::move(params), [&](auto& ins, auto f) {
ctx.finish(); ctx.finish();
std::cout << "Run instruction: "; std::cout << "Run instruction: ";
this->debug_print(ins); this->debug_print(ins);
return check_context(f); auto result = check_context(f);
ctx.finish();
if(trace_level > 1 and ins->name().front() != '@' and ins->name() != "load")
std::cout << "Ouput: " << result << std::endl;
return result;
}); });
} }
else else
......
...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins) ...@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off // clang-format off
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"reshape", "reshape",
"contiguous" "contiguous",
"squeeze",
"unsqueeze"
}; };
// clang-format on // clang-format on
return contains(names, ins->name()); return contains(names, ins->name());
...@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto end = std::prev(p.end()); auto end = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end) if(ins->outputs().empty() and ins != end)
continue; continue;
if(is_reshaper(ins)) if(is_reshaper(ins))
...@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const ...@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, t->inputs().front()); p.replace_instruction(ins, t->inputs().front());
} }
} }
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -51,7 +51,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
//simplify_reshapes{}, simplify_reshapes{},
dead_code_elimination{}, dead_code_elimination{},
lowering{ctx}, lowering{ctx},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
......
...@@ -53,15 +53,16 @@ struct tf_parser ...@@ -53,15 +53,16 @@ struct tf_parser
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes) const
{ {
std::vector<T> new_axes;
if(is_nhwc) if(is_nhwc)
{ {
std::vector<T> new_axes;
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis); });
return new_axes;
} }
return new_axes; return axes;
} }
// tf stores certain attributes such as strides, dilations, as a 4D input. // tf stores certain attributes such as strides, dilations, as a 4D input.
...@@ -392,7 +393,9 @@ struct tf_parser ...@@ -392,7 +393,9 @@ struct tf_parser
int64_t out_channels = num_channels * multiplier; int64_t out_channels = num_channels * multiplier;
new_weights_shape[0] = out_channels; new_weights_shape[0] = out_channels;
new_weights_shape[1] = 1; new_weights_shape[1] = 1;
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, weights); // Make sure weights are contiguous before doing reshape
auto cweights = prog.add_instruction(op::contiguous{}, weights);
auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights);
return prog.add_instruction(op, {args[0], new_weights}); return prog.add_instruction(op, {args[0], new_weights});
} }
...@@ -426,17 +429,21 @@ struct tf_parser ...@@ -426,17 +429,21 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector()); auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims) // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()}; op.lengths[0] = lens[2];
op.lengths[0] = input_dims[2]; op.lengths[1] = lens[3];
op.lengths[1] = input_dims[3]; auto l0 = prog.add_instruction(op, args.front());
return prog.add_instruction(op, args.front()); if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
......
...@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous> ...@@ -1251,22 +1251,6 @@ struct test_contiguous : verify_program<test_contiguous>
} }
}; };
struct test_eliminate_contiguous : verify_program<test_eliminate_contiguous>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 5}};
auto seq = p.add_parameter("seq", s);
std::vector<int64_t> perm{0, 2, 1, 3};
auto tran_seq = p.add_instruction(migraphx::op::transpose{perm}, seq);
std::vector<int64_t> out_shape{0, 0, -1};
p.add_instruction(migraphx::op::reshape{out_shape}, tran_seq);
return p;
}
};
struct test_transpose : verify_program<test_transpose> struct test_transpose : verify_program<test_transpose>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test) ...@@ -136,8 +136,9 @@ TEST_CASE(depthwiseconv_test)
op.group = 3; op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2); auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l3); auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
p.add_instruction(op, l0, l4); auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true); auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -168,6 +169,40 @@ TEST_CASE(matmul_test) ...@@ -168,6 +169,40 @@ TEST_CASE(matmul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(mean_test)
{
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(mean_test_nhwc)
{
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(mul_test) TEST_CASE(mul_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment