Unverified Commit 60d5fa1a authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dyn_ref_multibroadcast

parents 41268947 1820198e
......@@ -83,9 +83,10 @@ struct miopen_convolution
inline shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, op}.has(4).standard();
check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, op}.max_ndims(5);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts(
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}});
return migraphx::compute_shape<Op>(op, conv_inputs);
}
......@@ -144,12 +145,9 @@ struct miopen_convolution
#endif
}
inline void set_conv_descriptor()
void set_conv_descriptor()
{
if(cd == nullptr)
{
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
cd = (op.name() == "deconvolution") ? make_deconv(op) : make_conv(op);
}
value compile(migraphx::context& ctx, const shape& output, const std::vector<shape>& input)
......@@ -239,7 +237,6 @@ struct miopen_convolution
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen " + op.name() + " : find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
status = miopenConvolutionForwardGetSolutionCount(ctx.get_stream().get_miopen(),
......
......@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
concat<${axis}>(y, xs...);
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(${concat_args})(${post}, y, xs...);
});
}
......@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(vec)},
{"axis", v.at("axis").to<std::string>()}});
auto src = interpolate_string(
concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
......
......@@ -58,7 +58,7 @@ __global__ void ${kernel}(${params})
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
std::vector<std::string> names() const { return {"pointwise", "contiguous", "layout"}; }
static std::size_t oversubscribe_if(bool b)
{
......@@ -91,12 +91,12 @@ struct pointwise_compiler : compiler<pointwise_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
if(op.name() == "contiguous")
if(contains({"layout", "contiguous"}, op.name()))
{
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", "[](auto x) { return x; }"}, {"kernel", "contiguous_kernel"}}));
{{"lambda", "[](auto x) { return x; }"}, {"kernel", op.name() + "_kernel"}}));
}
else
{
......
......@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return Start{} * output_shape.strides[Axis];
});
constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s);
MIGRAPHX_ASSERT(offset < out.get_shape().element_space());
MIGRAPHX_ASSERT((s.element_space() + offset) <= out.get_shape().element_space());
return make_tensor_view(out.data() + offset, s);
}
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); };
}
template <index_int Axis, class Input>
......@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>;
}
template <index_int Axis, class Output, class... Inputs>
__device__ void concat(Output output, Inputs... inputs)
template <index_int Axis, class... Inputs>
__device__ auto concat(Inputs... inputs)
{
auto idx = make_index();
fold([&](auto start, auto input) {
auto y = concat_slice<Axis>(output, input, start);
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; });
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input) {
concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(),
[&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
}
} // namespace migraphx
......
......@@ -29,19 +29,14 @@
#include <migraphx/instruction_ref.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/deconvolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/if_op.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/compiler.hpp>
......@@ -109,9 +104,9 @@ struct miopen_apply
add_extend_op("scatter_none");
add_extend_op("topk");
add_convolution_op<op::convolution>("convolution");
add_convolution_op<op::deconvolution>("deconvolution");
add_convolution_op<op::quant_convolution>("quant_convolution");
add_convolution_op("convolution");
add_convolution_op("deconvolution");
add_convolution_op("quant_convolution");
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
add_if_op();
......@@ -238,34 +233,19 @@ struct miopen_apply
});
}
template <typename Op>
void add_convolution_op(const std::string& name)
{
apply_map.emplace(name, [=](instruction_ref ins) {
operation conv =
miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), int8_x4_format};
migraphx::context ctx = get_context();
size_t ws_bytes = 0;
auto compile_conv_with_format = [&](bool format) {
conv = miopen_convolution<Op>{any_cast<Op>(ins->get_operator()), format};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
ws_bytes = ws.get("workspace", 0);
};
try
{ // for the regular convolution and deconvolution, this try would always succeed
compile_conv_with_format(int8_x4_format);
}
catch(migraphx::exception&)
{
// In case no solver supports the default format, retry using the other format.
compile_conv_with_format(not int8_x4_format);
}
operation conv = make_op(
"gpu::" + name,
{{"op", ins->get_operator().to_value()}, {"int8_x4_format", int8_x4_format}});
auto output = insert_allocation(ins, ins->get_shape());
auto args = ins->inputs();
auto output = insert_allocation(ins, ins->get_shape());
auto workspace = insert_allocation(ins, shape{shape::int8_type, {ws_bytes}});
return mod->replace_instruction(ins, conv, args[0], args[1], workspace, output);
return mod->replace_instruction(ins,
make_op("gpu::miopen_op", {{"op", to_value(conv)}}),
ins->inputs().at(0),
ins->inputs().at(1),
output);
});
}
......
......@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle(T p) : handle(ptr{p}) {}
T get() const { return handle.get().get(); }
T get() const
{
return handle.get().get(); // NOLINT(readability-redundant-smartptr-get)
}
T release() { return handle.release().get(); }
......
......@@ -35,6 +35,7 @@
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp>
......@@ -50,6 +51,7 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp>
......@@ -70,6 +72,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
......@@ -120,6 +123,9 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
simplify_algebra{},
simplify_reshapes{},
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
dead_code_elimination{},
simplify_reshapes{},
simplify_algebra{},
prefuse_ops{},
dead_code_elimination{},
......@@ -136,8 +142,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{},
compile_miopen{&gctx},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
adjust_allocation{gpu_allocation_model{}},
dead_code_elimination{},
fuse_ops{&ctx, options.fast_math},
dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/layout_nhwc.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
void run_pass(migraphx::module& m)
{
migraphx::run_passes(m, {migraphx::layout_nhwc{}, migraphx::dead_code_elimination{}});
}
migraphx::operation layout(std::vector<int64_t> permutation = {0, 1, 2, 3})
{
return migraphx::make_op("layout", {{"permutation", permutation}});
}
migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruction_ref ins)
{
return m.add_instruction(layout({0, 2, 3, 1}), ins);
}
TEST_CASE(conv_relu)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
m1.add_instruction(migraphx::make_op("relu"), conv);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
m2.add_instruction(migraphx::make_op("relu"), conv_layout);
}
EXPECT(m1.sort() == m2.sort());
}
TEST_CASE(conv_add)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}});
auto w = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {16, 8, 3, 3}}));
auto y = m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto b = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m1.add_instruction(migraphx::make_op("add"), conv, b);
}
run_pass(m1);
migraphx::module m2;
{
auto x = add_layout_nhwc(
m2, m2.add_parameter("x", {migraphx::shape::float_type, {1, 8, 16, 16}}));
auto w = add_layout_nhwc(m2,
m2.add_literal(migraphx::generate_literal(
{migraphx::shape::float_type, {16, 8, 3, 3}})));
auto y = m2.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {16}}));
auto conv = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
x,
w);
auto conv_layout = m2.add_instruction(layout(), conv);
auto b = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", conv->get_shape().lens()}}),
y);
m2.add_instruction(migraphx::make_op("add"), conv_layout, b);
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -5755,6 +5755,92 @@ def split_test_default():
return ([node], [x], [y1, y2])
@onnx_test
def split_test_no_attribute():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4) * 75
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_no_attribute_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [300, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [75, 15])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [75, 15])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [75, 15])
y4 = helper.make_tensor_value_info('y4', TensorProto.FLOAT, [75, 15])
split = np.ones(4)
split_tensor = helper.make_tensor(name="split",
data_type=TensorProto.INT64,
dims=split.shape,
vals=split.astype(np.int64))
const_node = helper.make_node("Constant",
inputs=[],
outputs=['split'],
value=split_tensor)
node = onnx.helper.make_node(
'Split',
inputs=['x', 'split'],
outputs=['y1', 'y2', 'y3', 'y4'],
)
return ([const_node, node], [x], [y1, y2, y3, y4])
@onnx_test
def split_test_invalid_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[1, 1, 1])
return ([node], [x], [y1, y2, y3])
@onnx_test
def split_test_no_attribute_invalid_input_split():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
y1 = helper.make_tensor_value_info('y1', TensorProto.FLOAT, [10, 7])
y2 = helper.make_tensor_value_info('y2', TensorProto.FLOAT, [10, 4])
y3 = helper.make_tensor_value_info('y3', TensorProto.FLOAT, [10, 4])
node = onnx.helper.make_node('Split',
inputs=['x'],
outputs=['y1', 'y2', 'y3'],
axis=1,
split=[])
return ([node], [x], [y1, y2, y3])
@onnx_test
def sqrt_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
......@@ -5607,6 +5607,31 @@ TEST_CASE(split_test)
EXPECT(p == prog);
}
TEST_CASE(split_test_no_attribute)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape si{migraphx::shape::int64_type, {4}, {1}};
std::vector<int> ind = {75, 75, 75, 75};
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {300, 15}});
mm->add_literal(migraphx::literal(si, ind));
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {75}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {75}}, {"ends", {150}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {150}}, {"ends", {225}}}), input);
auto r4 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {225}}, {"ends", {300}}}), input);
mm->add_return({r1, r2, r3, r4});
auto prog = migraphx::parse_onnx("split_test_no_attribute.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_test_default)
{
migraphx::program p;
......@@ -5622,6 +5647,23 @@ TEST_CASE(split_test_default)
EXPECT(p == prog);
}
TEST_CASE(split_test_no_attribute_invalid_split)
{
EXPECT(
test::throws([&] { migraphx::parse_onnx("split_test_no_attribute_invalid_split.onnx"); }));
}
TEST_CASE(split_test_invalid_split)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("split_test_invalid_split.onnx"); }));
}
TEST_CASE(split_test_no_attribute_invalid_input_split)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("split_test_no_attribute_invalid_input_split.onnx"); }));
}
TEST_CASE(sqrt_test)
{
migraphx::program p;
......
split_test_invalid_split:
5
xy1y2y3"Split*
axis*
split@@@split_test_invalid_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
split_test_no_attribute:
0split"Constant*
value*:KKKKBsplit
!
x
splity1y2y3y4"Splitsplit_test_no_attributeZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
+split_test_no_attribute_invalid_input_split:
/
xy1y2y3"Split*
axis*
split+split_test_no_attribute_invalid_input_splitZ
x


b
y1


b
y2


b
y3


B
\ No newline at end of file
%split_test_no_attribute_invalid_split:
0split"Constant*
value*:Bsplit
!
x
splity1y2y3y4"Split%split_test_no_attribute_invalid_splitZ
x


b
y1

K
b
y2

K
b
y3

K
b
y4

K
B
\ No newline at end of file
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_concat_broadcast_add : verify_program<test_concat_broadcast_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4}};
migraphx::shape s1{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape s2{migraphx::shape::float_type, {6, 1}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s0);
auto z = mm->add_parameter("z", s0);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
auto b = mm->add_literal(migraphx::generate_literal(s2, 15));
auto bb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), b);
mm->add_instruction(migraphx::make_op("add"), concat, bb);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice_concat_add : verify_program<test_slice_concat_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 24, 2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {1, 8, 2, 2}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s0);
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), x);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), slice, y, y);
mm->add_instruction(migraphx::make_op("add"), concat, z);
return p;
}
};
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import argparse
import onnx
from onnx import version_converter
def parse_args():
parser = argparse.ArgumentParser(
description=
'MIGraphX Onnx Model Convertion. Use to convert the opset of the input model to MIGraphX\'s'
)
req_args = parser.add_argument_group(title='required arguments')
req_args.add_argument('--model',
type=str,
required=True,
help='path to onnx file')
req_args.add_argument('--output',
type=str,
required=True,
help='path to output onnx file')
req_args.add_argument('--opset',
type=int,
required=True,
help='The output opset')
req_args.add_argument('--infer_shapes',
action='store_true',
help='Infer shapes for output model')
parser.add_argument('--verbose',
action='store_true',
help='show verbose information (for debugging)')
args = parser.parse_args()
return args
def main():
args = parse_args()
model_path = args.model
out_model_path = args.output
target_opset = args.opset
verbose = args.verbose
infer_shapes = args.infer_shapes
original_model = onnx.load(model_path)
if verbose:
print(f"The model before conversion:\n{original_model}")
# A full list of supported adapters can be found here:
# https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21
# Apply the version conversion on the original model
converted_model = version_converter.convert_version(
original_model, target_opset)
if infer_shapes:
converted_model = onnx.shape_inference.infer_shapes(converted_model)
if verbose:
print(f"The model after conversion:\n{converted_model}")
# Save the ONNX model
onnx.save(converted_model, out_model_path)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment