Commit 74bd6d61 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-concat' into jit-concat-pointwise

parents b30c3408 8109aac8
......@@ -25,8 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_GPU_SYNC_DEVICE_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
......
......@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -27,13 +27,6 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -27,8 +27,6 @@
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
......
......@@ -26,16 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,7 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -24,16 +24,8 @@
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -26,15 +26,7 @@
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -27,42 +27,24 @@
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/abs.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp>
#include <migraphx/gpu/logical_or.hpp>
#include <migraphx/gpu/logical_xor.hpp>
#include <migraphx/gpu/lrn.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/unary_not.hpp>
#include <migraphx/gpu/where.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/program.hpp>
......
......@@ -23,6 +23,7 @@
*/
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
......
......@@ -41,8 +41,9 @@ struct parse_relu6 : op_parser<parse_relu6>
const tf_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto min_val = info.add_literal(0.0f);
auto max_val = info.add_literal(6.0f);
shape::type_t output_type = args[0]->get_shape().type();
auto min_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {0.0f}});
auto max_val = info.add_literal(migraphx::literal{migraphx::shape{output_type}, {6.0f}});
return info.add_common_op("clip", args[0], min_val, max_val);
}
......
......@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/dead_code_elimination.hpp"
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
......
......@@ -144,7 +144,7 @@ TEST_CASE(conv)
{
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32>
}
......@@ -167,7 +167,7 @@ TEST_CASE(conv_add_relu)
{
const std::string mlir_output = R"__migraphx__(
module {
func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1], use_dynamic_same_auto_pad = 0 : i64} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
......@@ -21,7 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/instruction_ref.hpp"
#include <migraphx/instruction_ref.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/target.hpp>
......
......@@ -38,7 +38,6 @@
#include <migraphx/onnx.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/reshape.hpp>
......
......@@ -2077,6 +2077,55 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_len_1)
{
migraphx::module m1;
{
migraphx::shape s{migraphx::shape::float_type, {1, 128, 3}};
auto input = m1.add_parameter("input", s);
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {1}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1}}, {"ends", {2}}}), input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {2}}, {"ends", {3}}}), input);
auto c0 = m1.add_instruction(migraphx::make_op("contiguous"), slc0);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), slc2);
std::vector<int64_t> lens = {1, 128};
auto r0 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c0);
auto r1 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c1);
auto r2 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), c2);
auto sum = m1.add_instruction(migraphx::make_op("add"), r0, r1);
auto ret = m1.add_instruction(migraphx::make_op("mul"), sum, r2);
m1.add_return({ret});
};
migraphx::module m2;
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 128, 3}};
auto input = m2.add_parameter("input", s);
std::vector<int64_t> lens = {1, 384};
auto rsp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", lens}}), input);
auto slc0 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {128}}}), rsp);
auto slc1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {128}}, {"ends", {256}}}), rsp);
auto slc2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {256}}, {"ends", {384}}}), rsp);
auto sum = m2.add_instruction(migraphx::make_op("add"), slc0, slc1);
auto ret = m2.add_instruction(migraphx::make_op("mul"), sum, slc2);
m2.add_return({ret});
};
run_pass(m1);
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
......
......@@ -48,6 +48,26 @@ inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx:
return result;
}
migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
const std::vector<size_t>& mbcast_lens,
const int axis)
{
migraphx::module m;
auto s = migraphx::shape{migraphx::shape::float_type, in_lens};
auto x = m.add_parameter("x", s);
auto y = m.add_parameter("y", s);
auto z = m.add_parameter("z", s);
auto xm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), x);
auto ym =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), y);
auto zm =
m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", mbcast_lens}}), z);
auto concat = m.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), xm, ym, zm);
m.add_return({concat});
return m;
}
TEST_CASE(double_contig)
{
migraphx::program p;
......@@ -337,6 +357,87 @@ TEST_CASE(nop_convert)
EXPECT(std::distance(m.begin(), m.end()) == n - 1);
}
TEST_CASE(concat_multibroadcasts1)
{
// Broadcasted batch dim, new axis < old axis
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
TEST_CASE(concat_multibroadcasts2)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 0);
}
TEST_CASE(concat_multibroadcasts3)
{
// Broadcasted middle dim, new axis == old axis
std::vector<std::size_t> in_lens = {3, 1, 4};
std::vector<std::size_t> mbcast_lens = {3, 2, 4};
const int axis = 2;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto out_shape = m.get_output_shapes().back();
auto n = std::distance(m.begin(), m.end());
run_pass(m);
EXPECT(m.get_output_shapes().back().lens() == out_shape.lens());
EXPECT(std::distance(m.begin(), m.end()) == n - 2);
auto new_concat =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != m.end()});
auto cd = std::distance(m.begin(), new_concat);
auto new_mb =
std::find_if(m.begin(), m.end(), [](auto ins) { return ins.name() == "multibroadcast"; });
auto md = std::distance(m.begin(), new_mb);
EXPECT(cd == md - 1);
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 2);
}
TEST_CASE(concat_multibroadcasts4)
{
// Broadcasted batch dim, axis is broadcasted dim
std::vector<std::size_t> in_lens = {3, 4};
std::vector<std::size_t> mbcast_lens = {2, 3, 4};
const int axis = 0;
auto m = make_concat_multibroadcast(in_lens, mbcast_lens, axis);
auto m1 = m;
run_pass(m);
EXPECT(m1 == m);
}
TEST_CASE(concat_transpose1)
{
migraphx::module m;
......
......@@ -495,10 +495,10 @@ def relu6_test(g1):
@tf_test
def relu6_mismatch_test(g1):
def relu6_half_test(g1):
with g1.as_default():
g1_input = tf.compat.v1.placeholder(tf.float16,
shape=(1, 3, 13, 37),
shape=(1, 3, 16, 16),
name='0')
tf.nn.relu6(g1_input, 'relu6')
......@@ -708,7 +708,7 @@ if __name__ == '__main__':
pow_test()
relu_test()
relu6_test()
relu6_mismatch_test()
relu6_half_test()
reshape_test()
rsqrt_test()
shape_test()
......
......@@ -2,7 +2,7 @@
:
0 Placeholder*
dtype0*
shape: %
shape:

relu6Relu60*
T0"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment