Unverified Commit d45b98cb authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into onnxruntime-sync-2023-09-01

parents 74657b8b d4cfdb3e
...@@ -1446,10 +1446,13 @@ struct find_split_transpose ...@@ -1446,10 +1446,13 @@ struct find_split_transpose
{ {
return; return;
} }
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size()); std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) { std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front(); return i->outputs().front();
}); });
......
...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr) ...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr); auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess) if(status != hipSuccess)
return false; return false;
return attr.memoryType == hipMemoryTypeDevice; return attr.type == hipMemoryTypeDevice;
} }
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
......
...@@ -647,8 +647,8 @@ struct mlir_program ...@@ -647,8 +647,8 @@ struct mlir_program
void set_gpu_properties(const context& migraphx_ctx) void set_gpu_properties(const context& migraphx_ctx)
{ {
const auto& device = migraphx_ctx.get_current_device(); const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name(); target_arch = device.get_device_name();
num_cu = device.get_cu_count(); num_cu = device.get_cu_count();
} }
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
...@@ -869,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx, ...@@ -869,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx,
adjust_param_shapes(m, to_shapes(inputs)); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl; std::cout << m << std::endl;
}
mlir_program mp; mlir_program mp;
mp.set_gpu_properties(migraphx_ctx); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution); auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
......
...@@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar) ...@@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb = auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
...@@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims) ...@@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}}); auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}}); auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb = auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb); auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm() ...@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>); TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps) TEST_CASE(reorder_slice_ins_deps)
{ {
auto create_module = [] { auto create_module = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment