Unverified Commit d45b98cb authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into onnxruntime-sync-2023-09-01

parents 74657b8b d4cfdb3e
......@@ -1446,10 +1446,13 @@ struct find_split_transpose
{
return;
}
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) {
return i->outputs().size() != 1;
}))
return;
std::vector<instruction_ref> vec_trans(split_outputs.size());
std::transform(split_outputs.begin(), split_outputs.end(), vec_trans.begin(), [](auto i) {
assert(i->outputs().size() == 1);
return i->outputs().front();
});
......
......@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess)
return false;
return attr.memoryType == hipMemoryTypeDevice;
return attr.type == hipMemoryTypeDevice;
}
std::size_t get_available_gpu_memory()
......
......@@ -647,8 +647,8 @@ struct mlir_program
void set_gpu_properties(const context& migraphx_ctx)
{
const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
}
std::pair<std::size_t, std::size_t> get_launch_params() const
......@@ -869,15 +869,22 @@ code_object_op compile_mlir(const context& migraphx_ctx,
adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl;
}
mlir_program mp;
mp.set_gpu_properties(migraphx_ctx);
mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front();
......
......@@ -603,8 +603,8 @@ TEST_CASE(simplify_inner_broadcast_scalar)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1, 384}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 384}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
......@@ -630,8 +630,8 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {384, 768}});
auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
auto yb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {384, 768}}}), y);
auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
......@@ -3035,6 +3035,36 @@ void reorder_slice_trans_diff_perm()
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<1>);
TEST_CASE_REGISTER(reorder_slice_trans_diff_perm<4>);
TEST_CASE(reorder_slice_trans_multi_outputs)
{
migraphx::module m1;
{
auto s = migraphx::shape{migraphx::shape::float_type, {8, 128, 1920}};
auto input = m1.add_parameter("input", s);
std::vector<int64_t> perm = {0, 2, 1};
auto slc0 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {640}}}), input);
auto slc1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {640}}, {"ends", {1280}}}),
input);
auto slc2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1280}}, {"ends", {1920}}}),
input);
auto t0 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc0);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc1);
auto t2 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), slc2);
auto sum = m1.add_instruction(migraphx::make_op("add"), t0, t1);
auto dot = m1.add_instruction(migraphx::make_op("mul"), sum, t2);
auto slc_cont = m1.add_instruction(migraphx::make_op("contiguous"), slc1);
m1.add_return({slc_cont, dot});
};
run_pass(m1);
auto m2 = m1;
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(reorder_slice_ins_deps)
{
auto create_module = [] {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment