Commit 4b7a267a authored by Paul's avatar Paul
Browse files

Merge from develop

parents 92803edf af00eea8
This diff is collapsed.
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_BATCHNORM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/batch_norm.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONCAT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/concat.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONTIGUOUS_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/contiguous.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/dot.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
argument allocate_gpu(const shape& s, bool host = false);
argument to_gpu(const argument& arg, bool host = false);
......
......@@ -4,7 +4,7 @@
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
......
......@@ -2,7 +2,9 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_MIOPEN_HPP
#include <migraphx/manage_ptr.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <miopen/miopen.h>
#include <migraphx/config.hpp>
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/pad.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_POOLING_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......
......@@ -2,7 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/softmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/logsoftmax.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <utility>
......
......@@ -20,6 +20,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
namespace migraphx {
......@@ -36,6 +37,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
{
dead_code_elimination{},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
fwd_conv_batchnorm_rewrite{},
dead_code_elimination{},
rewrite_rnn{},
......
......@@ -119,6 +119,7 @@ struct tf_parser
add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
add_mem_op("MaxPool", &tf_parser::parse_pooling);
add_mem_op("Mean", &tf_parser::parse_mean);
add_mem_op("Pack", &tf_parser::parse_pack);
add_mem_op("Pad", &tf_parser::parse_pad);
add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax);
......@@ -353,6 +354,23 @@ struct tf_parser
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
instruction_ref parse_pack(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
// reinterpret as unsqueeze with concat
std::vector<instruction_ref> unsqueezed_args;
int64_t axis = 0;
if(contains(attributes, "axis"))
axis = attributes.at("axis").i();
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
}
instruction_ref
parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
......
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......
#include <migraphx/constant_propagate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment