"testing/python/vscode:/vscode.git/clone" did not exist on "cb37bfef8f12e156ddffd3009f69c3b818cc05c7"
Commit 4fffcdd5 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into rocblas_api_opt

parents d45bd3ba 48585bad
...@@ -15,7 +15,7 @@ p = parse_onnx(input_file, options); ...@@ -15,7 +15,7 @@ p = parse_onnx(input_file, options);
``` ```
## Saving ## Saving
An instantiated migraphx::program object can then be serialized to MessagePack (.msgpack) format and saved so that it can be loaded for future uses. An instantiated migraphx::program object can then be serialized to MessagePack (.mxr) format and saved so that it can be loaded for future uses.
A program can be saved with either of the following: A program can be saved with either of the following:
``` ```
......
...@@ -77,7 +77,7 @@ int main(int argc, char** argv) ...@@ -77,7 +77,7 @@ int main(int argc, char** argv)
std::cout << "Saving program..." << std::endl; std::cout << "Saving program..." << std::endl;
std::string output_file; std::string output_file;
output_file = save_arg == nullptr ? "out" : save_arg; output_file = save_arg == nullptr ? "out" : save_arg;
output_file.append(".msgpack"); output_file.append(".mxr");
migraphx::file_options options; migraphx::file_options options;
options.set_file_format("msgpack"); options.set_file_format("msgpack");
......
# Modifications Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved
# Copyright 2018 The Google AI Language Team Authors. # Copyright 2018 The Google AI Language Team Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
......
...@@ -50,10 +50,10 @@ ...@@ -50,10 +50,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"if not os.path.exists(\"yolov4_fp16.msgpack\"):\n", "if not os.path.exists(\"yolov4_fp16.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.msgpack\n", " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --fp16ref --binary -o yolov4_fp16.mxr\n",
"if not os.path.exists(\"yolov4.msgpack\"):\n", "if not os.path.exists(\"yolov4.mxr\"):\n",
" !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.msgpack" " !/opt/rocm/bin/migraphx-driver compile ./utilities/yolov4.onnx --gpu --enable-offload-copy --binary -o yolov4.mxr"
] ]
}, },
{ {
...@@ -115,8 +115,8 @@ ...@@ -115,8 +115,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Load serialized model (either single- or half-precision)\n", "# Load serialized model (either single- or half-precision)\n",
"model = migraphx.load(\"yolov4.msgpack\", format=\"msgpack\")\n", "model = migraphx.load(\"yolov4.mxr\", format=\"msgpack\")\n",
"#model = migraphx.load(\"yolov4_fp16.msgpack\", format=\"msgpack\")\n", "#model = migraphx.load(\"yolov4_fp16.mxr\", format=\"msgpack\")\n",
"\n", "\n",
"# Get the name of the input parameter and convert image data to an MIGraphX argument\n", "# Get the name of the input parameter and convert image data to an MIGraphX argument\n",
"input_name = next(iter(model.get_parameter_shapes()))\n", "input_name = next(iter(model.get_parameter_shapes()))\n",
...@@ -192,4 +192,4 @@ ...@@ -192,4 +192,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 2 "nbformat_minor": 2
} }
\ No newline at end of file
...@@ -88,6 +88,7 @@ struct cpp_generator_impl ...@@ -88,6 +88,7 @@ struct cpp_generator_impl
std::stringstream fs{}; std::stringstream fs{};
std::size_t function_count = 0; std::size_t function_count = 0;
std::function<std::string(std::string)> fmap = nullptr; std::function<std::string(std::string)> fmap = nullptr;
std::function<std::string(shape)> fresult = nullptr;
std::unordered_map<std::string, std::string> point_op_map = {}; std::unordered_map<std::string, std::string> point_op_map = {};
}; };
cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {} cpp_generator::cpp_generator() : impl(std::make_unique<cpp_generator_impl>()) {}
...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default; ...@@ -104,6 +105,8 @@ cpp_generator::~cpp_generator() noexcept = default;
void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; } void cpp_generator::fmap(const std::function<std::string(std::string)>& f) { impl->fmap = f; }
void cpp_generator::fresult(const std::function<std::string(shape)>& f) { impl->fresult = f; }
void cpp_generator::add_point_op(const std::string& op_name, const std::string& code) void cpp_generator::add_point_op(const std::string& op_name, const std::string& code)
{ {
impl->point_op_map[op_name] = code; impl->point_op_map[op_name] = code;
...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m) ...@@ -174,7 +177,12 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(args), std::back_inserter(args),
[&](auto i) { return names.at(i); }); [&](auto i) { return names.at(i); });
return this->generate_point_op(ins->get_operator(), args);
auto s = this->generate_point_op(ins->get_operator(), args);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
return s;
}); });
return f; return f;
} }
......
...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r) ...@@ -32,7 +32,12 @@ void cse_range(module& p, Range&& r)
continue; continue;
p.replace_instruction(ins, eq); p.replace_instruction(ins, eq);
processed_ins.emplace(ins); processed_ins.emplace(ins);
auto outputs = eq->outputs(); std::vector<instruction_ref> outputs;
std::copy_if(eq->outputs().begin(),
eq->outputs().end(),
std::back_inserter(outputs),
[&](auto x) { return p.has_instruction(x); });
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) { std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y); return std::distance(eq, x) < std::distance(eq, y);
}); });
......
...@@ -68,6 +68,8 @@ struct cpp_generator ...@@ -68,6 +68,8 @@ struct cpp_generator
void fmap(const std::function<std::string(std::string)>& f); void fmap(const std::function<std::string(std::string)>& f);
void fresult(const std::function<std::string(shape)>& f);
void add_point_op(const std::string& op_name, const std::string& code); void add_point_op(const std::string& op_name, const std::string& code);
std::string generate_point_op(const operation& op, const std::vector<std::string>& args); std::string generate_point_op(const operation& op, const std::vector<std::string>& args);
......
...@@ -37,43 +37,49 @@ struct squeeze ...@@ -37,43 +37,49 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return args[0].reshape(output_shape); return args[0].reshape(output_shape);
......
...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -78,8 +78,8 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
template <class F> template <class F>
void par_for(std::size_t n, std::size_t min_grain, F f) void par_for(std::size_t n, std::size_t min_grain, F f)
{ {
const auto threadsize = const auto threadsize = std::min<std::size_t>(std::thread::hardware_concurrency(),
std::min<std::size_t>(std::thread::hardware_concurrency(), n / min_grain); n / std::max<std::size_t>(1, min_grain));
par_for_impl(n, threadsize, f); par_for_impl(n, threadsize, f);
} }
......
...@@ -12,6 +12,8 @@ namespace migraphx { ...@@ -12,6 +12,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
...@@ -70,6 +72,14 @@ struct compiled_result ...@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref ins; instruction_ref ins;
}; };
template <class F>
void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
}
void compile_ops::apply(module& m) const void compile_ops::apply(module& m) const
{ {
auto compilers = make_compilers(pointwise_compiler{}); auto compilers = make_compilers(pointwise_compiler{});
...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const ...@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; }); compiles.emplace_back([=]() -> compiled_result { return {c(*ctx, ins, preop), ins}; });
} }
std::vector<compiled_result> results(compiles.size()); std::vector<compiled_result> results(compiles.size());
par_for(compiles.size(), 1, [&](auto i) { results[i] = compiles[i](); }); par_compile(compiles.size(), [&](auto i) { results[i] = compiles[i](); });
for(const auto& cr : results) for(const auto& cr : results)
{ {
m.replace_instruction(cr.ins, cr.op, cr.ins->inputs()); m.replace_instruction(cr.ins, cr.op, cr.ins->inputs());
......
...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu ...@@ -70,6 +70,9 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
g.add_point_op("less", "migraphx::abs(${0} < ${1})"); g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})"); g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})"); g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m)); g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str()); return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
......
...@@ -587,6 +587,11 @@ struct miopen_fusion ...@@ -587,6 +587,11 @@ struct miopen_fusion
return pack(f(self.ops, "ops")); return pack(f(self.ops, "ops"));
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
value compile(context& ctx, const shape&, std::vector<shape> inputs) value compile(context& ctx, const shape&, std::vector<shape> inputs)
{ {
// Compensate for allocation // Compensate for allocation
......
...@@ -6,15 +6,32 @@ ...@@ -6,15 +6,32 @@
namespace migraphx { namespace migraphx {
template <class T>
struct remove_vec_impl
{
using type = T;
};
template <class T, index_int N>
struct remove_vec_impl<vec<T, N>>
{
using type = T;
};
template <class T>
using remove_vec = typename remove_vec_impl<T>::type;
template <class T, class... Shapes> template <class T, class... Shapes>
constexpr auto traverse_preload(Shapes... ss) constexpr auto traverse_preload(Shapes... ss)
{ {
return [=](auto f, auto... g) { return [=](auto f, auto... g) {
index_int offset = 0; index_int offset = 0;
auto each = [&](auto x) { auto each = [&](auto x) {
using type = remove_vec<typename decltype(x)::type>;
constexpr auto s = decltype(x.get_shape()){}; constexpr auto s = decltype(x.get_shape()){};
constexpr auto size = _c<s.element_space()>; constexpr auto size = _c<s.element_space()>;
if constexpr(not s.broadcasted() or (s.elements() - size) < 64) if constexpr(not s.broadcasted() or (s.elements() - size) < 64 or
not is_same<T, type>{})
return f(x, offset, false_type{}); return f(x, offset, false_type{});
else else
{ {
...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs) ...@@ -78,23 +95,23 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
invoke); invoke);
} }
template <class T> template <class T, class Shape>
struct remove_vec struct shape_type : Shape
{ {
using type = T; using type = T;
}; };
template <class T, index_int N> template <class T>
struct remove_vec<vec<T, N>> constexpr auto make_shape_type(T)
{ {
using type = T; return shape_type<typename T::type, typename T::shape_type>{};
}; }
template <class T, class... Ts> template <class T, class... Ts>
__device__ auto preload(index idx, Ts... xs) __device__ auto preload(index idx, Ts... xs)
{ {
using type = typename remove_vec<T>::type; using type = remove_vec<T>;
constexpr auto size = decltype(compute_preload_size<type>(xs.get_shape()...)){}; constexpr auto size = decltype(compute_preload_size<type>(make_shape_type(xs)...)){};
const index_int max_size = 512 * sizeof(type); const index_int max_size = 512 * sizeof(type);
return [=](auto f) { return [=](auto f) {
if constexpr(size > 0 and size < max_size) if constexpr(size > 0 and size < max_size)
......
...@@ -9,7 +9,8 @@ namespace migraphx { ...@@ -9,7 +9,8 @@ namespace migraphx {
template <class T, class Shape> template <class T, class Shape>
struct tensor_view struct tensor_view
{ {
using type = T; using type = T;
using shape_type = Shape;
constexpr Shape get_shape() const { return Shape{}; } constexpr Shape get_shape() const { return Shape{}; }
constexpr index_int size() const { return get_shape().elements(); } constexpr index_int size() const { return get_shape().elements(); }
......
...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)> ...@@ -25,6 +25,16 @@ struct is_convertible : bool_constant<__is_convertible(From, To)>
{ {
}; };
template <class T, class U>
struct is_same : false_type
{
};
template <class T>
struct is_same<T, T> : true_type
{
};
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__> #define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
} // namespace migraphx } // namespace migraphx
......
...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x) ...@@ -66,15 +66,18 @@ __device__ __host__ auto as_vec(T* x)
return reinterpret_cast<vec<T, N>*>(x); return reinterpret_cast<vec<T, N>*>(x);
} }
template <class T, index_int N>
using safe_vec = vec<std::conditional_t<std::is_same<T, bool>{}, uint8_t, T>, N>;
template <class... Ts> template <class... Ts>
constexpr auto vec_transform(Ts... xs) constexpr auto vec_transform(Ts... xs)
{ {
return [=](auto f) { return [=](auto f) {
if constexpr(is_any_vec<Ts...>()) if constexpr(is_any_vec<Ts...>())
{ {
using type = decltype(f(vec_at(xs, 0)...)); using type = decltype(f(vec_at(xs, 0)...));
constexpr auto size = common_vec_size<Ts...>(); constexpr auto size = common_vec_size<Ts...>();
vec<type, size> result = {0}; safe_vec<type, size> result = {0};
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
result[i] = f(vec_at(xs, i)...); result[i] = f(vec_at(xs, i)...);
return result; return result;
......
...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis) ...@@ -50,14 +50,14 @@ constexpr auto shape_step(Shape s, Axis)
}); });
} }
// Bools can not be used as a vector type so convert it to int8 // Bools can not be used as a vector type so convert it to uint8
template <class T> template <class T>
__device__ __host__ T* remove_bool(T* x) __device__ __host__ T* remove_bool(T* x)
{ {
return x; return x;
} }
inline __device__ __host__ int8_t* remove_bool(bool* x) { return reinterpret_cast<int8_t*>(x); } inline __device__ __host__ uint8_t* remove_bool(bool* x) { return reinterpret_cast<uint8_t*>(x); }
template <index_int N, class T, class Axis> template <index_int N, class T, class Axis>
__device__ __host__ auto as_vec(T x, Axis axis) __device__ __host__ auto as_vec(T x, Axis axis)
......
...@@ -44,7 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -44,7 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_POINTWISE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
struct id_pass struct id_pass
{ {
...@@ -100,7 +100,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -100,7 +100,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes{}, simplify_reshapes{},
propagate_constant{}, propagate_constant{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(enabled(MIGRAPHX_ENABLE_POINTWISE_FUSION{}), fuse_pointwise{}), enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{}, dead_code_elimination{},
mlir_conv{&ctx}, mlir_conv{&ctx},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
TEST_CASE(load_save_default) TEST_CASE(load_save_default)
{ {
std::string filename = "migraphx_api_load_save.dat"; std::string filename = "migraphx_api_load_save.mxr";
auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx"); auto p1 = migraphx::parse_onnx("conv_relu_maxpool_test.onnx");
auto s1 = p1.get_output_shapes(); auto s1 = p1.get_output_shapes();
......
...@@ -6,6 +6,12 @@ ...@@ -6,6 +6,12 @@
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(
p, {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}});
}
void run_pass(migraphx::module& m) void run_pass(migraphx::module& m)
{ {
migraphx::run_passes( migraphx::run_passes(
...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal) ...@@ -142,4 +148,59 @@ TEST_CASE(cse_test_literal)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(cse_test_submodule)
{
migraphx::shape si{migraphx::shape::int64_type};
migraphx::shape s{migraphx::shape::int64_type, {1}};
migraphx::shape sc{migraphx::shape::bool_type};
auto create_program = [&](bool remove_literal = false) {
migraphx::program p;
std::vector<bool> vc = {true};
std::vector<int64_t> vd = {3};
auto* mm = p.get_main_module();
auto in_cond = mm->add_parameter("ccond", sc);
auto in_val = mm->add_parameter("val", s);
auto b0 = mm->add_literal(migraphx::literal(sc, vc));
auto b1 = b0;
if(not(remove_literal))
b1 = mm->add_literal(migraphx::literal(sc, vc));
auto* body1 = p.create_module("loop_module1");
body1->add_parameter("#loop_module_in_1", sc);
auto in_v1 = body1->add_parameter("#loop_module_in_2", s);
auto l1 = body1->add_literal(migraphx::literal(si, vd));
auto ad1 = body1->add_instruction(migraphx::make_op("add"), l1, l1);
auto val1 = body1->add_instruction(migraphx::make_op("add"), in_v1, ad1);
auto cond1 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b0);
auto cond2 = body1->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body1->add_return({cond1, cond2, val1, val1});
auto* body2 = p.create_module("loop_module2");
body2->add_parameter("#loop_module_in_1", sc);
auto in_v2 = body2->add_parameter("#loop_module_in_2", s);
auto l2 = body2->add_literal(migraphx::literal(si, vd));
auto ad2 = body2->add_instruction(migraphx::make_op("add"), l2, l2);
auto val2 = body2->add_instruction(migraphx::make_op("add"), in_v2, ad2);
auto cond3 = body2->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), b1);
body2->add_return({cond3, val2, val2});
auto loop1 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body1});
auto loop2 = mm->add_instruction(
migraphx::make_op("loop", {{"max_iterations", 1}}), {in_cond, in_val}, {body2});
mm->add_return({loop1, loop2});
return p;
};
auto p = create_program();
run_pass(p);
EXPECT(p == create_program(true));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment