"include/ck/utility/sequence.hpp" did not exist on "545d9305687c083717274171fdb22c74a5b5615e"
Unverified Commit bb3a181e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into reshape_tests

parents fb0fab94 c180f601
...@@ -43,8 +43,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -43,8 +43,8 @@ void rewrite_rnn::apply(program& prog) const
auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]); auto r_reverse = prog.insert_instruction(ins, op::slice{{0}, {1}, {2}}, args[2]);
// process bias // process bias
instruction_ref bias_forward, bias_reverse; instruction_ref bias_forward = prog.end();
bias_forward = bias_reverse = prog.end(); instruction_ref bias_reverse = prog.end();
if(args.size() >= 4 && args[3]->get_operator().name() != "undefined") if(args.size() >= 4 && args[3]->get_operator().name() != "undefined")
{ {
bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]); bias_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[3]);
...@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const ...@@ -53,7 +53,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument // process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored) // or the 5th one (if the sequence len argument is ignored)
instruction_ref ih_forward, ih_reverse; instruction_ref ih_forward{};
instruction_ref ih_reverse{};
if(args.size() == 6 && args[5]->get_operator().name() != "undefined") if(args.size() == 6 && args[5]->get_operator().name() != "undefined")
{ {
ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]); ih_forward = prog.insert_instruction(ins, op::slice{{0}, {0}, {1}}, args[5]);
...@@ -210,9 +211,10 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward, ...@@ -210,9 +211,10 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b); bias = prog.insert_instruction(ins, op::broadcast{1, sih->get_shape()}, b);
} }
instruction_ref hidden_out = prog.end(), last_out; instruction_ref hidden_out = prog.end();
last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih); instruction_ref last_out{};
std::size_t seq_len = input->get_shape().lens()[0]; last_out = prog.insert_instruction(ins, op::unsqueeze{{0, 1}}, sih);
std::size_t seq_len = input->get_shape().lens()[0];
for(std::size_t i = 0; i < seq_len; i++) for(std::size_t i = 0; i < seq_len; i++)
{ {
long seq_index = is_forward ? i : (seq_len - 1 - i); long seq_index = is_forward ? i : (seq_len - 1 - i);
......
...@@ -30,6 +30,7 @@ add_library(migraphx_device ...@@ -30,6 +30,7 @@ add_library(migraphx_device
device/concat.cpp device/concat.cpp
device/pad.cpp device/pad.cpp
device/gather.cpp device/gather.cpp
device/sub.cpp
) )
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device) set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_clang_tidy_check(migraphx_device) rocm_clang_tidy_check(migraphx_device)
......
#include <migraphx/gpu/device/sub.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return y - x; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SUB_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SUB_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sub(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#define MIGRAPHX_GUARD_RTGLIB_SUB_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/sub.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_sub : binary_device<hip_sub, device::sub>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/softmax.hpp> #include <migraphx/gpu/softmax.hpp>
#include <migraphx/gpu/add.hpp> #include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/sub.hpp>
#include <migraphx/gpu/exp.hpp> #include <migraphx/gpu/exp.hpp>
#include <migraphx/gpu/log.hpp> #include <migraphx/gpu/log.hpp>
#include <migraphx/gpu/sin.hpp> #include <migraphx/gpu/sin.hpp>
...@@ -76,6 +77,7 @@ struct miopen_apply ...@@ -76,6 +77,7 @@ struct miopen_apply
add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu); add_miopen_extend_op<miopen_elu, op::elu>("elu", make_elu);
add_generic_op<hip_add>("add"); add_generic_op<hip_add>("add");
add_generic_op<hip_sub>("sub");
add_generic_op<hip_exp>("exp"); add_generic_op<hip_exp>("exp");
add_generic_op<hip_log>("log"); add_generic_op<hip_log>("log");
add_generic_op<hip_sin>("sin"); add_generic_op<hip_sin>("sin");
......
...@@ -483,6 +483,38 @@ struct test_triadd_broadcast ...@@ -483,6 +483,38 @@ struct test_triadd_broadcast
} }
}; };
struct test_sub
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", s);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, z);
return p;
}
};
struct test_sub2
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape b{migraphx::shape::float_type, {3}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto z = p.add_parameter("z", b);
auto zb = p.add_instruction(migraphx::op::broadcast{1, s}, z);
auto diff = p.add_instruction(migraphx::op::sub{}, x, y);
p.add_instruction(migraphx::op::sub{}, diff, zb);
return p;
}
};
struct test_softmax struct test_softmax
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1503,6 +1535,8 @@ int main() ...@@ -1503,6 +1535,8 @@ int main()
verify_program<test_add_broadcast4>(); verify_program<test_add_broadcast4>();
verify_program<test_add_broadcast5>(); verify_program<test_add_broadcast5>();
verify_program<test_triadd_broadcast>(); verify_program<test_triadd_broadcast>();
verify_program<test_sub>();
verify_program<test_sub2>();
verify_program<test_softmax>(); verify_program<test_softmax>();
verify_program<test_softmax2>(); verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
......
 subtraction2:q

0
1out"Sub subtraction2Z
0




Z
1


b
out




B
\ No newline at end of file
...@@ -346,7 +346,7 @@ TEST_CASE(add_bcast_test) ...@@ -346,7 +346,7 @@ TEST_CASE(add_bcast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(implicit_bcast_test) TEST_CASE(implicit_add_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
...@@ -360,6 +360,33 @@ TEST_CASE(implicit_bcast_test) ...@@ -360,6 +360,33 @@ TEST_CASE(implicit_bcast_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::broadcast{1, l0->get_shape()}, l1);
p.add_instruction(migraphx::op::sub{}, l0, l2);
auto prog = migraphx::parse_onnx("sub_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(implicit_sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3);
auto prog = migraphx::parse_onnx("implicit_sub_bcast_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test) TEST_CASE(unknown_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -714,6 +741,20 @@ TEST_CASE(add_scalar_test) ...@@ -714,6 +741,20 @@ TEST_CASE(add_scalar_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sub_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1}}, {1}});
auto m0 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto m1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, m0, m1);
auto prog = migraphx::parse_onnx("sub_scalar_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(group_conv_test) TEST_CASE(group_conv_test)
{ {
migraphx::program p; migraphx::program p;
......
 subtraction2:
/
0
1out"Sub*
axis*
broadcast subtraction2Z
0




Z
1


b
out




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment